refactor(app): 替换本地 CLIP 模型为远程 Qwen3-VL-Embedding API

- 移除 CLIPEmbedder 本地模型类,改用远程图片 Embedding API 获取特征向量
  - 新增 get_image_embedding() 函数,支持重试机制
  - 移除本地图片上传功能,仅保留 URL 输入和示例图片选择
  - build_index() 增加进度条显示,索引失败时展示具体错误信息
  - 移除 torch、transformers、requests 依赖,新增 httpx
  - 更新界面文案,反映新的技术方案
This commit is contained in:
zhenghu
2026-04-15 09:55:32 +08:00
parent 722d7dc57d
commit fdfc3e2e2b
2 changed files with 142 additions and 145 deletions

163
app.py
View File

@@ -1,21 +1,26 @@
"""
病虫害以图搜图
基于 CLIP 本地模型的图片 Embedding 相似度搜索
基于图片 Embedding API 的相似度搜索
"""
from __future__ import annotations
import io
import os
import time
from dataclasses import dataclass
from typing import Literal
import httpx
import numpy as np
import plotly.graph_objects as go
import requests
import streamlit as st
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
# ─── API Config ──────────────────────────────────────────────────────────────
IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings"
EMBEDDING_MODEL = "Qwen3-VL-Embedding"
API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA"
# ─── Page Config ────────────────────────────────────────────────────────────
st.set_page_config(
@@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [
]
# ─── CLIP Embedder ───────────────────────────────────────────────────────────
class CLIPEmbedder:
MODEL_NAME = "openai/clip-vit-base-patch32"
def __init__(self) -> None:
self._processor: CLIPProcessor | None = None
self._model: CLIPModel | None = None
def _load(self) -> tuple[CLIPProcessor, CLIPModel]:
if self._processor is None or self._model is None:
with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."):
self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME)
self._model = CLIPModel.from_pretrained(self.MODEL_NAME)
return self._processor, self._model
def embed(self, image: Image.Image) -> np.ndarray:
processor, model = self._load()
inputs = processor(images=image, return_tensors="pt")
image_features = model.get_image_features(**inputs)
vec = image_features.detach().cpu().numpy().flatten()
norm = np.linalg.norm(vec)
if norm == 0:
return vec
return vec / norm
@st.cache_resource(show_spinner=False)
def get_embedder() -> CLIPEmbedder:
return CLIPEmbedder()
# ─── Embedding API ───────────────────────────────────────────────────────────
def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]:
"""调用远程 API 获取图片 Embedding支持重试。"""
payload = {
"model": EMBEDDING_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
}
last_error = None
for attempt in range(max_retries + 1):
try:
resp = httpx.post(
IMAGE_EMBEDDING_API_URL,
headers={"Content-Type": "application/json"},
json=payload,
timeout=120,
)
resp.raise_for_status()
return resp.json()["data"][0]["embedding"]
except httpx.HTTPStatusError as e:
last_error = e
if attempt < max_retries:
time.sleep(2 * (attempt + 1))
raise last_error
# ─── Utilities ───────────────────────────────────────────────────────────────
@@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None:
return img
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b))
def cosine_similarity(a: list[float], b: list[float]) -> float:
a_arr = np.array(a)
b_arr = np.array(b)
norm_a = np.linalg.norm(a_arr)
norm_b = np.linalg.norm(b_arr)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
@st.cache_data(show_spinner=False)
def build_index() -> tuple[list[dict], list[str], list[str]]:
embedder = get_embedder()
@st.cache_resource
def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]:
items, succeeded, failed = [], [], []
for pest in PEST_KNOWLEDGE:
img = _load_image_raw(pest.url)
if img is None:
failed.append(pest.name)
continue
progress = st.progress(0, text="正在构建图片索引...")
total = len(PEST_KNOWLEDGE)
for i, pest in enumerate(PEST_KNOWLEDGE):
try:
embedding = embedder.embed(img)
embedding = get_image_embedding(pest.url, text=pest.name)
items.append({
"name": pest.name,
"url": pest.url,
@@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
"category": pest.category,
})
succeeded.append(pest.name)
except Exception:
failed.append(pest.name)
except Exception as e:
failed.append((pest.name, str(e)))
progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...")
progress.empty()
return items, succeeded, failed
# ─── Sidebar ─────────────────────────────────────────────────────────────────
with st.sidebar:
st.header("🌿 病虫害以图搜图")
st.caption("上传图片,智能识别相似病虫害")
st.caption("输入图片 URL通过图片 Embedding 搜索相似病虫害")
st.divider()
st.subheader("🖼️ 输入方式")
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed")
# 初始化 session_state
if "query_url" not in st.session_state:
st.session_state.query_url = ""
if "query_image_bytes" not in st.session_state:
st.session_state.query_image_bytes = None
query_source = None
query_url = ""
if input_mode == "上传本地图片":
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
if uploaded is not None:
st.session_state.query_image_bytes = uploaded.getvalue()
st.session_state.query_url = ""
if st.session_state.query_image_bytes is not None:
query_source = io.BytesIO(st.session_state.query_image_bytes)
elif input_mode == "输入图片 URL":
if input_mode == "输入图片 URL":
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
st.session_state.query_url = query_url
st.session_state.query_image_bytes = None
if query_url.strip():
query_source = query_url.strip()
else:
@@ -283,7 +285,6 @@ with st.sidebar:
with cols[idx % 2]:
if st.button(name, key=f"ex_{name}"):
st.session_state.query_url = url
st.session_state.query_image_bytes = None
st.rerun()
if st.session_state.query_url:
query_url = st.session_state.query_url
@@ -298,57 +299,52 @@ with st.sidebar:
st.divider()
st.info(
"**使用说明**\n\n"
"1. 上传病虫害患处图片\n"
"2. 系统自动提取图像特征\n"
"1. 输入病虫害图片 URL 或选择示例\n"
"2. 系统调用 Embedding API 提取图像特征\n"
"3. 与知识库比对返回相似结果\n"
"4. 参考症状与防治建议"
)
# ─── Build Index ─────────────────────────────────────────────────────────────
index_items, succeeded, failed = build_index()
with st.spinner("首次加载需要构建图片索引,请稍候..."):
index_items, succeeded, failed = build_index()
# ─── Main Layout ─────────────────────────────────────────────────────────────
st.title("🌿 病虫害以图搜图")
st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议")
st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议")
# Status badges
if succeeded:
st.badge(f"📚 知识库 {len(succeeded)}", color="blue")
st.success(f"📚 图片索引构建完成,成功 {len(succeeded)}")
if failed:
st.badge(f"⚠️ 索引失败 {len(failed)} ", color="red")
st.warning(f"以下 {len(failed)} 张图片索引失败:")
for name, err in failed:
st.error(f"- **{name}**{err}")
st.markdown("<br>", unsafe_allow_html=True)
# ─── Search Logic ────────────────────────────────────────────────────────────
if search_clicked:
if query_source is None:
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
elif not index_items:
if search_clicked and query_url.strip():
if not index_items:
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
else:
query_img = load_image(query_source)
if query_img is not None:
with st.spinner("正在分析图片并搜索..."):
try:
query_embedding = get_image_embedding(query_url.strip())
col_query, col_preview = st.columns([1, 3])
with col_query:
st.subheader("🔍 查询图片")
query_img = load_image(query_url.strip())
if query_img is not None:
st.image(query_img, use_container_width=True)
with col_preview:
st.subheader("⏳ 正在分析...")
progress = st.progress(0, text="提取图像特征...")
embedder = get_embedder()
query_embedding = embedder.embed(query_img)
progress.progress(50, text="比对知识库...")
scores = []
for item in index_items:
sim = cosine_similarity(query_embedding, item["embedding"])
scores.append((sim, item))
scores.sort(key=lambda x: x[0], reverse=True)
results = scores[:top_k]
progress.progress(100, text="搜索完成")
progress.empty()
st.subheader(f"🏆 搜索结果Top-{len(results)}")
@@ -408,6 +404,9 @@ if search_clicked:
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
)
except Exception as e:
st.error(f"搜索失败: {e}")
# ─── Footer ───────────────────────────────────────────────────────────────────
st.divider()
st.caption("病虫害以图搜图 · 基于 CLIP 视觉模型 · 结果仅供参考,请结合田间实际情况判断")
st.caption("病虫害以图搜图 · 基于 Qwen3-VL-Embedding · 结果仅供参考,请结合田间实际情况判断")

View File

@@ -5,14 +5,12 @@ description = "病虫害以图搜图 — 基于图片 Embedding 的相似度搜
readme = "README.md"
requires-python = ">=3.14"
dependencies = [
"httpx>=0.28.1",
"numpy>=2.3.5",
"pillow>=11.2.1",
"plotly>=6.5.0",
"requests>=2.32.3",
"ruff>=0.14.8",
"streamlit==1.52.1",
"torch>=2.7.0",
"transformers>=4.51.3",
]
[[tool.uv.index]]