diff --git a/app.py b/app.py index 16a59b4..b1359ae 100644 --- a/app.py +++ b/app.py @@ -12,7 +12,7 @@ from PIL import Image # 国内 HuggingFace 镜像加速 os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com") -from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline # noqa: E402 +from transformers import ViTForImageClassification, ViTImageProcessor, pipeline # noqa: E402 # ─── Disease Label Mapping ────────────────────────────────────────────────── # 模型输出 LABEL_0 ~ LABEL_6,映射为实际病害名称 @@ -91,21 +91,13 @@ html, body, [class*="css"] { def load_model(): """加载 HuggingFace 模型(首次运行自动下载,约 343MB)""" model_name = "Dmitry43243242/banana-disease-leaf-model" - try: - classifier = pipeline( - "image-classification", - model=model_name, - ) - except ValueError: - # 模型的 preprocessor_config.json 中 image_processor_type 可能是旧版名称 - # 手动加载并构造 pipeline - processor = AutoImageProcessor.from_pretrained( - model_name, trust_remote_code=False - ) - model = AutoModelForImageClassification.from_pretrained(model_name) - classifier = pipeline( - "image-classification", model=model, image_processor=processor - ) + # 模型的 preprocessor_config.json 中 image_processor_type 为旧版 ViTFeatureExtractor, + # AutoImageProcessor 无法识别,直接用具体的 ViT 类加载 + processor = ViTImageProcessor.from_pretrained(model_name) + model = ViTForImageClassification.from_pretrained(model_name) + classifier = pipeline( + "image-classification", model=model, image_processor=processor + ) return classifier