From fdfc3e2e2bc0deeeda442159de68fbeb2232ade4 Mon Sep 17 00:00:00 2001 From: zhenghu <1831829219@qq.com> Date: Wed, 15 Apr 2026 09:55:32 +0800 Subject: [PATCH] =?UTF-8?q?refactor(app):=20=E6=9B=BF=E6=8D=A2=E6=9C=AC?= =?UTF-8?q?=E5=9C=B0=20CLIP=20=E6=A8=A1=E5=9E=8B=E4=B8=BA=E8=BF=9C?= =?UTF-8?q?=E7=A8=8B=20Qwen3-VL-Embedding=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 CLIPEmbedder 本地模型类,改用远程图片 Embedding API 获取特征向量 - 新增 get_image_embedding() 函数,支持重试机制 - 移除本地图片上传功能,仅保留 URL 输入和示例图片选择 - build_index() 增加进度条显示,索引失败时展示具体错误信息 - 移除 torch、transformers、requests 依赖,新增 httpx - 更新界面文案,反映新的技术方案 --- app.py | 283 ++++++++++++++++++++++++------------------------- pyproject.toml | 4 +- 2 files changed, 142 insertions(+), 145 deletions(-) diff --git a/app.py b/app.py index 1d80157..bf02b26 100644 --- a/app.py +++ b/app.py @@ -1,21 +1,26 @@ """ 病虫害以图搜图 -基于 CLIP 本地模型的图片 Embedding 相似度搜索 +基于图片 Embedding API 的相似度搜索 """ from __future__ import annotations import io -import os +import time from dataclasses import dataclass from typing import Literal +import httpx import numpy as np import plotly.graph_objects as go import requests import streamlit as st from PIL import Image -from transformers import CLIPModel, CLIPProcessor + +# ─── API Config ────────────────────────────────────────────────────────────── +IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings" +EMBEDDING_MODEL = "Qwen3-VL-Embedding" +API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA" # ─── Page Config ──────────────────────────────────────────────────────────── st.set_page_config( @@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [ ] -# ─── CLIP Embedder ─────────────────────────────────────────────────────────── -class CLIPEmbedder: - MODEL_NAME = "openai/clip-vit-base-patch32" - - def __init__(self) -> None: - self._processor: CLIPProcessor | None = None - self._model: CLIPModel | None = None - - def _load(self) -> tuple[CLIPProcessor, CLIPModel]: - if self._processor is None or self._model is None: - with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."): - self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME) - self._model = CLIPModel.from_pretrained(self.MODEL_NAME) - return self._processor, self._model - - def embed(self, image: Image.Image) -> np.ndarray: - processor, model = self._load() - inputs = processor(images=image, return_tensors="pt") - image_features = model.get_image_features(**inputs) - vec = image_features.detach().cpu().numpy().flatten() - norm = np.linalg.norm(vec) - if norm == 0: - return vec - return vec / norm - - -@st.cache_resource(show_spinner=False) -def get_embedder() -> CLIPEmbedder: - return CLIPEmbedder() +# ─── Embedding API ─────────────────────────────────────────────────────────── +def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]: + """调用远程 API 获取图片 Embedding,支持重试。""" + payload = { + "model": EMBEDDING_MODEL, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + } + ], + } + last_error = None + for attempt in range(max_retries + 1): + try: + resp = httpx.post( + IMAGE_EMBEDDING_API_URL, + headers={"Content-Type": "application/json"}, + json=payload, + timeout=120, + ) + resp.raise_for_status() + return resp.json()["data"][0]["embedding"] + except httpx.HTTPStatusError as e: + last_error = e + if attempt < max_retries: + time.sleep(2 * (attempt + 1)) + raise last_error # ─── Utilities ─────────────────────────────────────────────────────────────── @@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None: return img -def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: - return float(np.dot(a, b)) +def cosine_similarity(a: list[float], b: list[float]) -> float: + a_arr = np.array(a) + b_arr = np.array(b) + norm_a = np.linalg.norm(a_arr) + norm_b = np.linalg.norm(b_arr) + if norm_a == 0 or norm_b == 0: + return 0.0 + return float(np.dot(a_arr, b_arr) / (norm_a * norm_b)) -@st.cache_data(show_spinner=False) -def build_index() -> tuple[list[dict], list[str], list[str]]: - embedder = get_embedder() +@st.cache_resource +def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]: items, succeeded, failed = [], [], [] - for pest in PEST_KNOWLEDGE: - img = _load_image_raw(pest.url) - if img is None: - failed.append(pest.name) - continue + progress = st.progress(0, text="正在构建图片索引...") + total = len(PEST_KNOWLEDGE) + for i, pest in enumerate(PEST_KNOWLEDGE): try: - embedding = embedder.embed(img) + embedding = get_image_embedding(pest.url, text=pest.name) items.append({ "name": pest.name, "url": pest.url, @@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]: "category": pest.category, }) succeeded.append(pest.name) - except Exception: - failed.append(pest.name) + except Exception as e: + failed.append((pest.name, str(e))) + progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...") + progress.empty() return items, succeeded, failed # ─── Sidebar ───────────────────────────────────────────────────────────────── with st.sidebar: st.header("🌿 病虫害以图搜图") - st.caption("上传图片,智能识别相似病虫害") + st.caption("输入图片 URL,通过图片 Embedding 搜索相似病虫害") st.divider() st.subheader("🖼️ 输入方式") - input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed") + input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed") # 初始化 session_state if "query_url" not in st.session_state: st.session_state.query_url = "" - if "query_image_bytes" not in st.session_state: - st.session_state.query_image_bytes = None query_source = None query_url = "" - if input_mode == "上传本地图片": - uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"]) - if uploaded is not None: - st.session_state.query_image_bytes = uploaded.getvalue() - st.session_state.query_url = "" - if st.session_state.query_image_bytes is not None: - query_source = io.BytesIO(st.session_state.query_image_bytes) - elif input_mode == "输入图片 URL": + if input_mode == "输入图片 URL": query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg") st.session_state.query_url = query_url - st.session_state.query_image_bytes = None if query_url.strip(): query_source = query_url.strip() else: @@ -283,7 +285,6 @@ with st.sidebar: with cols[idx % 2]: if st.button(name, key=f"ex_{name}"): st.session_state.query_url = url - st.session_state.query_image_bytes = None st.rerun() if st.session_state.query_url: query_url = st.session_state.query_url @@ -298,116 +299,114 @@ with st.sidebar: st.divider() st.info( "**使用说明**\n\n" - "1. 上传病虫害患处图片\n" - "2. 系统自动提取图像特征\n" + "1. 输入病虫害图片 URL 或选择示例\n" + "2. 系统调用 Embedding API 提取图像特征\n" "3. 与知识库比对返回相似结果\n" "4. 参考症状与防治建议" ) # ─── Build Index ───────────────────────────────────────────────────────────── -index_items, succeeded, failed = build_index() +with st.spinner("首次加载需要构建图片索引,请稍候..."): + index_items, succeeded, failed = build_index() # ─── Main Layout ───────────────────────────────────────────────────────────── st.title("🌿 病虫害以图搜图") -st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议") +st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议") # Status badges -if succeeded: - st.badge(f"📚 知识库 {len(succeeded)} 种", color="blue") +st.success(f"📚 图片索引构建完成,成功 {len(succeeded)} 张") if failed: - st.badge(f"⚠️ 索引失败 {len(failed)} 种", color="red") + st.warning(f"以下 {len(failed)} 张图片索引失败:") + for name, err in failed: + st.error(f"- **{name}**:{err}") st.markdown("
", unsafe_allow_html=True) # ─── Search Logic ──────────────────────────────────────────────────────────── -if search_clicked: - if query_source is None: - st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索") - elif not index_items: +if search_clicked and query_url.strip(): + if not index_items: st.warning("知识库索引为空,请检查网络连接后刷新页面重试。") else: - query_img = load_image(query_source) - if query_img is not None: - col_query, col_preview = st.columns([1, 3]) - with col_query: - st.subheader("🔍 查询图片") - st.image(query_img, use_container_width=True) + with st.spinner("正在分析图片并搜索..."): + try: + query_embedding = get_image_embedding(query_url.strip()) - with col_preview: - st.subheader("⏳ 正在分析...") - progress = st.progress(0, text="提取图像特征...") + col_query, col_preview = st.columns([1, 3]) + with col_query: + st.subheader("🔍 查询图片") + query_img = load_image(query_url.strip()) + if query_img is not None: + st.image(query_img, use_container_width=True) - embedder = get_embedder() - query_embedding = embedder.embed(query_img) - progress.progress(50, text="比对知识库...") + with col_preview: + scores = [] + for item in index_items: + sim = cosine_similarity(query_embedding, item["embedding"]) + scores.append((sim, item)) + scores.sort(key=lambda x: x[0], reverse=True) + results = scores[:top_k] - scores = [] - for item in index_items: - sim = cosine_similarity(query_embedding, item["embedding"]) - scores.append((sim, item)) - scores.sort(key=lambda x: x[0], reverse=True) - results = scores[:top_k] - progress.progress(100, text="搜索完成") - progress.empty() + st.subheader(f"🏆 搜索结果(Top-{len(results)})") - st.subheader(f"🏆 搜索结果(Top-{len(results)})") + # Similarity bar chart + names = [f"{r[1]['name']}" for r in results] + sims = [r[0] * 100 for r in results] + colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results] - # Similarity bar chart - names = [f"{r[1]['name']}" for r in results] - sims = [r[0] * 100 for r in results] - colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results] + fig_bar = go.Figure() + fig_bar.add_trace(go.Bar( + x=sims, + y=names, + orientation="h", + marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)), + text=[f"{s:.1f}%" for s in sims], + textposition="outside", + textfont=dict(color="#5a5a5a", size=10), + )) + fig_bar.update_layout( + xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]), + yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"), + paper_bgcolor="rgba(0,0,0,0)", + plot_bgcolor="rgba(0,0,0,0)", + font=dict(color="#2c2c2c", size=11), + margin=dict(t=10, b=30, l=80, r=50), + height=160 + len(results) * 34, + showlegend=False, + ) + st.plotly_chart(fig_bar, use_container_width=True) - fig_bar = go.Figure() - fig_bar.add_trace(go.Bar( - x=sims, - y=names, - orientation="h", - marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)), - text=[f"{s:.1f}%" for s in sims], - textposition="outside", - textfont=dict(color="#5a5a5a", size=10), - )) - fig_bar.update_layout( - xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]), - yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"), - paper_bgcolor="rgba(0,0,0,0)", - plot_bgcolor="rgba(0,0,0,0)", - font=dict(color="#2c2c2c", size=11), - margin=dict(t=10, b=30, l=80, r=50), - height=160 + len(results) * 34, - showlegend=False, - ) - st.plotly_chart(fig_bar, use_container_width=True) + # Result cards below + st.subheader("📋 详细结果") + for rank, (sim, item) in enumerate(results, 1): + with st.container(border=True): + c1, c2 = st.columns([1, 4]) + with c1: + st.image(item["url"], use_container_width=True) + with c2: + header_col, score_col = st.columns([3, 1]) + header_col.markdown(f"**#{rank} {item['name']}**") + score_col.markdown(f"
相似度 {sim*100:.1f}%
", unsafe_allow_html=True) - # Result cards below - st.subheader("📋 详细结果") - for rank, (sim, item) in enumerate(results, 1): - with st.container(border=True): - c1, c2 = st.columns([1, 4]) - with c1: - st.image(item["url"], use_container_width=True) - with c2: - header_col, score_col = st.columns([3, 1]) - header_col.markdown(f"**#{rank} {item['name']}**") - score_col.markdown(f"
相似度 {sim*100:.1f}%
", unsafe_allow_html=True) + badge_cols = st.columns([1, 1, 4]) + badge_cols[0].caption(f"🌾 {item['crop']}") + badge_cols[1].caption(f"🐛 {item['category']}" if item["category"] == "虫害" else f"🍃 {item['category']}") - badge_cols = st.columns([1, 1, 4]) - badge_cols[0].caption(f"🌾 {item['crop']}") - badge_cols[1].caption(f"🐛 {item['category']}" if item["category"] == "虫害" else f"🍃 {item['category']}") + st.markdown(f"**症状:** {item['symptoms']}") + st.markdown(f"**防治:** {item['treatment']}") - st.markdown(f"**症状:** {item['symptoms']}") - st.markdown(f"**防治:** {item['treatment']}") + # Advisory summary + if results: + best = results[0][1] + st.subheader("💡 初步建议") + st.info( + f"系统判断该图片与 **{best['name']}**({best['crop']}{best['category']})最为相似," + f"相似度 **{results[0][0]*100:.1f}%**。\n\n" + f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**" + ) - # Advisory summary - if results: - best = results[0][1] - st.subheader("💡 初步建议") - st.info( - f"系统判断该图片与 **{best['name']}**({best['crop']}{best['category']})最为相似," - f"相似度 **{results[0][0]*100:.1f}%**。\n\n" - f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**" - ) + except Exception as e: + st.error(f"搜索失败: {e}") # ─── Footer ─────────────────────────────────────────────────────────────────── st.divider() -st.caption("病虫害以图搜图 · 基于 CLIP 视觉模型 · 结果仅供参考,请结合田间实际情况判断") +st.caption("病虫害以图搜图 · 基于 Qwen3-VL-Embedding · 结果仅供参考,请结合田间实际情况判断") diff --git a/pyproject.toml b/pyproject.toml index f415dea..ce139b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,12 @@ description = "病虫害以图搜图 — 基于图片 Embedding 的相似度搜 readme = "README.md" requires-python = ">=3.14" dependencies = [ + "httpx>=0.28.1", "numpy>=2.3.5", "pillow>=11.2.1", "plotly>=6.5.0", - "requests>=2.32.3", "ruff>=0.14.8", "streamlit==1.52.1", - "torch>=2.7.0", - "transformers>=4.51.3", ] [[tool.uv.index]]