feat(app): 优化图片加载、状态管理与搜索交互体验
- 拆分 load_image 为 _load_image_raw 与 load_image,隔离索引构建与 UI 错误提示 - 移除 build_index 中的进度条,避免副作用 - 使用 session_state 管理 query_url 与 query_image_bytes,修复切换输入方式时状态丢失 - 示例图片选择后增加 st.rerun(),确保 UI 即时刷新 - 搜索前增加空值校验,给出更友好的提示信息
This commit is contained in:
215
app.py
215
app.py
@@ -373,18 +373,24 @@ def get_embedder() -> CLIPEmbedder:
|
|||||||
|
|
||||||
|
|
||||||
# ─── Utilities ───────────────────────────────────────────────────────────────
|
# ─── Utilities ───────────────────────────────────────────────────────────────
|
||||||
def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
def _load_image_raw(source: str | io.BytesIO) -> Image.Image | None:
|
||||||
try:
|
try:
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
resp = requests.get(source, timeout=30)
|
resp = requests.get(source, timeout=30)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return Image.open(io.BytesIO(resp.content)).convert("RGB")
|
return Image.open(io.BytesIO(resp.content)).convert("RGB")
|
||||||
return Image.open(source).convert("RGB")
|
return Image.open(source).convert("RGB")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
st.error(f"图片加载失败: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
||||||
|
img = _load_image_raw(source)
|
||||||
|
if img is None:
|
||||||
|
st.error("图片加载失败,请检查链接是否可访问或文件是否损坏")
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
return float(np.dot(a, b))
|
return float(np.dot(a, b))
|
||||||
|
|
||||||
@@ -393,13 +399,10 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|||||||
def build_index() -> tuple[list[dict], list[str], list[str]]:
|
def build_index() -> tuple[list[dict], list[str], list[str]]:
|
||||||
embedder = get_embedder()
|
embedder = get_embedder()
|
||||||
items, succeeded, failed = [], [], []
|
items, succeeded, failed = [], [], []
|
||||||
progress = st.progress(0, text="正在构建病虫害图片索引...")
|
for pest in PEST_KNOWLEDGE:
|
||||||
total = len(PEST_KNOWLEDGE)
|
img = _load_image_raw(pest.url)
|
||||||
for i, pest in enumerate(PEST_KNOWLEDGE):
|
|
||||||
img = load_image(pest.url)
|
|
||||||
if img is None:
|
if img is None:
|
||||||
failed.append(pest.name)
|
failed.append(pest.name)
|
||||||
progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
|
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
embedding = embedder.embed(img)
|
embedding = embedder.embed(img)
|
||||||
@@ -415,8 +418,6 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
|
|||||||
succeeded.append(pest.name)
|
succeeded.append(pest.name)
|
||||||
except Exception:
|
except Exception:
|
||||||
failed.append(pest.name)
|
failed.append(pest.name)
|
||||||
progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
|
|
||||||
progress.empty()
|
|
||||||
return items, succeeded, failed
|
return items, succeeded, failed
|
||||||
|
|
||||||
|
|
||||||
@@ -429,16 +430,26 @@ with st.sidebar:
|
|||||||
st.markdown('<div class="section-header" style="margin-top:0">🖼️ 输入方式</div>', unsafe_allow_html=True)
|
st.markdown('<div class="section-header" style="margin-top:0">🖼️ 输入方式</div>', unsafe_allow_html=True)
|
||||||
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
||||||
|
|
||||||
|
# 初始化 session_state
|
||||||
|
if "query_url" not in st.session_state:
|
||||||
|
st.session_state.query_url = ""
|
||||||
|
if "query_image_bytes" not in st.session_state:
|
||||||
|
st.session_state.query_image_bytes = None
|
||||||
|
|
||||||
query_source = None
|
query_source = None
|
||||||
query_url = ""
|
query_url = ""
|
||||||
|
|
||||||
if input_mode == "上传本地图片":
|
if input_mode == "上传本地图片":
|
||||||
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
|
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
|
||||||
if uploaded is not None:
|
if uploaded is not None:
|
||||||
query_source = io.BytesIO(uploaded.getvalue())
|
st.session_state.query_image_bytes = uploaded.getvalue()
|
||||||
query_url = ""
|
st.session_state.query_url = ""
|
||||||
|
if st.session_state.query_image_bytes is not None:
|
||||||
|
query_source = io.BytesIO(st.session_state.query_image_bytes)
|
||||||
elif input_mode == "输入图片 URL":
|
elif input_mode == "输入图片 URL":
|
||||||
query_url = st.text_input("图片 URL", placeholder="https://example.com/image.jpg")
|
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
|
||||||
|
st.session_state.query_url = query_url
|
||||||
|
st.session_state.query_image_bytes = None
|
||||||
if query_url.strip():
|
if query_url.strip():
|
||||||
query_source = query_url.strip()
|
query_source = query_url.strip()
|
||||||
else:
|
else:
|
||||||
@@ -448,7 +459,9 @@ with st.sidebar:
|
|||||||
with cols[idx % 2]:
|
with cols[idx % 2]:
|
||||||
if st.button(name, key=f"ex_{name}"):
|
if st.button(name, key=f"ex_{name}"):
|
||||||
st.session_state.query_url = url
|
st.session_state.query_url = url
|
||||||
if "query_url" in st.session_state:
|
st.session_state.query_image_bytes = None
|
||||||
|
st.rerun()
|
||||||
|
if st.session_state.query_url:
|
||||||
query_url = st.session_state.query_url
|
query_url = st.session_state.query_url
|
||||||
query_source = query_url
|
query_source = query_url
|
||||||
st.image(query_url, use_container_width=True)
|
st.image(query_url, use_container_width=True)
|
||||||
@@ -493,103 +506,105 @@ if badges:
|
|||||||
st.markdown("<br>", unsafe_allow_html=True)
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
# ─── Search Logic ────────────────────────────────────────────────────────────
|
# ─── Search Logic ────────────────────────────────────────────────────────────
|
||||||
if search_clicked and query_source is not None and index_items:
|
if search_clicked:
|
||||||
query_img = load_image(query_source)
|
if query_source is None:
|
||||||
if query_img is not None:
|
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
|
||||||
col_query, col_preview = st.columns([1, 3])
|
elif not index_items:
|
||||||
with col_query:
|
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
||||||
st.markdown('<div class="section-header" style="margin-top:0">🔍 查询图片</div>', unsafe_allow_html=True)
|
else:
|
||||||
st.image(query_img, use_container_width=True)
|
query_img = load_image(query_source)
|
||||||
|
if query_img is not None:
|
||||||
|
col_query, col_preview = st.columns([1, 3])
|
||||||
|
with col_query:
|
||||||
|
st.markdown('<div class="section-header" style="margin-top:0">🔍 查询图片</div>', unsafe_allow_html=True)
|
||||||
|
st.image(query_img, use_container_width=True)
|
||||||
|
|
||||||
with col_preview:
|
with col_preview:
|
||||||
st.markdown('<div class="section-header" style="margin-top:0">⏳ 正在分析...</div>', unsafe_allow_html=True)
|
st.markdown('<div class="section-header" style="margin-top:0">⏳ 正在分析...</div>', unsafe_allow_html=True)
|
||||||
progress = st.progress(0, text="提取图像特征...")
|
progress = st.progress(0, text="提取图像特征...")
|
||||||
|
|
||||||
embedder = get_embedder()
|
embedder = get_embedder()
|
||||||
query_embedding = embedder.embed(query_img)
|
query_embedding = embedder.embed(query_img)
|
||||||
progress.progress(50, text="比对知识库...")
|
progress.progress(50, text="比对知识库...")
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for item in index_items:
|
for item in index_items:
|
||||||
sim = cosine_similarity(query_embedding, item["embedding"])
|
sim = cosine_similarity(query_embedding, item["embedding"])
|
||||||
scores.append((sim, item))
|
scores.append((sim, item))
|
||||||
scores.sort(key=lambda x: x[0], reverse=True)
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
results = scores[:top_k]
|
results = scores[:top_k]
|
||||||
progress.progress(100, text="搜索完成")
|
progress.progress(100, text="搜索完成")
|
||||||
progress.empty()
|
progress.empty()
|
||||||
|
|
||||||
st.markdown(f'<div class="section-header" style="margin-top:0">🏆 搜索结果(Top-{len(results)})</div>', unsafe_allow_html=True)
|
st.markdown(f'<div class="section-header" style="margin-top:0">🏆 搜索结果(Top-{len(results)})</div>', unsafe_allow_html=True)
|
||||||
|
|
||||||
# Similarity bar chart
|
# Similarity bar chart
|
||||||
names = [f"{r[1]['name']}" for r in results]
|
names = [f"{r[1]['name']}" for r in results]
|
||||||
sims = [r[0] * 100 for r in results]
|
sims = [r[0] * 100 for r in results]
|
||||||
colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
|
colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
|
||||||
|
|
||||||
fig_bar = go.Figure()
|
fig_bar = go.Figure()
|
||||||
fig_bar.add_trace(go.Bar(
|
fig_bar.add_trace(go.Bar(
|
||||||
x=sims,
|
x=sims,
|
||||||
y=names,
|
y=names,
|
||||||
orientation="h",
|
orientation="h",
|
||||||
marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
|
marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
|
||||||
text=[f"{s:.1f}%" for s in sims],
|
text=[f"{s:.1f}%" for s in sims],
|
||||||
textposition="outside",
|
textposition="outside",
|
||||||
textfont=dict(color="#5a5a5a", size=10),
|
textfont=dict(color="#5a5a5a", size=10),
|
||||||
))
|
))
|
||||||
fig_bar.update_layout(
|
fig_bar.update_layout(
|
||||||
xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
|
xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
|
||||||
yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
|
yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
|
||||||
paper_bgcolor="rgba(0,0,0,0)",
|
paper_bgcolor="rgba(0,0,0,0)",
|
||||||
plot_bgcolor="rgba(0,0,0,0)",
|
plot_bgcolor="rgba(0,0,0,0)",
|
||||||
font=dict(color="#2c2c2c", size=11),
|
font=dict(color="#2c2c2c", size=11),
|
||||||
margin=dict(t=10, b=30, l=80, r=50),
|
margin=dict(t=10, b=30, l=80, r=50),
|
||||||
height=160 + len(results) * 34,
|
height=160 + len(results) * 34,
|
||||||
showlegend=False,
|
showlegend=False,
|
||||||
)
|
)
|
||||||
st.plotly_chart(fig_bar, use_container_width=True)
|
st.plotly_chart(fig_bar, use_container_width=True)
|
||||||
|
|
||||||
# Result cards below
|
# Result cards below
|
||||||
st.markdown('<div class="section-header">📋 详细结果</div>', unsafe_allow_html=True)
|
st.markdown('<div class="section-header">📋 详细结果</div>', unsafe_allow_html=True)
|
||||||
for rank, (sim, item) in enumerate(results, 1):
|
for rank, (sim, item) in enumerate(results, 1):
|
||||||
with st.container():
|
with st.container():
|
||||||
|
st.markdown(f"""
|
||||||
|
<div class="result-card">
|
||||||
|
<div style="display:flex; gap:14px; align-items:flex-start;">
|
||||||
|
<div style="flex:0 0 140px;">
|
||||||
|
<img src="{item['url']}" style="width:100%; border-radius:10px; border:1px solid var(--border);">
|
||||||
|
</div>
|
||||||
|
<div style="flex:1;">
|
||||||
|
<div style="display:flex; align-items:center; margin-bottom:8px;">
|
||||||
|
<span class="result-rank">{rank}</span>
|
||||||
|
<span class="result-name">{item['name']}</span>
|
||||||
|
<span style="margin-left:auto;" class="result-score">相似度 {sim*100:.1f}%</span>
|
||||||
|
</div>
|
||||||
|
<div style="margin-bottom:8px;">
|
||||||
|
<span class="tag">{item['crop']}</span>
|
||||||
|
<span class="tag{' tag-warn' if item['category'] == '虫害' else ''}">{item['category']}</span>
|
||||||
|
</div>
|
||||||
|
<div style="font-size:0.88rem; color:var(--ink); line-height:1.6;">
|
||||||
|
<b>症状:</b>{item['symptoms']}<br>
|
||||||
|
<b>防治:</b>{item['treatment']}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
|
# Advisory summary
|
||||||
|
if results:
|
||||||
|
best = results[0][1]
|
||||||
|
st.markdown('<div class="section-header">💡 初步建议</div>', unsafe_allow_html=True)
|
||||||
st.markdown(f"""
|
st.markdown(f"""
|
||||||
<div class="result-card">
|
<div class="info-panel" style="border-left:3px solid var(--leaf-light); border-radius:0 12px 12px 0;">
|
||||||
<div style="display:flex; gap:14px; align-items:flex-start;">
|
系统判断该图片与 <b>{best['name']}</b>({best['crop']}{best['category']})最为相似,相似度 <b>{results[0][0]*100:.1f}%</b>。<br>
|
||||||
<div style="flex:0 0 140px;">
|
建议结合田间实际情况进一步确认,参考防治方案:<b>{best['treatment']}</b>
|
||||||
<img src="{item['url']}" style="width:100%; border-radius:10px; border:1px solid var(--border);">
|
|
||||||
</div>
|
|
||||||
<div style="flex:1;">
|
|
||||||
<div style="display:flex; align-items:center; margin-bottom:8px;">
|
|
||||||
<span class="result-rank">{rank}</span>
|
|
||||||
<span class="result-name">{item['name']}</span>
|
|
||||||
<span style="margin-left:auto;" class="result-score">相似度 {sim*100:.1f}%</span>
|
|
||||||
</div>
|
|
||||||
<div style="margin-bottom:8px;">
|
|
||||||
<span class="tag">{item['crop']}</span>
|
|
||||||
<span class="tag{' tag-warn' if item['category'] == '虫害' else ''}">{item['category']}</span>
|
|
||||||
</div>
|
|
||||||
<div style="font-size:0.88rem; color:var(--ink); line-height:1.6;">
|
|
||||||
<b>症状:</b>{item['symptoms']}<br>
|
|
||||||
<b>防治:</b>{item['treatment']}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
</div>
|
||||||
""", unsafe_allow_html=True)
|
""", unsafe_allow_html=True)
|
||||||
|
|
||||||
# Advisory summary
|
|
||||||
if results:
|
|
||||||
best = results[0][1]
|
|
||||||
st.markdown('<div class="section-header">💡 初步建议</div>', unsafe_allow_html=True)
|
|
||||||
st.markdown(f"""
|
|
||||||
<div class="info-panel" style="border-left:3px solid var(--leaf-light); border-radius:0 12px 12px 0;">
|
|
||||||
系统判断该图片与 <b>{best['name']}</b>({best['crop']}{best['category']})最为相似,相似度 <b>{results[0][0]*100:.1f}%</b>。<br>
|
|
||||||
建议结合田间实际情况进一步确认,参考防治方案:<b>{best['treatment']}</b>
|
|
||||||
</div>
|
|
||||||
""", unsafe_allow_html=True)
|
|
||||||
|
|
||||||
elif search_clicked and not index_items:
|
|
||||||
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
|
||||||
|
|
||||||
# ─── Footer ───────────────────────────────────────────────────────────────────
|
# ─── Footer ───────────────────────────────────────────────────────────────────
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
st.markdown("""
|
st.markdown("""
|
||||||
|
|||||||
Reference in New Issue
Block a user