diff --git a/app.py b/app.py
index ca503ee..82eb89c 100644
--- a/app.py
+++ b/app.py
@@ -373,18 +373,24 @@ def get_embedder() -> CLIPEmbedder:
# ─── Utilities ───────────────────────────────────────────────────────────────
-def load_image(source: str | io.BytesIO) -> Image.Image | None:
+def _load_image_raw(source: str | io.BytesIO) -> Image.Image | None:
try:
if isinstance(source, str):
resp = requests.get(source, timeout=30)
resp.raise_for_status()
return Image.open(io.BytesIO(resp.content)).convert("RGB")
return Image.open(source).convert("RGB")
- except Exception as e:
- st.error(f"图片加载失败: {e}")
+ except Exception:
return None
+def load_image(source: str | io.BytesIO) -> Image.Image | None:
+ img = _load_image_raw(source)
+ if img is None:
+ st.error("图片加载失败,请检查链接是否可访问或文件是否损坏")
+ return img
+
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b))
@@ -393,13 +399,10 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
def build_index() -> tuple[list[dict], list[str], list[str]]:
embedder = get_embedder()
items, succeeded, failed = [], [], []
- progress = st.progress(0, text="正在构建病虫害图片索引...")
- total = len(PEST_KNOWLEDGE)
- for i, pest in enumerate(PEST_KNOWLEDGE):
- img = load_image(pest.url)
+ for pest in PEST_KNOWLEDGE:
+ img = _load_image_raw(pest.url)
if img is None:
failed.append(pest.name)
- progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
continue
try:
embedding = embedder.embed(img)
@@ -415,8 +418,6 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
succeeded.append(pest.name)
except Exception:
failed.append(pest.name)
- progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
- progress.empty()
return items, succeeded, failed
@@ -429,16 +430,26 @@ with st.sidebar:
st.markdown('
', unsafe_allow_html=True)
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
+ # 初始化 session_state
+ if "query_url" not in st.session_state:
+ st.session_state.query_url = ""
+ if "query_image_bytes" not in st.session_state:
+ st.session_state.query_image_bytes = None
+
query_source = None
query_url = ""
if input_mode == "上传本地图片":
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
if uploaded is not None:
- query_source = io.BytesIO(uploaded.getvalue())
- query_url = ""
+ st.session_state.query_image_bytes = uploaded.getvalue()
+ st.session_state.query_url = ""
+ if st.session_state.query_image_bytes is not None:
+ query_source = io.BytesIO(st.session_state.query_image_bytes)
elif input_mode == "输入图片 URL":
- query_url = st.text_input("图片 URL", placeholder="https://example.com/image.jpg")
+ query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
+ st.session_state.query_url = query_url
+ st.session_state.query_image_bytes = None
if query_url.strip():
query_source = query_url.strip()
else:
@@ -448,7 +459,9 @@ with st.sidebar:
with cols[idx % 2]:
if st.button(name, key=f"ex_{name}"):
st.session_state.query_url = url
- if "query_url" in st.session_state:
+ st.session_state.query_image_bytes = None
+ st.rerun()
+ if st.session_state.query_url:
query_url = st.session_state.query_url
query_source = query_url
st.image(query_url, use_container_width=True)
@@ -493,103 +506,105 @@ if badges:
st.markdown("
", unsafe_allow_html=True)
# ─── Search Logic ────────────────────────────────────────────────────────────
-if search_clicked and query_source is not None and index_items:
- query_img = load_image(query_source)
- if query_img is not None:
- col_query, col_preview = st.columns([1, 3])
- with col_query:
- st.markdown('', unsafe_allow_html=True)
- st.image(query_img, use_container_width=True)
+if search_clicked:
+ if query_source is None:
+ st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
+ elif not index_items:
+ st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
+ else:
+ query_img = load_image(query_source)
+ if query_img is not None:
+ col_query, col_preview = st.columns([1, 3])
+ with col_query:
+ st.markdown('', unsafe_allow_html=True)
+ st.image(query_img, use_container_width=True)
- with col_preview:
- st.markdown('', unsafe_allow_html=True)
- progress = st.progress(0, text="提取图像特征...")
+ with col_preview:
+ st.markdown('', unsafe_allow_html=True)
+ progress = st.progress(0, text="提取图像特征...")
- embedder = get_embedder()
- query_embedding = embedder.embed(query_img)
- progress.progress(50, text="比对知识库...")
+ embedder = get_embedder()
+ query_embedding = embedder.embed(query_img)
+ progress.progress(50, text="比对知识库...")
- scores = []
- for item in index_items:
- sim = cosine_similarity(query_embedding, item["embedding"])
- scores.append((sim, item))
- scores.sort(key=lambda x: x[0], reverse=True)
- results = scores[:top_k]
- progress.progress(100, text="搜索完成")
- progress.empty()
+ scores = []
+ for item in index_items:
+ sim = cosine_similarity(query_embedding, item["embedding"])
+ scores.append((sim, item))
+ scores.sort(key=lambda x: x[0], reverse=True)
+ results = scores[:top_k]
+ progress.progress(100, text="搜索完成")
+ progress.empty()
- st.markdown(f'', unsafe_allow_html=True)
+ st.markdown(f'', unsafe_allow_html=True)
- # Similarity bar chart
- names = [f"{r[1]['name']}" for r in results]
- sims = [r[0] * 100 for r in results]
- colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
+ # Similarity bar chart
+ names = [f"{r[1]['name']}" for r in results]
+ sims = [r[0] * 100 for r in results]
+ colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
- fig_bar = go.Figure()
- fig_bar.add_trace(go.Bar(
- x=sims,
- y=names,
- orientation="h",
- marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
- text=[f"{s:.1f}%" for s in sims],
- textposition="outside",
- textfont=dict(color="#5a5a5a", size=10),
- ))
- fig_bar.update_layout(
- xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
- yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
- paper_bgcolor="rgba(0,0,0,0)",
- plot_bgcolor="rgba(0,0,0,0)",
- font=dict(color="#2c2c2c", size=11),
- margin=dict(t=10, b=30, l=80, r=50),
- height=160 + len(results) * 34,
- showlegend=False,
- )
- st.plotly_chart(fig_bar, use_container_width=True)
+ fig_bar = go.Figure()
+ fig_bar.add_trace(go.Bar(
+ x=sims,
+ y=names,
+ orientation="h",
+ marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
+ text=[f"{s:.1f}%" for s in sims],
+ textposition="outside",
+ textfont=dict(color="#5a5a5a", size=10),
+ ))
+ fig_bar.update_layout(
+ xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
+ yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
+ paper_bgcolor="rgba(0,0,0,0)",
+ plot_bgcolor="rgba(0,0,0,0)",
+ font=dict(color="#2c2c2c", size=11),
+ margin=dict(t=10, b=30, l=80, r=50),
+ height=160 + len(results) * 34,
+ showlegend=False,
+ )
+ st.plotly_chart(fig_bar, use_container_width=True)
- # Result cards below
- st.markdown('', unsafe_allow_html=True)
- for rank, (sim, item) in enumerate(results, 1):
- with st.container():
+ # Result cards below
+ st.markdown('', unsafe_allow_html=True)
+ for rank, (sim, item) in enumerate(results, 1):
+ with st.container():
+ st.markdown(f"""
+
+
+
+

+
+
+
+ {rank}
+ {item['name']}
+ 相似度 {sim*100:.1f}%
+
+
+ {item['crop']}
+ {item['category']}
+
+
+ 症状:{item['symptoms']}
+ 防治:{item['treatment']}
+
+
+
+
+ """, unsafe_allow_html=True)
+
+ # Advisory summary
+ if results:
+ best = results[0][1]
+ st.markdown('', unsafe_allow_html=True)
st.markdown(f"""
-
-
-
-

-
-
-
- {rank}
- {item['name']}
- 相似度 {sim*100:.1f}%
-
-
- {item['crop']}
- {item['category']}
-
-
- 症状:{item['symptoms']}
- 防治:{item['treatment']}
-
-
-
+
+ 系统判断该图片与 {best['name']}({best['crop']}{best['category']})最为相似,相似度 {results[0][0]*100:.1f}%。
+ 建议结合田间实际情况进一步确认,参考防治方案:{best['treatment']}
""", unsafe_allow_html=True)
- # Advisory summary
- if results:
- best = results[0][1]
- st.markdown('', unsafe_allow_html=True)
- st.markdown(f"""
-
- 系统判断该图片与 {best['name']}({best['crop']}{best['category']})最为相似,相似度 {results[0][0]*100:.1f}%。
- 建议结合田间实际情况进一步确认,参考防治方案:{best['treatment']}
-
- """, unsafe_allow_html=True)
-
-elif search_clicked and not index_items:
- st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
-
# ─── Footer ───────────────────────────────────────────────────────────────────
st.markdown("
", unsafe_allow_html=True)
st.markdown("""