refactor(app): 替换本地 CLIP 模型为远程 Qwen3-VL-Embedding API
- 移除 CLIPEmbedder 本地模型类,改用远程图片 Embedding API 获取特征向量 - 新增 get_image_embedding() 函数,支持重试机制 - 移除本地图片上传功能,仅保留 URL 输入和示例图片选择 - build_index() 增加进度条显示,索引失败时展示具体错误信息 - 移除 torch、transformers、requests 依赖,新增 httpx - 更新界面文案,反映新的技术方案
This commit is contained in:
283
app.py
283
app.py
@@ -1,21 +1,26 @@
|
||||
"""
|
||||
病虫害以图搜图
|
||||
基于 CLIP 本地模型的图片 Embedding 相似度搜索
|
||||
基于图片 Embedding API 的相似度搜索
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import requests
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# ─── API Config ──────────────────────────────────────────────────────────────
|
||||
IMAGE_EMBEDDING_API_URL = "https://llm.dev.maimaiag.com/qwen3-vl-embedding/v1/embeddings"
|
||||
EMBEDDING_MODEL = "Qwen3-VL-Embedding"
|
||||
API_KEY = "sk--VnOesEU5D8wnHjdg0MEsA"
|
||||
|
||||
# ─── Page Config ────────────────────────────────────────────────────────────
|
||||
st.set_page_config(
|
||||
@@ -165,35 +170,37 @@ EXAMPLE_IMAGES: list[tuple[str, str]] = [
|
||||
]
|
||||
|
||||
|
||||
# ─── CLIP Embedder ───────────────────────────────────────────────────────────
|
||||
class CLIPEmbedder:
|
||||
MODEL_NAME = "openai/clip-vit-base-patch32"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._processor: CLIPProcessor | None = None
|
||||
self._model: CLIPModel | None = None
|
||||
|
||||
def _load(self) -> tuple[CLIPProcessor, CLIPModel]:
|
||||
if self._processor is None or self._model is None:
|
||||
with st.spinner("首次启动正在加载 CLIP 模型,请稍候..."):
|
||||
self._processor = CLIPProcessor.from_pretrained(self.MODEL_NAME)
|
||||
self._model = CLIPModel.from_pretrained(self.MODEL_NAME)
|
||||
return self._processor, self._model
|
||||
|
||||
def embed(self, image: Image.Image) -> np.ndarray:
|
||||
processor, model = self._load()
|
||||
inputs = processor(images=image, return_tensors="pt")
|
||||
image_features = model.get_image_features(**inputs)
|
||||
vec = image_features.detach().cpu().numpy().flatten()
|
||||
norm = np.linalg.norm(vec)
|
||||
if norm == 0:
|
||||
return vec
|
||||
return vec / norm
|
||||
|
||||
|
||||
@st.cache_resource(show_spinner=False)
|
||||
def get_embedder() -> CLIPEmbedder:
|
||||
return CLIPEmbedder()
|
||||
# ─── Embedding API ───────────────────────────────────────────────────────────
|
||||
def get_image_embedding(image_url: str, text: str = "这是什么病虫害?", max_retries: int = 2) -> list[float]:
|
||||
"""调用远程 API 获取图片 Embedding,支持重试。"""
|
||||
payload = {
|
||||
"model": EMBEDDING_MODEL,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
last_error = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
resp = httpx.post(
|
||||
IMAGE_EMBEDDING_API_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["data"][0]["embedding"]
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = e
|
||||
if attempt < max_retries:
|
||||
time.sleep(2 * (attempt + 1))
|
||||
raise last_error
|
||||
|
||||
|
||||
# ─── Utilities ───────────────────────────────────────────────────────────────
|
||||
@@ -215,21 +222,24 @@ def load_image(source: str | io.BytesIO) -> Image.Image | None:
|
||||
return img
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
return float(np.dot(a, b))
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
a_arr = np.array(a)
|
||||
b_arr = np.array(b)
|
||||
norm_a = np.linalg.norm(a_arr)
|
||||
norm_b = np.linalg.norm(b_arr)
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return float(np.dot(a_arr, b_arr) / (norm_a * norm_b))
|
||||
|
||||
|
||||
@st.cache_data(show_spinner=False)
|
||||
def build_index() -> tuple[list[dict], list[str], list[str]]:
|
||||
embedder = get_embedder()
|
||||
@st.cache_resource
|
||||
def build_index() -> tuple[list[dict], list[str], list[tuple[str, str]]]:
|
||||
items, succeeded, failed = [], [], []
|
||||
for pest in PEST_KNOWLEDGE:
|
||||
img = _load_image_raw(pest.url)
|
||||
if img is None:
|
||||
failed.append(pest.name)
|
||||
continue
|
||||
progress = st.progress(0, text="正在构建图片索引...")
|
||||
total = len(PEST_KNOWLEDGE)
|
||||
for i, pest in enumerate(PEST_KNOWLEDGE):
|
||||
try:
|
||||
embedding = embedder.embed(img)
|
||||
embedding = get_image_embedding(pest.url, text=pest.name)
|
||||
items.append({
|
||||
"name": pest.name,
|
||||
"url": pest.url,
|
||||
@@ -240,40 +250,32 @@ def build_index() -> tuple[list[dict], list[str], list[str]]:
|
||||
"category": pest.category,
|
||||
})
|
||||
succeeded.append(pest.name)
|
||||
except Exception:
|
||||
failed.append(pest.name)
|
||||
except Exception as e:
|
||||
failed.append((pest.name, str(e)))
|
||||
progress.progress((i + 1) / total, text=f"正在构建图片索引 ({i + 1}/{total})...")
|
||||
progress.empty()
|
||||
return items, succeeded, failed
|
||||
|
||||
|
||||
# ─── Sidebar ─────────────────────────────────────────────────────────────────
|
||||
with st.sidebar:
|
||||
st.header("🌿 病虫害以图搜图")
|
||||
st.caption("上传图片,智能识别相似病虫害")
|
||||
st.caption("输入图片 URL,通过图片 Embedding 搜索相似病虫害")
|
||||
st.divider()
|
||||
|
||||
st.subheader("🖼️ 输入方式")
|
||||
input_mode = st.radio("", ["上传本地图片", "输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
||||
input_mode = st.radio("", ["输入图片 URL", "选择示例图片"], label_visibility="collapsed")
|
||||
|
||||
# 初始化 session_state
|
||||
if "query_url" not in st.session_state:
|
||||
st.session_state.query_url = ""
|
||||
if "query_image_bytes" not in st.session_state:
|
||||
st.session_state.query_image_bytes = None
|
||||
|
||||
query_source = None
|
||||
query_url = ""
|
||||
|
||||
if input_mode == "上传本地图片":
|
||||
uploaded = st.file_uploader("选择图片", type=["jpg", "jpeg", "png", "webp"])
|
||||
if uploaded is not None:
|
||||
st.session_state.query_image_bytes = uploaded.getvalue()
|
||||
st.session_state.query_url = ""
|
||||
if st.session_state.query_image_bytes is not None:
|
||||
query_source = io.BytesIO(st.session_state.query_image_bytes)
|
||||
elif input_mode == "输入图片 URL":
|
||||
if input_mode == "输入图片 URL":
|
||||
query_url = st.text_input("图片 URL", value=st.session_state.query_url, placeholder="https://example.com/image.jpg")
|
||||
st.session_state.query_url = query_url
|
||||
st.session_state.query_image_bytes = None
|
||||
if query_url.strip():
|
||||
query_source = query_url.strip()
|
||||
else:
|
||||
@@ -283,7 +285,6 @@ with st.sidebar:
|
||||
with cols[idx % 2]:
|
||||
if st.button(name, key=f"ex_{name}"):
|
||||
st.session_state.query_url = url
|
||||
st.session_state.query_image_bytes = None
|
||||
st.rerun()
|
||||
if st.session_state.query_url:
|
||||
query_url = st.session_state.query_url
|
||||
@@ -298,116 +299,114 @@ with st.sidebar:
|
||||
st.divider()
|
||||
st.info(
|
||||
"**使用说明**\n\n"
|
||||
"1. 上传病虫害患处图片\n"
|
||||
"2. 系统自动提取图像特征\n"
|
||||
"1. 输入病虫害图片 URL 或选择示例\n"
|
||||
"2. 系统调用 Embedding API 提取图像特征\n"
|
||||
"3. 与知识库比对返回相似结果\n"
|
||||
"4. 参考症状与防治建议"
|
||||
)
|
||||
|
||||
# ─── Build Index ─────────────────────────────────────────────────────────────
|
||||
index_items, succeeded, failed = build_index()
|
||||
with st.spinner("首次加载需要构建图片索引,请稍候..."):
|
||||
index_items, succeeded, failed = build_index()
|
||||
|
||||
# ─── Main Layout ─────────────────────────────────────────────────────────────
|
||||
st.title("🌿 病虫害以图搜图")
|
||||
st.caption("基于 CLIP 视觉模型的病虫害相似度检索与防治建议")
|
||||
st.caption("基于图片 Embedding API 的病虫害相似度检索与防治建议")
|
||||
|
||||
# Status badges
|
||||
if succeeded:
|
||||
st.badge(f"📚 知识库 {len(succeeded)} 种", color="blue")
|
||||
st.success(f"📚 图片索引构建完成,成功 {len(succeeded)} 张")
|
||||
if failed:
|
||||
st.badge(f"⚠️ 索引失败 {len(failed)} 种", color="red")
|
||||
st.warning(f"以下 {len(failed)} 张图片索引失败:")
|
||||
for name, err in failed:
|
||||
st.error(f"- **{name}**:{err}")
|
||||
|
||||
st.markdown("<br>", unsafe_allow_html=True)
|
||||
|
||||
# ─── Search Logic ────────────────────────────────────────────────────────────
|
||||
if search_clicked:
|
||||
if query_source is None:
|
||||
st.warning("请先上传图片、输入图片 URL 或选择示例图片后再点击搜索")
|
||||
elif not index_items:
|
||||
if search_clicked and query_url.strip():
|
||||
if not index_items:
|
||||
st.warning("知识库索引为空,请检查网络连接后刷新页面重试。")
|
||||
else:
|
||||
query_img = load_image(query_source)
|
||||
if query_img is not None:
|
||||
col_query, col_preview = st.columns([1, 3])
|
||||
with col_query:
|
||||
st.subheader("🔍 查询图片")
|
||||
st.image(query_img, use_container_width=True)
|
||||
with st.spinner("正在分析图片并搜索..."):
|
||||
try:
|
||||
query_embedding = get_image_embedding(query_url.strip())
|
||||
|
||||
with col_preview:
|
||||
st.subheader("⏳ 正在分析...")
|
||||
progress = st.progress(0, text="提取图像特征...")
|
||||
col_query, col_preview = st.columns([1, 3])
|
||||
with col_query:
|
||||
st.subheader("🔍 查询图片")
|
||||
query_img = load_image(query_url.strip())
|
||||
if query_img is not None:
|
||||
st.image(query_img, use_container_width=True)
|
||||
|
||||
embedder = get_embedder()
|
||||
query_embedding = embedder.embed(query_img)
|
||||
progress.progress(50, text="比对知识库...")
|
||||
with col_preview:
|
||||
scores = []
|
||||
for item in index_items:
|
||||
sim = cosine_similarity(query_embedding, item["embedding"])
|
||||
scores.append((sim, item))
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
results = scores[:top_k]
|
||||
|
||||
scores = []
|
||||
for item in index_items:
|
||||
sim = cosine_similarity(query_embedding, item["embedding"])
|
||||
scores.append((sim, item))
|
||||
scores.sort(key=lambda x: x[0], reverse=True)
|
||||
results = scores[:top_k]
|
||||
progress.progress(100, text="搜索完成")
|
||||
progress.empty()
|
||||
st.subheader(f"🏆 搜索结果(Top-{len(results)})")
|
||||
|
||||
st.subheader(f"🏆 搜索结果(Top-{len(results)})")
|
||||
# Similarity bar chart
|
||||
names = [f"{r[1]['name']}" for r in results]
|
||||
sims = [r[0] * 100 for r in results]
|
||||
colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
|
||||
|
||||
# Similarity bar chart
|
||||
names = [f"{r[1]['name']}" for r in results]
|
||||
sims = [r[0] * 100 for r in results]
|
||||
colors = ["#c45c4a" if r[1]["category"] == "虫害" else "#4a7c59" for r in results]
|
||||
fig_bar = go.Figure()
|
||||
fig_bar.add_trace(go.Bar(
|
||||
x=sims,
|
||||
y=names,
|
||||
orientation="h",
|
||||
marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
|
||||
text=[f"{s:.1f}%" for s in sims],
|
||||
textposition="outside",
|
||||
textfont=dict(color="#5a5a5a", size=10),
|
||||
))
|
||||
fig_bar.update_layout(
|
||||
xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
|
||||
yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
|
||||
paper_bgcolor="rgba(0,0,0,0)",
|
||||
plot_bgcolor="rgba(0,0,0,0)",
|
||||
font=dict(color="#2c2c2c", size=11),
|
||||
margin=dict(t=10, b=30, l=80, r=50),
|
||||
height=160 + len(results) * 34,
|
||||
showlegend=False,
|
||||
)
|
||||
st.plotly_chart(fig_bar, use_container_width=True)
|
||||
|
||||
fig_bar = go.Figure()
|
||||
fig_bar.add_trace(go.Bar(
|
||||
x=sims,
|
||||
y=names,
|
||||
orientation="h",
|
||||
marker=dict(color=colors, opacity=0.85, line=dict(color="rgba(0,0,0,0.08)", width=1)),
|
||||
text=[f"{s:.1f}%" for s in sims],
|
||||
textposition="outside",
|
||||
textfont=dict(color="#5a5a5a", size=10),
|
||||
))
|
||||
fig_bar.update_layout(
|
||||
xaxis=dict(title="相似度 (%)", color="#5a5a5a", gridcolor="rgba(0,0,0,0.06)", range=[0, 105]),
|
||||
yaxis=dict(color="#5a5a5a", gridcolor="rgba(0,0,0,0.04)", autorange="reversed"),
|
||||
paper_bgcolor="rgba(0,0,0,0)",
|
||||
plot_bgcolor="rgba(0,0,0,0)",
|
||||
font=dict(color="#2c2c2c", size=11),
|
||||
margin=dict(t=10, b=30, l=80, r=50),
|
||||
height=160 + len(results) * 34,
|
||||
showlegend=False,
|
||||
)
|
||||
st.plotly_chart(fig_bar, use_container_width=True)
|
||||
# Result cards below
|
||||
st.subheader("📋 详细结果")
|
||||
for rank, (sim, item) in enumerate(results, 1):
|
||||
with st.container(border=True):
|
||||
c1, c2 = st.columns([1, 4])
|
||||
with c1:
|
||||
st.image(item["url"], use_container_width=True)
|
||||
with c2:
|
||||
header_col, score_col = st.columns([3, 1])
|
||||
header_col.markdown(f"**#{rank} {item['name']}**")
|
||||
score_col.markdown(f"<div style='text-align:right; font-weight:600;'>相似度 {sim*100:.1f}%</div>", unsafe_allow_html=True)
|
||||
|
||||
# Result cards below
|
||||
st.subheader("📋 详细结果")
|
||||
for rank, (sim, item) in enumerate(results, 1):
|
||||
with st.container(border=True):
|
||||
c1, c2 = st.columns([1, 4])
|
||||
with c1:
|
||||
st.image(item["url"], use_container_width=True)
|
||||
with c2:
|
||||
header_col, score_col = st.columns([3, 1])
|
||||
header_col.markdown(f"**#{rank} {item['name']}**")
|
||||
score_col.markdown(f"<div style='text-align:right; font-weight:600;'>相似度 {sim*100:.1f}%</div>", unsafe_allow_html=True)
|
||||
badge_cols = st.columns([1, 1, 4])
|
||||
badge_cols[0].caption(f"🌾 {item['crop']}")
|
||||
badge_cols[1].caption(f"🐛 {item['category']}" if item["category"] == "虫害" else f"🍃 {item['category']}")
|
||||
|
||||
badge_cols = st.columns([1, 1, 4])
|
||||
badge_cols[0].caption(f"🌾 {item['crop']}")
|
||||
badge_cols[1].caption(f"🐛 {item['category']}" if item["category"] == "虫害" else f"🍃 {item['category']}")
|
||||
st.markdown(f"**症状:** {item['symptoms']}")
|
||||
st.markdown(f"**防治:** {item['treatment']}")
|
||||
|
||||
st.markdown(f"**症状:** {item['symptoms']}")
|
||||
st.markdown(f"**防治:** {item['treatment']}")
|
||||
# Advisory summary
|
||||
if results:
|
||||
best = results[0][1]
|
||||
st.subheader("💡 初步建议")
|
||||
st.info(
|
||||
f"系统判断该图片与 **{best['name']}**({best['crop']}{best['category']})最为相似,"
|
||||
f"相似度 **{results[0][0]*100:.1f}%**。\n\n"
|
||||
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
|
||||
)
|
||||
|
||||
# Advisory summary
|
||||
if results:
|
||||
best = results[0][1]
|
||||
st.subheader("💡 初步建议")
|
||||
st.info(
|
||||
f"系统判断该图片与 **{best['name']}**({best['crop']}{best['category']})最为相似,"
|
||||
f"相似度 **{results[0][0]*100:.1f}%**。\n\n"
|
||||
f"建议结合田间实际情况进一步确认,参考防治方案:**{best['treatment']}**"
|
||||
)
|
||||
except Exception as e:
|
||||
st.error(f"搜索失败: {e}")
|
||||
|
||||
# ─── Footer ───────────────────────────────────────────────────────────────────
|
||||
st.divider()
|
||||
st.caption("病虫害以图搜图 · 基于 CLIP 视觉模型 · 结果仅供参考,请结合田间实际情况判断")
|
||||
st.caption("病虫害以图搜图 · 基于 Qwen3-VL-Embedding · 结果仅供参考,请结合田间实际情况判断")
|
||||
|
||||
Reference in New Issue
Block a user