refactor(app): 替换本地 CLIP 模型为远程 Qwen3-VL-Embedding API

- 移除 CLIPEmbedder 本地模型类,改用远程图片 Embedding API 获取特征向量
  - 新增 get_image_embedding() 函数,支持重试机制
  - 移除本地图片上传功能,仅保留 URL 输入和示例图片选择
  - build_index() 增加进度条显示,索引失败时展示具体错误信息
  - 移除 torch、transformers、requests 依赖,新增 httpx
  - 更新界面文案,反映新的技术方案
This commit is contained in:
zhenghu
2026-04-15 09:55:32 +08:00
parent 722d7dc57d
commit fdfc3e2e2b
2 changed files with 142 additions and 145 deletions

161
app.py
View File

@@ -1,21 +1,26 @@
""" """
病虫害以图搜图 病虫害以图搜图
基于 CLIP 本地模型的图片 Embedding 相似度搜索 基于图片 Embedding API 的相似度搜索
""" """
from __future__ import annotations from __future__ import annotations
import io import io
import os import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Literal
import httpx
import numpy as np import numpy as np
import plotly.graph_objects as go import plotly.graph_objects as go
import requests import requests
import streamlit as st import streamlit as st
from PIL import Image from PIL import Image
from transformers import CLIPModel, CLIPProcessor
# ─── API Config ──────────────────────────────────────────────────────────────
IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings"
EMBEDDING_MODEL = "Qwen3-VL-Embedding"
API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA"
# ─── Page Config ──────────────────────────────────────────────────────────── # ─── Page Config ────────────────────────────────────────────────────────────
st.set_page_config( st.set_page_config(
@@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [
] ]
# ─── CLIP Embedder ─────────────────────────────────────────────────────────── # ─── Embedding API ───────────────────────────────────────────────────────────
class CLIPEmbedder: def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]:
MODEL_NAME = "openai/clip-vit-base-patch32" """调用远程 API 获取图片 Embedding支持重试。"""
payload = {
def __init__(self) -> None: "model": EMBEDDING_MODEL,
self._processor: CLIPProcessor | None = None "messages": [
self._model: CLIPModel | None = None {
"role": "user",
def _load(self) -> tuple[CLIPProcessor, CLIPModel]: "content": [
if self._processor is None or self._model is None: {"type": "text", "text": text},
with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."): {"type": "image_url", "image_url": {"url": image_url}},
self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME) ],
self._model = CLIPModel.from_pretrained(self.MODEL_NAME) }
return self._processor, self._model ],
}
def embed(self, image: Image.Image) -> np.ndarray: last_error = None
processor, model = self._load() for attempt in range(max_retries + 1):
inputs = processor(images=image, return_tensors="pt") try:
image_features = model.get_image_features(**inputs) resp = httpx.post(
vec = image_features.detach().cpu().numpy().flatten() IMAGE_EMBEDDING_API_URL,
norm = np.linalg.norm(vec) headers={"Content-Type": "application/json"},
if norm == 0: json=payload,
return vec timeout=120,
return vec / norm )
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]
@st.cache_resource(show_spinner=False) except httpx.HTTPStatusError as e:
def get_embedder() -> CLIPEmbedder: last_error = e
return CLIPEmbedder() if attempt < max_retries:
time.sleep(2 * (attempt + 1))
raise last_error
# ─── Utilities ─────────────────────────────────────────────────────────────── # ─── Utilities ───────────────────────────────────────────────────────────────
@@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None:
return img return img
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: def cosine_similarity(a: list[float], b: list[float]) -> float:
return float(np.dot(a, b)) a_arr = np.array(a)
b_arr = np.array(b)
norm_a = np.linalg.norm(a_arr)
norm_b = np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
@st.cache_data(show_spinner=False) @st.cache_resource
def build_index() -> tuple[list[dict], list[str], list[str]]: def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]:
embedder = get_embedder()
items, succeeded, failed = [], [], [] items, succeeded, failed = [], [], []
for pest in PEST_KNOWLEDGE: progress = st.progress(0, text="正在构建图片索引...")
img = _load_image_raw(pest.url) total = len(PEST_KNOWLEDGE)
if img is None: for i, pest in enumerate(PEST_KNOWLEDGE):
failed.append(pest.name)
continue
try: try:
embedding = embedder.embed(img) embedding = get_image_embedding(pest.url, text=pest.name)
items.append({ items.append({
"name": pest.name, "name": pest.name,
"url": pest.url, "url": pest.url,
@@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
"category": pest.category, "category": pest.category,
}) })
succeeded.append(pest.name) succeeded.append(pest.name)
except Exception: except Exception as e:
failed.append(pest.name) failed.append((pest.name, str(e)))
progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...")
progress.empty()
return items, succeeded, failed return items, succeeded, failed
# ─── Sidebar ───────────────────────────────────────────────────────────────── # ─── Sidebar ─────────────────────────────────────────────────────────────────
with st.sidebar: with st.sidebar:
st.header("🌿 病虫害以图搜图") st.header("🌿 病虫害以图搜图")
st.caption("上传图片,智能识别相似病虫害") st.caption("输入图片 URL通过图片 Embedding 搜索相似病虫害")
st.divider() st.divider()
st.subheader("🖼️ 输入方式") st.subheader("🖼️ 输入方式")
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed") input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed")
# 初始化 session_state # 初始化 session_state
if "query_url" not in st.session_state: if "query_url" not in st.session_state:
st.session_state.query_url = "" st.session_state.query_url = ""
if "query_image_bytes" not in st.session_state:
st.session_state.query_image_bytes = None
query_source = None query_source = None
query_url = "" query_url = ""
if input_mode == "上传本地图片": if input_mode == "输入图片 URL":
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
if uploaded is not None:
st.session_state.query_image_bytes = uploaded.getvalue()
st.session_state.query_url = ""
if st.session_state.query_image_bytes is not None:
query_source = io.BytesIO(st.session_state.query_image_bytes)
elif input_mode == "输入图片 URL":
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg") query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
st.session_state.query_url = query_url st.session_state.query_url = query_url
st.session_state.query_image_bytes = None
if query_url.strip(): if query_url.strip():
query_source = query_url.strip() query_source = query_url.strip()
else: else:
@@ -283,7 +285,6 @@ with st.sidebar:
with cols[idx % 2]: with cols[idx % 2]:
if st.button(name, key=f"ex_{name}"): if st.button(name, key=f"ex_{name}"):
st.session_state.query_url = url st.session_state.query_url = url
st.session_state.query_image_bytes = None
st.rerun() st.rerun()
if st.session_state.query_url: if st.session_state.query_url:
query_url = st.session_state.query_url query_url = st.session_state.query_url
@@ -298,57 +299,52 @@ with st.sidebar:
st.divider() st.divider()
st.info( st.info(
"**使用说明**\n\n" "**使用说明**\n\n"
"1. 上传病虫害患处图片\n" "1. 输入病虫害图片 URL 或选择示例\n"
"2. 系统自动提取图像特征\n" "2. 系统调用 Embedding API 提取图像特征\n"
"3. 与知识库比对返回相似结果\n" "3. 与知识库比对返回相似结果\n"
"4. 参考症状与防治建议" "4. 参考症状与防治建议"
) )
# ─── Build Index ───────────────────────────────────────────────────────────── # ─── Build Index ─────────────────────────────────────────────────────────────
with st.spinner("首次加载需要构建图片索引,请稍候..."):
index_items, succeeded, failed = build_index() index_items, succeeded, failed = build_index()
# ─── Main Layout ───────────────────────────────────────────────────────────── # ─── Main Layout ─────────────────────────────────────────────────────────────
st.title("🌿 病虫害以图搜图") st.title("🌿 病虫害以图搜图")
st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议") st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议")
# Status badges # Status badges
if succeeded: st.success(f"📚 图片索引构建完成,成功 {len(succeeded)}")
st.badge(f"📚 知识库 {len(succeeded)}", color="blue")
if failed: if failed:
st.badge(f"⚠️ 索引失败 {len(failed)} ", color="red") st.warning(f"以下 {len(failed)} 张图片索引失败:")
for name, err in failed:
st.error(f"- **{name}**{err}")
st.markdown("<br>", unsafe_allow_html=True) st.markdown("<br>", unsafe_allow_html=True)
# ─── Search Logic ──────────────────────────────────────────────────────────── # ─── Search Logic ────────────────────────────────────────────────────────────
if search_clicked: if search_clicked and query_url.strip():
if query_source is None: if not index_items:
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
elif not index_items:
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。") st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
else: else:
query_img = load_image(query_source) with st.spinner("正在分析图片并搜索..."):
if query_img is not None: try:
query_embedding = get_image_embedding(query_url.strip())
col_query, col_preview = st.columns([1, 3]) col_query, col_preview = st.columns([1, 3])
with col_query: with col_query:
st.subheader("🔍 查询图片") st.subheader("🔍 查询图片")
query_img = load_image(query_url.strip())
if query_img is not None:
st.image(query_img, use_container_width=True) st.image(query_img, use_container_width=True)
with col_preview: with col_preview:
st.subheader("⏳ 正在分析...")
progress = st.progress(0, text="提取图像特征...")
embedder = get_embedder()
query_embedding = embedder.embed(query_img)
progress.progress(50, text="比对知识库...")
scores = [] scores = []
for item in index_items: for item in index_items:
sim = cosine_similarity(query_embedding, item["embedding"]) sim = cosine_similarity(query_embedding, item["embedding"])
scores.append((sim, item)) scores.append((sim, item))
scores.sort(key=lambda x: x[0], reverse=True) scores.sort(key=lambda x: x[0], reverse=True)
results = scores[:top_k] results = scores[:top_k]
progress.progress(100, text="搜索完成")
progress.empty()
st.subheader(f"🏆 搜索结果Top-{len(results)}") st.subheader(f"🏆 搜索结果Top-{len(results)}")
@@ -408,6 +404,9 @@ if search_clicked:
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**" f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
) )
except Exception as e:
st.error(f"搜索失败: {e}")
# ─── Footer ─────────────────────────────────────────────────────────────────── # ─── Footer ───────────────────────────────────────────────────────────────────
st.divider() st.divider()
st.caption("病虫害以图搜图 · 基于 CLIP 视觉模型 · 结果仅供参考,请结合田间实际情况判断") st.caption("病虫害以图搜图 · 基于 Qwen3-VL-Embedding · 结果仅供参考,请结合田间实际情况判断")

View File

@@ -5,14 +5,12 @@ description = "病虫害以图搜图 — 基于图片 Embedding 的相似度搜
readme = "README.md" readme = "README.md"
requires-python = ">=3.14" requires-python = ">=3.14"
dependencies = [ dependencies = [
"httpx>=0.28.1",
"numpy>=2.3.5", "numpy>=2.3.5",
"pillow>=11.2.1", "pillow>=11.2.1",
"plotly>=6.5.0", "plotly>=6.5.0",
"requests>=2.32.3",
"ruff>=0.14.8", "ruff>=0.14.8",
"streamlit==1.52.1", "streamlit==1.52.1",
"torch>=2.7.0",
"transformers>=4.51.3",
] ]
[[tool.uv.index]] [[tool.uv.index]]