diff --git a/app.py b/app.py index ca503ee..82eb89c 100644 --- a/app.py +++ b/app.py @@ -373,18 +373,24 @@ def get_embedder() -> CLIPEmbedder: # ─── Utilities ─────────────────────────────────────────────────────────────── -def load_image(source: str | io.BytesIO) -> Image.Image | None: +def _load_image_raw(source: str | io.BytesIO) -> Image.Image | None: try: if isinstance(source, str): resp = requests.get(source, timeout=30) resp.raise_for_status() return Image.open(io.BytesIO(resp.content)).convert("RGB") return Image.open(source).convert("RGB") - except Exception as e: - st.error(f"图片加载失败: {e}") + except Exception: return None +def load_image(source: str | io.BytesIO) -> Image.Image | None: + img = _load_image_raw(source) + if img is None: + st.error("图片加载失败,请检查链接是否可访问或文件是否损坏") + return img + + def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: return float(np.dot(a, b)) @@ -393,13 +399,10 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: def build_index() -> tuple[list[dict], list[str], list[str]]: embedder = get_embedder() items, succeeded, failed = [], [], [] - progress = st.progress(0, text="正在构建病虫害图片索引...") - total = len(PEST_KNOWLEDGE) - for i, pest in enumerate(PEST_KNOWLEDGE): - img = load_image(pest.url) + for pest in PEST_KNOWLEDGE: + img = _load_image_raw(pest.url) if img is None: failed.append(pest.name) - progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...") continue try: embedding = embedder.embed(img) @@ -415,8 +418,6 @@ def build_index() -> tuple[list[dict], list[str], list[str]]: succeeded.append(pest.name) except Exception: failed.append(pest.name) - progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...") - progress.empty() return items, succeeded, failed @@ -429,16 +430,26 @@ with st.sidebar: st.markdown('
🖼️ 输入方式
', unsafe_allow_html=True) input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed") + # 初始化 session_state + if "query_url" not in st.session_state: + st.session_state.query_url = "" + if "query_image_bytes" not in st.session_state: + st.session_state.query_image_bytes = None + query_source = None query_url = "" if input_mode == "上传本地图片": uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"]) if uploaded is not None: - query_source = io.BytesIO(uploaded.getvalue()) - query_url = "" + st.session_state.query_image_bytes = uploaded.getvalue() + st.session_state.query_url = "" + if st.session_state.query_image_bytes is not None: + query_source = io.BytesIO(st.session_state.query_image_bytes) elif input_mode == "输入图片 URL": - query_url = st.text_input("图片 URL", placeholder="https://example.com/image.jpg") + query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg") + st.session_state.query_url = query_url + st.session_state.query_image_bytes = None if query_url.strip(): query_source = query_url.strip() else: @@ -448,7 +459,9 @@ with st.sidebar: with cols[idx % 2]: if st.button(name, key=f"ex_{name}"): st.session_state.query_url = url - if "query_url" in st.session_state: + st.session_state.query_image_bytes = None + st.rerun() + if st.session_state.query_url: query_url = st.session_state.query_url query_source = query_url st.image(query_url, use_container_width=True) @@ -493,103 +506,105 @@ if badges: st.markdown("
", unsafe_allow_html=True) # ─── Search Logic ──────────────────────────────────────────────────────────── -if search_clicked and query_source is not None and index_items: - query_img = load_image(query_source) - if query_img is not None: - col_query, col_preview = st.columns([1, 3]) - with col_query: - st.markdown('
🔍 查询图片
', unsafe_allow_html=True) - st.image(query_img, use_container_width=True) +if search_clicked: + if query_source is None: + st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索") + elif not index_items: + st.warning("知识库索引为空,请检查网络连接后刷新页面重试。") + else: + query_img = load_image(query_source) + if query_img is not None: + col_query, col_preview = st.columns([1, 3]) + with col_query: + st.markdown('
🔍 查询图片
', unsafe_allow_html=True) + st.image(query_img, use_container_width=True) - with col_preview: - st.markdown('
⏳ 正在分析...
', unsafe_allow_html=True) - progress = st.progress(0, text="提取图像特征...") + with col_preview: + st.markdown('
⏳ 正在分析...
', unsafe_allow_html=True) + progress = st.progress(0, text="提取图像特征...") - embedder = get_embedder() - query_embedding = embedder.embed(query_img) - progress.progress(50, text="比对知识库...") + embedder = get_embedder() + query_embedding = embedder.embed(query_img) + progress.progress(50, text="比对知识库...") - scores = [] - for item in index_items: - sim = cosine_similarity(query_embedding, item["embedding"]) - scores.append((sim, item)) - scores.sort(key=lambda x: x[0], reverse=True) - results = scores[:top_k] - progress.progress(100, text="搜索完成") - progress.empty() + scores = [] + for item in index_items: + sim = cosine_similarity(query_embedding, item["embedding"]) + scores.append((sim, item)) + scores.sort(key=lambda x: x[0], reverse=True) + results = scores[:top_k] + progress.progress(100, text="搜索完成") + progress.empty() - st.markdown(f'
🏆 搜索结果(Top-{len(results)})
', unsafe_allow_html=True) + st.markdown(f'
🏆 搜索结果(Top-{len(results)})
', unsafe_allow_html=True) - # Similarity bar chart - names = [f"{r[1]['name']}" for r in results] - sims = [r[0] * 100 for r in results] - colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results] + # Similarity bar chart + names = [f"{r[1]['name']}" for r in results] + sims = [r[0] * 100 for r in results] + colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results] - fig_bar = go.Figure() - fig_bar.add_trace(go.Bar( - x=sims, - y=names, - orientation="h", - marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)), - text=[f"{s:.1f}%" for s in sims], - textposition="outside", - textfont=dict(color="#5a5a5a", size=10), - )) - fig_bar.update_layout( - xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]), - yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"), - paper_bgcolor="rgba(0,0,0,0)", - plot_bgcolor="rgba(0,0,0,0)", - font=dict(color="#2c2c2c", size=11), - margin=dict(t=10, b=30, l=80, r=50), - height=160 + len(results) * 34, - showlegend=False, - ) - st.plotly_chart(fig_bar, use_container_width=True) + fig_bar = go.Figure() + fig_bar.add_trace(go.Bar( + x=sims, + y=names, + orientation="h", + marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)), + text=[f"{s:.1f}%" for s in sims], + textposition="outside", + textfont=dict(color="#5a5a5a", size=10), + )) + fig_bar.update_layout( + xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]), + yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"), + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="rgba(0,0,0,0)", + font=dict(color="#2c2c2c", size=11), + margin=dict(t=10, b=30, l=80, r=50), + height=160 + len(results) * 34, + showlegend=False, + ) + st.plotly_chart(fig_bar, use_container_width=True) - # Result cards below - st.markdown('
📋 详细结果
', unsafe_allow_html=True) - for rank, (sim, item) in enumerate(results, 1): - with st.container(): + # Result cards below + st.markdown('
📋 详细结果
', unsafe_allow_html=True) + for rank, (sim, item) in enumerate(results, 1): + with st.container(): + st.markdown(f""" +
+
+
+ +
+
+
+ {rank} + {item['name']} + 相似度 {sim*100:.1f}% +
+
+ {item['crop']} + {item['category']} +
+
+ 症状:{item['symptoms']}
+ 防治:{item['treatment']} +
+
+
+
+ """, unsafe_allow_html=True) + + # Advisory summary + if results: + best = results[0][1] + st.markdown('
💡 初步建议
', unsafe_allow_html=True) st.markdown(f""" -
-
-
- -
-
-
- {rank} - {item['name']} - 相似度 {sim*100:.1f}% -
-
- {item['crop']} - {item['category']} -
-
- 症状:{item['symptoms']}
- 防治:{item['treatment']} -
-
-
+
+ 系统判断该图片与 {best['name']}({best['crop']}{best['category']})最为相似,相似度 {results[0][0]*100:.1f}%
+ 建议结合田间实际情况进一步确认,参考防治方案:{best['treatment']}
""", unsafe_allow_html=True) - # Advisory summary - if results: - best = results[0][1] - st.markdown('
💡 初步建议
', unsafe_allow_html=True) - st.markdown(f""" -
- 系统判断该图片与 {best['name']}({best['crop']}{best['category']})最为相似,相似度 {results[0][0]*100:.1f}%
- 建议结合田间实际情况进一步确认,参考防治方案:{best['treatment']} -
- """, unsafe_allow_html=True) - -elif search_clicked and not index_items: - st.warning("知识库索引为空,请检查网络连接后刷新页面重试。") - # ─── Footer ─────────────────────────────────────────────────────────────────── st.markdown("
", unsafe_allow_html=True) st.markdown("""