feat(app): 优化图片加载、状态管理与搜索交互体验
- 拆分 load_image 为 _load_image_raw 与 load_image,隔离索引构建与 UI 错误提示 - 移除 build_index 中的进度条,避免副作用 - 使用 session_state 管理 query_url 与 query_image_bytes,修复切换输入方式时状态丢失 - 示例图片选择后增加 st.rerun(),确保 UI 即时刷新 - 搜索前增加空值校验,给出更友好的提示信息
This commit is contained in:
51
app.py
51
app.py
@@ -373,18 +373,24 @@ def get_embedder() -> CLIPEmbedder:
|
||||
|
||||
|
||||
# ─── Utilities ───────────────────────────────────────────────────────────────
|
||||
def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
||||
def _load_image_raw(source: str | io.BytesIO) -> Image.Image | None:
|
||||
try:
|
||||
if isinstance(source, str):
|
||||
resp = requests.get(source, timeout=30)
|
||||
resp.raise_for_status()
|
||||
return Image.open(io.BytesIO(resp.content)).convert("RGB")
|
||||
return Image.open(source).convert("RGB")
|
||||
except Exception as e:
|
||||
st.error(f"图片加载失败: {e}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
||||
img = _load_image_raw(source)
|
||||
if img is None:
|
||||
st.error("图片加载失败,请检查链接是否可访问或文件是否损坏")
|
||||
return img
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return float(np.dot(a, b))
|
||||
|
||||
@@ -393,13 +399,10 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
def build_index() -> tuple[list[dict], list[str], list[str]]:
|
||||
embedder = get_embedder()
|
||||
items, succeeded, failed = [], [], []
|
||||
progress = st.progress(0, text="正在构建病虫害图片索引...")
|
||||
total = len(PEST_KNOWLEDGE)
|
||||
for i, pest in enumerate(PEST_KNOWLEDGE):
|
||||
img = load_image(pest.url)
|
||||
for pest in PEST_KNOWLEDGE:
|
||||
img = _load_image_raw(pest.url)
|
||||
if img is None:
|
||||
failed.append(pest.name)
|
||||
progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
|
||||
continue
|
||||
try:
|
||||
embedding = embedder.embed(img)
|
||||
@@ -415,8 +418,6 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
|
||||
succeeded.append(pest.name)
|
||||
except Exception:
|
||||
failed.append(pest.name)
|
||||
progress.progress((i + 1) / total, text=f"索引构建中 ({i + 1}/{total})...")
|
||||
progress.empty()
|
||||
return items, succeeded, failed
|
||||
|
||||
|
||||
@@ -429,16 +430,26 @@ with st.sidebar:
|
||||
st.markdown('<div class="section-header" style="margin-top:0">🖼️ 输入方式</div>', unsafe_allow_html=True)
|
||||
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
||||
|
||||
# 初始化 session_state
|
||||
if "query_url" not in st.session_state:
|
||||
st.session_state.query_url = ""
|
||||
if "query_image_bytes" not in st.session_state:
|
||||
st.session_state.query_image_bytes = None
|
||||
|
||||
query_source = None
|
||||
query_url = ""
|
||||
|
||||
if input_mode == "上传本地图片":
|
||||
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
|
||||
if uploaded is not None:
|
||||
query_source = io.BytesIO(uploaded.getvalue())
|
||||
query_url = ""
|
||||
st.session_state.query_image_bytes = uploaded.getvalue()
|
||||
st.session_state.query_url = ""
|
||||
if st.session_state.query_image_bytes is not None:
|
||||
query_source = io.BytesIO(st.session_state.query_image_bytes)
|
||||
elif input_mode == "输入图片 URL":
|
||||
query_url = st.text_input("图片 URL", placeholder="https://example.com/image.jpg")
|
||||
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
|
||||
st.session_state.query_url = query_url
|
||||
st.session_state.query_image_bytes = None
|
||||
if query_url.strip():
|
||||
query_source = query_url.strip()
|
||||
else:
|
||||
@@ -448,7 +459,9 @@ with st.sidebar:
|
||||
with cols[idx % 2]:
|
||||
if st.button(name, key=f"ex_{name}"):
|
||||
st.session_state.query_url = url
|
||||
if "query_url" in st.session_state:
|
||||
st.session_state.query_image_bytes = None
|
||||
st.rerun()
|
||||
if st.session_state.query_url:
|
||||
query_url = st.session_state.query_url
|
||||
query_source = query_url
|
||||
st.image(query_url, use_container_width=True)
|
||||
@@ -493,7 +506,12 @@ if badges:
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
# ─── Search Logic ────────────────────────────────────────────────────────────
|
||||
if search_clicked and query_source is not None and index_items:
|
||||
if search_clicked:
|
||||
if query_source is None:
|
||||
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
|
||||
elif not index_items:
|
||||
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
||||
else:
|
||||
query_img = load_image(query_source)
|
||||
if query_img is not None:
|
||||
col_query, col_preview = st.columns([1, 3])
|
||||
@@ -587,9 +605,6 @@ if search_clicked and query_source is not None and index_items:
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
elif search_clicked and not index_items:
|
||||
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
||||
|
||||
# ─── Footer ───────────────────────────────────────────────────────────────────
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
st.markdown("""
|
||||
|
||||
Reference in New Issue
Block a user