refactor(app): 替换本地 CLIP 模型为远程 Qwen3-VL-Embedding API
- 移除 CLIPEmbedder 本地模型类,改用远程图片 Embedding API 获取特征向量 - 新增 get_image_embedding() 函数,支持重试机制 - 移除本地图片上传功能,仅保留 URL 输入和示例图片选择 - build_index() 增加进度条显示,索引失败时展示具体错误信息 - 移除 torch、transformers、requests 依赖,新增 httpx - 更新界面文案,反映新的技术方案
This commit is contained in:
163
app.py
163
app.py
@@ -1,21 +1,26 @@
|
|||||||
"""
|
"""
|
||||||
病虫害以图搜图
|
病虫害以图搜图
|
||||||
基于 CLIP 本地模型的图片 Embedding 相似度搜索
|
基于图片 Embedding API 的相似度搜索
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import plotly.graph_objects as go
|
import plotly.graph_objects as go
|
||||||
import requests
|
import requests
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPModel, CLIPProcessor
|
|
||||||
|
# ─── API Config ──────────────────────────────────────────────────────────────
|
||||||
|
IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings"
|
||||||
|
EMBEDDING_MODEL = "Qwen3-VL-Embedding"
|
||||||
|
API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA"
|
||||||
|
|
||||||
# ─── Page Config ────────────────────────────────────────────────────────────
|
# ─── Page Config ────────────────────────────────────────────────────────────
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
@@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ─── CLIP Embedder ───────────────────────────────────────────────────────────
|
# ─── Embedding API ───────────────────────────────────────────────────────────
|
||||||
class CLIPEmbedder:
|
def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]:
|
||||||
MODEL_NAME = "openai/clip-vit-base-patch32"
|
"""调用远程 API 获取图片 Embedding,支持重试。"""
|
||||||
|
payload = {
|
||||||
def __init__(self) -> None:
|
"model": EMBEDDING_MODEL,
|
||||||
self._processor: CLIPProcessor | None = None
|
"messages": [
|
||||||
self._model: CLIPModel | None = None
|
{
|
||||||
|
"role": "user",
|
||||||
def _load(self) -> tuple[CLIPProcessor, CLIPModel]:
|
"content": [
|
||||||
if self._processor is None or self._model is None:
|
{"type": "text", "text": text},
|
||||||
with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."):
|
{"type": "image_url", "image_url": {"url": image_url}},
|
||||||
self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME)
|
],
|
||||||
self._model = CLIPModel.from_pretrained(self.MODEL_NAME)
|
}
|
||||||
return self._processor, self._model
|
],
|
||||||
|
}
|
||||||
def embed(self, image: Image.Image) -> np.ndarray:
|
last_error = None
|
||||||
processor, model = self._load()
|
for attempt in range(max_retries + 1):
|
||||||
inputs = processor(images=image, return_tensors="pt")
|
try:
|
||||||
image_features = model.get_image_features(**inputs)
|
resp = httpx.post(
|
||||||
vec = image_features.detach().cpu().numpy().flatten()
|
IMAGE_EMBEDDING_API_URL,
|
||||||
norm = np.linalg.norm(vec)
|
headers={"Content-Type": "application/json"},
|
||||||
if norm == 0:
|
json=payload,
|
||||||
return vec
|
timeout=120,
|
||||||
return vec / norm
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["data"][0]["embedding"]
|
||||||
@st.cache_resource(show_spinner=False)
|
except httpx.HTTPStatusError as e:
|
||||||
def get_embedder() -> CLIPEmbedder:
|
last_error = e
|
||||||
return CLIPEmbedder()
|
if attempt < max_retries:
|
||||||
|
time.sleep(2 * (attempt + 1))
|
||||||
|
raise last_error
|
||||||
|
|
||||||
|
|
||||||
# ─── Utilities ───────────────────────────────────────────────────────────────
|
# ─── Utilities ───────────────────────────────────────────────────────────────
|
||||||
@@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
return float(np.dot(a, b))
|
a_arr = np.array(a)
|
||||||
|
b_arr = np.array(b)
|
||||||
|
norm_a = np.linalg.norm(a_arr)
|
||||||
|
norm_b = np.linalg.norm(b_arr)
|
||||||
|
if norm_a == 0 or norm_b == 0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data(show_spinner=False)
|
@st.cache_resource
|
||||||
def build_index() -> tuple[list[dict], list[str], list[str]]:
|
def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]:
|
||||||
embedder = get_embedder()
|
|
||||||
items, succeeded, failed = [], [], []
|
items, succeeded, failed = [], [], []
|
||||||
for pest in PEST_KNOWLEDGE:
|
progress = st.progress(0, text="正在构建图片索引...")
|
||||||
img = _load_image_raw(pest.url)
|
total = len(PEST_KNOWLEDGE)
|
||||||
if img is None:
|
for i, pest in enumerate(PEST_KNOWLEDGE):
|
||||||
failed.append(pest.name)
|
|
||||||
continue
|
|
||||||
try:
|
try:
|
||||||
embedding = embedder.embed(img)
|
embedding = get_image_embedding(pest.url, text=pest.name)
|
||||||
items.append({
|
items.append({
|
||||||
"name": pest.name,
|
"name": pest.name,
|
||||||
"url": pest.url,
|
"url": pest.url,
|
||||||
@@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
|
|||||||
"category": pest.category,
|
"category": pest.category,
|
||||||
})
|
})
|
||||||
succeeded.append(pest.name)
|
succeeded.append(pest.name)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
failed.append(pest.name)
|
failed.append((pest.name, str(e)))
|
||||||
|
progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...")
|
||||||
|
progress.empty()
|
||||||
return items, succeeded, failed
|
return items, succeeded, failed
|
||||||
|
|
||||||
|
|
||||||
# ─── Sidebar ─────────────────────────────────────────────────────────────────
|
# ─── Sidebar ─────────────────────────────────────────────────────────────────
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
st.header("🌿 病虫害以图搜图")
|
st.header("🌿 病虫害以图搜图")
|
||||||
st.caption("上传图片,智能识别相似病虫害")
|
st.caption("输入图片 URL,通过图片 Embedding 搜索相似病虫害")
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
st.subheader("🖼️ 输入方式")
|
st.subheader("🖼️ 输入方式")
|
||||||
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
||||||
|
|
||||||
# 初始化 session_state
|
# 初始化 session_state
|
||||||
if "query_url" not in st.session_state:
|
if "query_url" not in st.session_state:
|
||||||
st.session_state.query_url = ""
|
st.session_state.query_url = ""
|
||||||
if "query_image_bytes" not in st.session_state:
|
|
||||||
st.session_state.query_image_bytes = None
|
|
||||||
|
|
||||||
query_source = None
|
query_source = None
|
||||||
query_url = ""
|
query_url = ""
|
||||||
|
|
||||||
if input_mode == "上传本地图片":
|
if input_mode == "输入图片 URL":
|
||||||
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
|
|
||||||
if uploaded is not None:
|
|
||||||
st.session_state.query_image_bytes = uploaded.getvalue()
|
|
||||||
st.session_state.query_url = ""
|
|
||||||
if st.session_state.query_image_bytes is not None:
|
|
||||||
query_source = io.BytesIO(st.session_state.query_image_bytes)
|
|
||||||
elif input_mode == "输入图片 URL":
|
|
||||||
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
|
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
|
||||||
st.session_state.query_url = query_url
|
st.session_state.query_url = query_url
|
||||||
st.session_state.query_image_bytes = None
|
|
||||||
if query_url.strip():
|
if query_url.strip():
|
||||||
query_source = query_url.strip()
|
query_source = query_url.strip()
|
||||||
else:
|
else:
|
||||||
@@ -283,7 +285,6 @@ with st.sidebar:
|
|||||||
with cols[idx % 2]:
|
with cols[idx % 2]:
|
||||||
if st.button(name, key=f"ex_{name}"):
|
if st.button(name, key=f"ex_{name}"):
|
||||||
st.session_state.query_url = url
|
st.session_state.query_url = url
|
||||||
st.session_state.query_image_bytes = None
|
|
||||||
st.rerun()
|
st.rerun()
|
||||||
if st.session_state.query_url:
|
if st.session_state.query_url:
|
||||||
query_url = st.session_state.query_url
|
query_url = st.session_state.query_url
|
||||||
@@ -298,57 +299,52 @@ with st.sidebar:
|
|||||||
st.divider()
|
st.divider()
|
||||||
st.info(
|
st.info(
|
||||||
"**使用说明**\n\n"
|
"**使用说明**\n\n"
|
||||||
"1. 上传病虫害患处图片\n"
|
"1. 输入病虫害图片 URL 或选择示例\n"
|
||||||
"2. 系统自动提取图像特征\n"
|
"2. 系统调用 Embedding API 提取图像特征\n"
|
||||||
"3. 与知识库比对返回相似结果\n"
|
"3. 与知识库比对返回相似结果\n"
|
||||||
"4. 参考症状与防治建议"
|
"4. 参考症状与防治建议"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ─── Build Index ─────────────────────────────────────────────────────────────
|
# ─── Build Index ─────────────────────────────────────────────────────────────
|
||||||
index_items, succeeded, failed = build_index()
|
with st.spinner("首次加载需要构建图片索引,请稍候..."):
|
||||||
|
index_items, succeeded, failed = build_index()
|
||||||
|
|
||||||
# ─── Main Layout ─────────────────────────────────────────────────────────────
|
# ─── Main Layout ─────────────────────────────────────────────────────────────
|
||||||
st.title("🌿 病虫害以图搜图")
|
st.title("🌿 病虫害以图搜图")
|
||||||
st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议")
|
st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议")
|
||||||
|
|
||||||
# Status badges
|
# Status badges
|
||||||
if succeeded:
|
st.success(f"📚 图片索引构建完成,成功 {len(succeeded)} 张")
|
||||||
st.badge(f"📚 知识库 {len(succeeded)} 种", color="blue")
|
|
||||||
if failed:
|
if failed:
|
||||||
st.badge(f"⚠️ 索引失败 {len(failed)} 种", color="red")
|
st.warning(f"以下 {len(failed)} 张图片索引失败:")
|
||||||
|
for name, err in failed:
|
||||||
|
st.error(f"- **{name}**:{err}")
|
||||||
|
|
||||||
st.markdown("<br>", unsafe_allow_html=True)
|
st.markdown("<br>", unsafe_allow_html=True)
|
||||||
|
|
||||||
# ─── Search Logic ────────────────────────────────────────────────────────────
|
# ─── Search Logic ────────────────────────────────────────────────────────────
|
||||||
if search_clicked:
|
if search_clicked and query_url.strip():
|
||||||
if query_source is None:
|
if not index_items:
|
||||||
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
|
|
||||||
elif not index_items:
|
|
||||||
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
||||||
else:
|
else:
|
||||||
query_img = load_image(query_source)
|
with st.spinner("正在分析图片并搜索..."):
|
||||||
if query_img is not None:
|
try:
|
||||||
|
query_embedding = get_image_embedding(query_url.strip())
|
||||||
|
|
||||||
col_query, col_preview = st.columns([1, 3])
|
col_query, col_preview = st.columns([1, 3])
|
||||||
with col_query:
|
with col_query:
|
||||||
st.subheader("🔍 查询图片")
|
st.subheader("🔍 查询图片")
|
||||||
|
query_img = load_image(query_url.strip())
|
||||||
|
if query_img is not None:
|
||||||
st.image(query_img, use_container_width=True)
|
st.image(query_img, use_container_width=True)
|
||||||
|
|
||||||
with col_preview:
|
with col_preview:
|
||||||
st.subheader("⏳ 正在分析...")
|
|
||||||
progress = st.progress(0, text="提取图像特征...")
|
|
||||||
|
|
||||||
embedder = get_embedder()
|
|
||||||
query_embedding = embedder.embed(query_img)
|
|
||||||
progress.progress(50, text="比对知识库...")
|
|
||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for item in index_items:
|
for item in index_items:
|
||||||
sim = cosine_similarity(query_embedding, item["embedding"])
|
sim = cosine_similarity(query_embedding, item["embedding"])
|
||||||
scores.append((sim, item))
|
scores.append((sim, item))
|
||||||
scores.sort(key=lambda x: x[0], reverse=True)
|
scores.sort(key=lambda x: x[0], reverse=True)
|
||||||
results = scores[:top_k]
|
results = scores[:top_k]
|
||||||
progress.progress(100, text="搜索完成")
|
|
||||||
progress.empty()
|
|
||||||
|
|
||||||
st.subheader(f"🏆 搜索结果(Top-{len(results)})")
|
st.subheader(f"🏆 搜索结果(Top-{len(results)})")
|
||||||
|
|
||||||
@@ -408,6 +404,9 @@ if search_clicked:
|
|||||||
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
|
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"搜索失败: {e}")
|
||||||
|
|
||||||
# ─── Footer ───────────────────────────────────────────────────────────────────
|
# ─── Footer ───────────────────────────────────────────────────────────────────
|
||||||
st.divider()
|
st.divider()
|
||||||
st.caption("病虫害以图搜图 · 基于 CLIP 视觉模型 · 结果仅供参考,请结合田间实际情况判断")
|
st.caption("病虫害以图搜图 · 基于 Qwen3-VL-Embedding · 结果仅供参考,请结合田间实际情况判断")
|
||||||
|
|||||||
@@ -5,14 +5,12 @@ description = "病虫害以图搜图 — 基于图片 Embedding 的相似度搜
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.14"
|
requires-python = ">=3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"httpx>=0.28.1",
|
||||||
"numpy>=2.3.5",
|
"numpy>=2.3.5",
|
||||||
"pillow>=11.2.1",
|
"pillow>=11.2.1",
|
||||||
"plotly>=6.5.0",
|
"plotly>=6.5.0",
|
||||||
"requests>=2.32.3",
|
|
||||||
"ruff>=0.14.8",
|
"ruff>=0.14.8",
|
||||||
"streamlit==1.52.1",
|
"streamlit==1.52.1",
|
||||||
"torch>=2.7.0",
|
|
||||||
"transformers>=4.51.3",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
|
|||||||
Reference in New Issue
Block a user