From c72fd420082b4d1b904a8a7be7e2d2d3b6bc3a88 Mon Sep 17 00:00:00 2001 From: bdhtc Date: Wed, 15 Apr 2026 08:05:00 +0000 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 03_multimodal_and_rag_demo.ipynb | 421 +++++++++++++++++++++++++++++++ 1 file changed, 421 insertions(+) create mode 100644 03_multimodal_and_rag_demo.ipynb diff --git a/03_multimodal_and_rag_demo.ipynb b/03_multimodal_and_rag_demo.ipynb new file mode 100644 index 0000000..a249a3d --- /dev/null +++ b/03_multimodal_and_rag_demo.ipynb @@ -0,0 +1,421 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 多模态与 RAG 检索示例\n", + "\n", + "本 Notebook 演示如何使用多模态数据(文本+图像元数据)构建简单的 RAG 检索和问答系统。\n", + "\n", + "**场景:** 多模态 AI / RAG 知识库检索\n", + "\n", + "**数据集:**\n", + "- JSONL: 农业问答微调语料\n", + "- TXT: 农业知识库语料\n", + "- JSON: COCO 标注数据\n", + "- HDF5: 多光谱影像数据" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 环境准备" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "from collections import defaultdict\n", + "\n", + "DATA_DIR = os.path.dirname(os.path.abspath(''))\n", + "print('数据目录:', DATA_DIR)\n", + "print('环境准备完成')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. 加载 JSONL 问答语料" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 加载 JSONL 数据\n", + "qa_data = []\n", + "with open(f'{DATA_DIR}/jsonl/农业问答_微调语料.jsonl', 'r', encoding='utf-8') as f:\n", + " for line in f:\n", + " qa_data.append(json.loads(line))\n", + "\n", + "print(f'问答对数量: {len(qa_data)}')\n", + "print('\\n--- 示例数据 ---')\n", + "sample = qa_data[0]\n", + "for conv in sample['conversations']:\n", + " print(f'[{conv[\"role\"]}]: {conv[\"content\"][:100]}...')\n", + "print(f'\\n元数据: {json.dumps(sample[\"metadata\"], ensure_ascii=False)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. 加载 TXT 知识库语料" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 加载知识库文本\n", + "with open(f'{DATA_DIR}/txt/农业知识库语料.txt', 'r', encoding='utf-8') as f:\n", + " knowledge_raw = f.read()\n", + "\n", + "# 按文档分割\n", + "documents = []\n", + "for block in knowledge_raw.split('【文档')[1:]:\n", + " lines = block.strip().split('\\n', 1)\n", + " doc_id = lines[0].replace('】', '')\n", + " content = lines[1].strip() if len(lines) > 1 else ''\n", + " documents.append({'id': doc_id, 'content': content})\n", + "\n", + "print(f'知识库文档数: {len(documents)}')\n", + "print(f'\\n第一篇文档前200字:\\n{documents[0][\"content\"][:200]}...')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. 简单关键词检索(模拟 RAG 检索)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class SimpleRetriever:\n", + " \"\"\"基于关键词的简单检索器,模拟 RAG 检索流程\"\"\"\n", + " \n", + " def __init__(self, documents):\n", + " self.documents = documents\n", + " self.index = defaultdict(list) # 倒排索引\n", + " self._build_index()\n", + " \n", + " def _build_index(self):\n", + " \"\"\"构建简单的倒排索引\"\"\"\n", + " for doc in self.documents:\n", + " words = set(doc['content'])\n", + " for word in words:\n", + " self.index[word].append(doc['id'])\n", + " \n", + " def search(self, query, top_k=3):\n", + " \"\"\"关键词检索\"\"\"\n", + " scores = defaultdict(float)\n", + " query_chars = set(query)\n", + " \n", + " for doc in self.documents:\n", + " doc_chars = set(doc['content'])\n", + " overlap = query_chars & doc_chars\n", + " if overlap:\n", + " # 简单的字符重叠度打分\n", + " scores[doc['id']] = len(overlap) / len(query_chars)\n", + " \n", + " # 排序返回 top_k\n", + " ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]\n", + " results = []\n", + " for doc_id, score in ranked:\n", + " doc = next(d for d in self.documents if d['id'] == doc_id)\n", + " results.append({'id': doc_id, 'score': round(score, 4), 'content': doc['content']})\n", + " return results\n", + "\n", + "retriever = SimpleRetriever(documents)\n", + "print('检索器构建完成')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 测试检索\n", + "queries = [\n", + " '水稻稻瘟病怎么识别?',\n", + " '遥感技术怎么监测病虫害?',\n", + " '智能灌溉系统如何工作?',\n", + "]\n", + "\n", + "for query in queries:\n", + " print(f'\\n🔍 查询: {query}')\n", + " print('-' * 60)\n", + " results = retriever.search(query, top_k=2)\n", + " for i, r in enumerate(results, 1):\n", + " print(f' [{i}] 相关度={r[\"score\"]:.4f} | 文档{r[\"id\"]}')\n", + " print(f' {r[\"content\"][:100]}...')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. 多模态数据联动:图像标注 + 知识检索" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 加载 COCO 标注\n", + "with open(f'{DATA_DIR}/json/作物病害检测_COCO格式.json', 'r', encoding='utf-8') as f:\n", + " coco = json.load(f)\n", + "\n", + "cat_names = {c['id']: c['name'] for c in coco['categories']}\n", + "\n", + "# 构建图像-标注映射\n", + "img_annotations = defaultdict(list)\n", + "for ann in coco['annotations']:\n", + " img_annotations[ann['image_id']].append({\n", + " 'category': cat_names[ann['category_id']],\n", + " 'bbox': ann['bbox'],\n", + " 'area': ann['area']\n", + " })\n", + "\n", + "# 模拟多模态检索:输入图片ID,返回标注+相关知识\n", + "def multimodal_search(image_id):\n", + " \"\"\"根据图像ID检索标注信息和相关知识\"\"\"\n", + " # 获取图像信息\n", + " img_info = next(img for img in coco['images'] if img['id'] == image_id)\n", + " annotations = img_annotations.get(image_id, [])\n", + " \n", + " result = {\n", + " 'image': img_info['file_name'],\n", + " 'size': f\"{img_info['width']}x{img_info['height']}\",\n", + " 'annotations': [],\n", + " 'knowledge': []\n", + " }\n", + " \n", + " for ann in annotations:\n", + " category = ann['category']\n", + " result['annotations'].append({\n", + " '病害类别': category,\n", + " '边界框': ann['bbox'],\n", + " '面积': ann['area']\n", + " })\n", + " \n", + " # 关联知识库检索\n", + " knowledge = retriever.search(category, top_k=1)\n", + " if knowledge:\n", + " result['knowledge'].append({\n", + " 'query': category,\n", + " 'related_doc': knowledge[0]['content'][:200]\n", + " })\n", + " \n", + " return result\n", + "\n", + "# 测试多模态检索\n", + "sample_img_id = 1\n", + "result = multimodal_search(sample_img_id)\n", + "print(f'\\n图像: {result[\"image\"]} ({result[\"size\"]})')\n", + "print(f'检测到标注: {len(result[\"annotations\"])} 个')\n", + "for ann in result['annotations']:\n", + " print(f' - {ann[\"病害类别\"]} | bbox={ann[\"边界框\"]} | area={ann[\"面积\"]}')\n", + "if result['knowledge']:\n", + " print(f'\\n关联知识:')\n", + " for k in result['knowledge']:\n", + " print(f' 查询[{k[\"query\"]}]: {k[\"related_doc\"][:100]}...')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. 问答对匹配(模拟 LLM 问答)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def answer_question(question, qa_data, knowledge_docs):\n", + " \"\"\"模拟 RAG 问答流程:检索相关文档 + 匹配问答对\"\"\"\n", + " \n", + " # Step 1: 从知识库检索相关文档\n", + " related_docs = retriever.search(question, top_k=2)\n", + " \n", + " # Step 2: 匹配最相似的问答对\n", + " best_match = None\n", + " best_score = 0\n", + " \n", + " for qa in qa_data:\n", + " user_q = qa['conversations'][0]['content']\n", + " # 简单字符重叠度\n", + " overlap = len(set(question) & set(user_q))\n", + " score = overlap / max(len(set(question)), 1)\n", + " if score > best_score:\n", + " best_score = score\n", + " best_match = qa\n", + " \n", + " # Step 3: 组装回答\n", + " response = {\n", + " 'question': question,\n", + " 'matched_qa_score': round(best_score, 4),\n", + " 'knowledge_context': [d['id'] for d in related_docs] if related_docs else [],\n", + " }\n", + " \n", + " if best_match and best_score > 0.2:\n", + " response['answer'] = best_match['conversations'][1]['content']\n", + " response['source'] = 'QA语料库'\n", + " else:\n", + " response['answer'] = '暂未找到精确匹配的回答,请参考知识库文档。'\n", + " response['source'] = '知识库'\n", + " \n", + " return response\n", + "\n", + "# 测试问答\n", + "test_questions = [\n", + " '水稻叶片有褐色病斑,是不是稻瘟病?',\n", + " '无人机怎么监测作物病害?',\n", + " '小麦锈病怎么防治?',\n", + " '土壤湿度传感器数据怎么看?',\n", + "]\n", + "\n", + "for q in test_questions:\n", + " result = answer_question(q, qa_data, documents)\n", + " print(f'\\n❓ 问题: {q}')\n", + " print(f' 匹配度: {result[\"matched_qa_score\"]}')\n", + " print(f' 来源: {result[\"source\"]}')\n", + " print(f' 回答: {result[\"answer\"][:150]}...')\n", + " print(f' 参考文档: {result[\"knowledge_context\"]}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. 多光谱数据探索 (HDF5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import h5py\n", + "\n", + "# 加载 HDF5 数据\n", + "with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n", + " print('HDF5 文件结构:')\n", + " print(f' 顶层组: {list(hf.keys())}')\n", + " print(f'\\n 元数据属性:')\n", + " meta = hf['metadata']\n", + " for key, val in meta.attrs.items():\n", + " print(f' {key}: {val}')\n", + " \n", + " print(f'\\n 场景数量: {len(hf[\"images\"])}')\n", + " \n", + " # 查看第一个场景\n", + " scene = hf['images/scene_001']\n", + " print(f'\\n 场景1属性:')\n", + " for key, val in scene.attrs.items():\n", + " print(f' {key}: {val}')\n", + " \n", + " ms_data = scene['multispectral'][()]\n", + " ndvi_data = scene['ndvi'][()]\n", + " print(f'\\n 多光谱数据形状: {ms_data.shape} (波段 x 高 x 宽)')\n", + " print(f' NDVI 数据形状: {ndvi_data.shape}')\n", + " print(f' NDVI 范围: [{ndvi_data.min():.3f}, {ndvi_data.max():.3f}]')\n", + " \n", + " # 查看标注\n", + " labels = hf['labels/segmentation_masks'][()]\n", + " classes = json.loads(hf['labels'].attrs['classes'])\n", + " print(f'\\n 标注数据形状: {labels.shape}')\n", + " print(f' 分类类别: {classes}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 多光谱数据可视化\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib\n", + "\n", + "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n", + "\n", + "with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n", + " scene = hf['images/scene_001']\n", + " ms = scene['multispectral'][()]\n", + " ndvi = scene['ndvi'][()]\n", + " \n", + " fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", + " band_names = ['Blue (450nm)', 'Green (560nm)', 'Red (650nm)', 'RedEdge (730nm)', 'NIR (840nm)']\n", + " \n", + " for i, (ax, name) in enumerate(zip(axes[0], band_names)):\n", + " im = ax.imshow(ms[i], cmap='gray')\n", + " ax.set_title(name)\n", + " plt.colorbar(im, ax=ax, fraction=0.046)\n", + " \n", + " # NDVI 伪彩色\n", + " im = axes[1, 0].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)\n", + " axes[1, 0].set_title('NDVI 分布')\n", + " plt.colorbar(im, ax=axes[1, 0], fraction=0.046)\n", + " \n", + " # NDVI 直方图\n", + " axes[1, 1].hist(ndvi.flatten(), bins=50, color='green', alpha=0.7, edgecolor='white')\n", + " axes[1, 1].set_title('NDVI 值分布直方图')\n", + " axes[1, 1].set_xlabel('NDVI')\n", + " axes[1, 1].axvline(ndvi.mean(), color='red', linestyle='--', label=f'均值={ndvi.mean():.3f}')\n", + " axes[1, 1].legend()\n", + " \n", + " # 标注掩码\n", + " label = hf['labels/segmentation_masks'][0]\n", + " classes = json.loads(hf['labels'].attrs['classes'])\n", + " im = axes[1, 2].imshow(label, cmap='jet')\n", + " axes[1, 2].set_title('分割标注掩码')\n", + " cbar = plt.colorbar(im, ax=axes[1, 2], fraction=0.046, ticks=[0, 1, 2, 3])\n", + " cbar.ax.set_yticklabels(classes)\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "print('多模态与 RAG 检索示例演示完成!')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file