diff --git a/app.py b/app.py
index 1d80157..bf02b26 100644
--- a/app.py
+++ b/app.py
@@ -1,21 +1,26 @@
"""
病虫害以图搜图
-基于 CLIP 本地模型的图片 Embedding 相似度搜索
+基于图片 Embedding API 的相似度搜索
"""
from __future__ import annotations
import io
-import os
+import time
from dataclasses import dataclass
from typing import Literal
+import httpx
import numpy as np
import plotly.graph_objects as go
import requests
import streamlit as st
from PIL import Image
-from transformers import CLIPModel, CLIPProcessor
+
+# ─── API Config ──────────────────────────────────────────────────────────────
+IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings"
+EMBEDDING_MODEL = "Qwen3-VL-Embedding"
+API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA"
# ─── Page Config ────────────────────────────────────────────────────────────
st.set_page_config(
@@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [
]
-# ─── CLIP Embedder ───────────────────────────────────────────────────────────
-class CLIPEmbedder:
- MODEL_NAME = "openai/clip-vit-base-patch32"
-
- def __init__(self) -> None:
- self._processor: CLIPProcessor | None = None
- self._model: CLIPModel | None = None
-
- def _load(self) -> tuple[CLIPProcessor, CLIPModel]:
- if self._processor is None or self._model is None:
- with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."):
- self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME)
- self._model = CLIPModel.from_pretrained(self.MODEL_NAME)
- return self._processor, self._model
-
- def embed(self, image: Image.Image) -> np.ndarray:
- processor, model = self._load()
- inputs = processor(images=image, return_tensors="pt")
- image_features = model.get_image_features(**inputs)
- vec = image_features.detach().cpu().numpy().flatten()
- norm = np.linalg.norm(vec)
- if norm == 0:
- return vec
- return vec / norm
-
-
-@st.cache_resource(show_spinner=False)
-def get_embedder() -> CLIPEmbedder:
- return CLIPEmbedder()
+# ─── Embedding API ───────────────────────────────────────────────────────────
+def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]:
+ """调用远程 API 获取图片 Embedding,支持重试。"""
+ payload = {
+ "model": EMBEDDING_MODEL,
+ "messages": [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": text},
+ {"type": "image_url", "image_url": {"url": image_url}},
+ ],
+ }
+ ],
+ }
+ last_error = None
+ for attempt in range(max_retries + 1):
+ try:
+ resp = httpx.post(
+ IMAGE_EMBEDDING_API_URL,
+ headers={"Content-Type": "application/json"},
+ json=payload,
+ timeout=120,
+ )
+ resp.raise_for_status()
+ return resp.json()["data"][0]["embedding"]
+ except httpx.HTTPStatusError as e:
+ last_error = e
+ if attempt < max_retries:
+ time.sleep(2 * (attempt + 1))
+ raise last_error
# ─── Utilities ───────────────────────────────────────────────────────────────
@@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None:
return img
-def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
- return float(np.dot(a, b))
+def cosine_similarity(a: list[float], b: list[float]) -> float:
+ a_arr = np.array(a)
+ b_arr = np.array(b)
+ norm_a = np.linalg.norm(a_arr)
+ norm_b = np.linalg.norm(b_arr)
+ if norm_a == 0 or norm_b == 0:
+ return 0.0
+ return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
-@st.cache_data(show_spinner=False)
-def build_index() -> tuple[list[dict], list[str], list[str]]:
- embedder = get_embedder()
+@st.cache_resource
+def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]:
items, succeeded, failed = [], [], []
- for pest in PEST_KNOWLEDGE:
- img = _load_image_raw(pest.url)
- if img is None:
- failed.append(pest.name)
- continue
+ progress = st.progress(0, text="正在构建图片索引...")
+ total = len(PEST_KNOWLEDGE)
+ for i, pest in enumerate(PEST_KNOWLEDGE):
try:
- embedding = embedder.embed(img)
+ embedding = get_image_embedding(pest.url, text=pest.name)
items.append({
"name": pest.name,
"url": pest.url,
@@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
"category": pest.category,
})
succeeded.append(pest.name)
- except Exception:
- failed.append(pest.name)
+ except Exception as e:
+ failed.append((pest.name, str(e)))
+ progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...")
+ progress.empty()
return items, succeeded, failed
# ─── Sidebar ─────────────────────────────────────────────────────────────────
with st.sidebar:
st.header("🌿 病虫害以图搜图")
- st.caption("上传图片,智能识别相似病虫害")
+ st.caption("输入图片 URL,通过图片 Embedding 搜索相似病虫害")
st.divider()
st.subheader("🖼️ 输入方式")
- input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
+ input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed")
# 初始化 session_state
if "query_url" not in st.session_state:
st.session_state.query_url = ""
- if "query_image_bytes" not in st.session_state:
- st.session_state.query_image_bytes = None
query_source = None
query_url = ""
- if input_mode == "上传本地图片":
- uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
- if uploaded is not None:
- st.session_state.query_image_bytes = uploaded.getvalue()
- st.session_state.query_url = ""
- if st.session_state.query_image_bytes is not None:
- query_source = io.BytesIO(st.session_state.query_image_bytes)
- elif input_mode == "输入图片 URL":
+ if input_mode == "输入图片 URL":
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
st.session_state.query_url = query_url
- st.session_state.query_image_bytes = None
if query_url.strip():
query_source = query_url.strip()
else:
@@ -283,7 +285,6 @@ with st.sidebar:
with cols[idx % 2]:
if st.button(name, key=f"ex_{name}"):
st.session_state.query_url = url
- st.session_state.query_image_bytes = None
st.rerun()
if st.session_state.query_url:
query_url = st.session_state.query_url
@@ -298,116 +299,114 @@ with st.sidebar:
st.divider()
st.info(
"**使用说明**\n\n"
- "1. 上传病虫害患处图片\n"
- "2. 系统自动提取图像特征\n"
+ "1. 输入病虫害图片 URL 或选择示例\n"
+ "2. 系统调用 Embedding API 提取图像特征\n"
"3. 与知识库比对返回相似结果\n"
"4. 参考症状与防治建议"
)
# ─── Build Index ─────────────────────────────────────────────────────────────
-index_items, succeeded, failed = build_index()
+with st.spinner("首次加载需要构建图片索引,请稍候..."):
+ index_items, succeeded, failed = build_index()
# ─── Main Layout ─────────────────────────────────────────────────────────────
st.title("🌿 病虫害以图搜图")
-st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议")
+st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议")
# Status badges
-if succeeded:
- st.badge(f"📚 知识库 {len(succeeded)} 种", color="blue")
+st.success(f"📚 图片索引构建完成,成功 {len(succeeded)} 张")
if failed:
- st.badge(f"⚠️ 索引失败 {len(failed)} 种", color="red")
+ st.warning(f"以下 {len(failed)} 张图片索引失败:")
+ for name, err in failed:
+ st.error(f"- **{name}**:{err}")
st.markdown("
", unsafe_allow_html=True)
# ─── Search Logic ────────────────────────────────────────────────────────────
-if search_clicked:
- if query_source is None:
- st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
- elif not index_items:
+if search_clicked and query_url.strip():
+ if not index_items:
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
else:
- query_img = load_image(query_source)
- if query_img is not None:
- col_query, col_preview = st.columns([1, 3])
- with col_query:
- st.subheader("🔍 查询图片")
- st.image(query_img, use_container_width=True)
+ with st.spinner("正在分析图片并搜索..."):
+ try:
+ query_embedding = get_image_embedding(query_url.strip())
- with col_preview:
- st.subheader("⏳ 正在分析...")
- progress = st.progress(0, text="提取图像特征...")
+ col_query, col_preview = st.columns([1, 3])
+ with col_query:
+ st.subheader("🔍 查询图片")
+ query_img = load_image(query_url.strip())
+ if query_img is not None:
+ st.image(query_img, use_container_width=True)
- embedder = get_embedder()
- query_embedding = embedder.embed(query_img)
- progress.progress(50, text="比对知识库...")
+ with col_preview:
+ scores = []
+ for item in index_items:
+ sim = cosine_similarity(query_embedding, item["embedding"])
+ scores.append((sim, item))
+ scores.sort(key=lambda x: x[0], reverse=True)
+ results = scores[:top_k]
- scores = []
- for item in index_items:
- sim = cosine_similarity(query_embedding, item["embedding"])
- scores.append((sim, item))
- scores.sort(key=lambda x: x[0], reverse=True)
- results = scores[:top_k]
- progress.progress(100, text="搜索完成")
- progress.empty()
+ st.subheader(f"🏆 搜索结果(Top-{len(results)})")
- st.subheader(f"🏆 搜索结果(Top-{len(results)})")
+ # Similarity bar chart
+ names = [f"{r[1]['name']}" for r in results]
+ sims = [r[0] * 100 for r in results]
+ colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
- # Similarity bar chart
- names = [f"{r[1]['name']}" for r in results]
- sims = [r[0] * 100 for r in results]
- colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
+ fig_bar = go.Figure()
+ fig_bar.add_trace(go.Bar(
+ x=sims,
+ y=names,
+ orientation="h",
+ marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
+ text=[f"{s:.1f}%" for s in sims],
+ textposition="outside",
+ textfont=dict(color="#5a5a5a", size=10),
+ ))
+ fig_bar.update_layout(
+ xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
+ yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
+ paper_bgcolor="rgba(0,0,0,0)",
+ plot_bgcolor="rgba(0,0,0,0)",
+ font=dict(color="#2c2c2c", size=11),
+ margin=dict(t=10, b=30, l=80, r=50),
+ height=160 + len(results) * 34,
+ showlegend=False,
+ )
+ st.plotly_chart(fig_bar, use_container_width=True)
- fig_bar = go.Figure()
- fig_bar.add_trace(go.Bar(
- x=sims,
- y=names,
- orientation="h",
- marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
- text=[f"{s:.1f}%" for s in sims],
- textposition="outside",
- textfont=dict(color="#5a5a5a", size=10),
- ))
- fig_bar.update_layout(
- xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
- yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
- paper_bgcolor="rgba(0,0,0,0)",
- plot_bgcolor="rgba(0,0,0,0)",
- font=dict(color="#2c2c2c", size=11),
- margin=dict(t=10, b=30, l=80, r=50),
- height=160 + len(results) * 34,
- showlegend=False,
- )
- st.plotly_chart(fig_bar, use_container_width=True)
+ # Result cards below
+ st.subheader("📋 详细结果")
+ for rank, (sim, item) in enumerate(results, 1):
+ with st.container(border=True):
+ c1, c2 = st.columns([1, 4])
+ with c1:
+ st.image(item["url"], use_container_width=True)
+ with c2:
+ header_col, score_col = st.columns([3, 1])
+ header_col.markdown(f"**#{rank} {item['name']}**")
+ score_col.markdown(f"