{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 多模态与 RAG 检索示例\n", "\n", "本 Notebook 演示如何使用多模态数据(文本+图像元数据)构建简单的 RAG 检索和问答系统。\n", "\n", "**场景:** 多模态 AI / RAG 知识库检索\n", "\n", "**数据集:**\n", "- JSONL: 农业问答微调语料\n", "- TXT: 农业知识库语料\n", "- JSON: COCO 标注数据\n", "- HDF5: 多光谱影像数据" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. 环境准备" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "from collections import defaultdict\n", "\n", "DATA_DIR = os.path.dirname(os.path.abspath(''))\n", "print('数据目录:', DATA_DIR)\n", "print('环境准备完成')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. 加载 JSONL 问答语料" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 加载 JSONL 数据\n", "qa_data = []\n", "with open(f'{DATA_DIR}/jsonl/农业问答_微调语料.jsonl', 'r', encoding='utf-8') as f:\n", " for line in f:\n", " qa_data.append(json.loads(line))\n", "\n", "print(f'问答对数量: {len(qa_data)}')\n", "print('\\n--- 示例数据 ---')\n", "sample = qa_data[0]\n", "for conv in sample['conversations']:\n", " print(f'[{conv[\"role\"]}]: {conv[\"content\"][:100]}...')\n", "print(f'\\n元数据: {json.dumps(sample[\"metadata\"], ensure_ascii=False)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. 加载 TXT 知识库语料" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 加载知识库文本\n", "with open(f'{DATA_DIR}/txt/农业知识库语料.txt', 'r', encoding='utf-8') as f:\n", " knowledge_raw = f.read()\n", "\n", "# 按文档分割\n", "documents = []\n", "for block in knowledge_raw.split('【文档')[1:]:\n", " lines = block.strip().split('\\n', 1)\n", " doc_id = lines[0].replace('】', '')\n", " content = lines[1].strip() if len(lines) > 1 else ''\n", " documents.append({'id': doc_id, 'content': content})\n", "\n", "print(f'知识库文档数: {len(documents)}')\n", "print(f'\\n第一篇文档前200字:\\n{documents[0][\"content\"][:200]}...')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. 简单关键词检索(模拟 RAG 检索)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SimpleRetriever:\n", " \"\"\"基于关键词的简单检索器,模拟 RAG 检索流程\"\"\"\n", " \n", " def __init__(self, documents):\n", " self.documents = documents\n", " self.index = defaultdict(list) # 倒排索引\n", " self._build_index()\n", " \n", " def _build_index(self):\n", " \"\"\"构建简单的倒排索引\"\"\"\n", " for doc in self.documents:\n", " words = set(doc['content'])\n", " for word in words:\n", " self.index[word].append(doc['id'])\n", " \n", " def search(self, query, top_k=3):\n", " \"\"\"关键词检索\"\"\"\n", " scores = defaultdict(float)\n", " query_chars = set(query)\n", " \n", " for doc in self.documents:\n", " doc_chars = set(doc['content'])\n", " overlap = query_chars & doc_chars\n", " if overlap:\n", " # 简单的字符重叠度打分\n", " scores[doc['id']] = len(overlap) / len(query_chars)\n", " \n", " # 排序返回 top_k\n", " ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]\n", " results = []\n", " for doc_id, score in ranked:\n", " doc = next(d for d in self.documents if d['id'] == doc_id)\n", " results.append({'id': doc_id, 'score': round(score, 4), 'content': doc['content']})\n", " return results\n", "\n", "retriever = SimpleRetriever(documents)\n", "print('检索器构建完成')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 测试检索\n", "queries = [\n", " '水稻稻瘟病怎么识别?',\n", " '遥感技术怎么监测病虫害?',\n", " '智能灌溉系统如何工作?',\n", "]\n", "\n", "for query in queries:\n", " print(f'\\n🔍 查询: {query}')\n", " print('-' * 60)\n", " results = retriever.search(query, top_k=2)\n", " for i, r in enumerate(results, 1):\n", " print(f' [{i}] 相关度={r[\"score\"]:.4f} | 文档{r[\"id\"]}')\n", " print(f' {r[\"content\"][:100]}...')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. 多模态数据联动:图像标注 + 知识检索" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 加载 COCO 标注\n", "with open(f'{DATA_DIR}/json/作物病害检测_COCO格式.json', 'r', encoding='utf-8') as f:\n", " coco = json.load(f)\n", "\n", "cat_names = {c['id']: c['name'] for c in coco['categories']}\n", "\n", "# 构建图像-标注映射\n", "img_annotations = defaultdict(list)\n", "for ann in coco['annotations']:\n", " img_annotations[ann['image_id']].append({\n", " 'category': cat_names[ann['category_id']],\n", " 'bbox': ann['bbox'],\n", " 'area': ann['area']\n", " })\n", "\n", "# 模拟多模态检索:输入图片ID,返回标注+相关知识\n", "def multimodal_search(image_id):\n", " \"\"\"根据图像ID检索标注信息和相关知识\"\"\"\n", " # 获取图像信息\n", " img_info = next(img for img in coco['images'] if img['id'] == image_id)\n", " annotations = img_annotations.get(image_id, [])\n", " \n", " result = {\n", " 'image': img_info['file_name'],\n", " 'size': f\"{img_info['width']}x{img_info['height']}\",\n", " 'annotations': [],\n", " 'knowledge': []\n", " }\n", " \n", " for ann in annotations:\n", " category = ann['category']\n", " result['annotations'].append({\n", " '病害类别': category,\n", " '边界框': ann['bbox'],\n", " '面积': ann['area']\n", " })\n", " \n", " # 关联知识库检索\n", " knowledge = retriever.search(category, top_k=1)\n", " if knowledge:\n", " result['knowledge'].append({\n", " 'query': category,\n", " 'related_doc': knowledge[0]['content'][:200]\n", " })\n", " \n", " return result\n", "\n", "# 测试多模态检索\n", "sample_img_id = 1\n", "result = multimodal_search(sample_img_id)\n", "print(f'\\n图像: {result[\"image\"]} ({result[\"size\"]})')\n", "print(f'检测到标注: {len(result[\"annotations\"])} 个')\n", "for ann in result['annotations']:\n", " print(f' - {ann[\"病害类别\"]} | bbox={ann[\"边界框\"]} | area={ann[\"面积\"]}')\n", "if result['knowledge']:\n", " print(f'\\n关联知识:')\n", " for k in result['knowledge']:\n", " print(f' 查询[{k[\"query\"]}]: {k[\"related_doc\"][:100]}...')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. 问答对匹配(模拟 LLM 问答)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def answer_question(question, qa_data, knowledge_docs):\n", " \"\"\"模拟 RAG 问答流程:检索相关文档 + 匹配问答对\"\"\"\n", " \n", " # Step 1: 从知识库检索相关文档\n", " related_docs = retriever.search(question, top_k=2)\n", " \n", " # Step 2: 匹配最相似的问答对\n", " best_match = None\n", " best_score = 0\n", " \n", " for qa in qa_data:\n", " user_q = qa['conversations'][0]['content']\n", " # 简单字符重叠度\n", " overlap = len(set(question) & set(user_q))\n", " score = overlap / max(len(set(question)), 1)\n", " if score > best_score:\n", " best_score = score\n", " best_match = qa\n", " \n", " # Step 3: 组装回答\n", " response = {\n", " 'question': question,\n", " 'matched_qa_score': round(best_score, 4),\n", " 'knowledge_context': [d['id'] for d in related_docs] if related_docs else [],\n", " }\n", " \n", " if best_match and best_score > 0.2:\n", " response['answer'] = best_match['conversations'][1]['content']\n", " response['source'] = 'QA语料库'\n", " else:\n", " response['answer'] = '暂未找到精确匹配的回答,请参考知识库文档。'\n", " response['source'] = '知识库'\n", " \n", " return response\n", "\n", "# 测试问答\n", "test_questions = [\n", " '水稻叶片有褐色病斑,是不是稻瘟病?',\n", " '无人机怎么监测作物病害?',\n", " '小麦锈病怎么防治?',\n", " '土壤湿度传感器数据怎么看?',\n", "]\n", "\n", "for q in test_questions:\n", " result = answer_question(q, qa_data, documents)\n", " print(f'\\n❓ 问题: {q}')\n", " print(f' 匹配度: {result[\"matched_qa_score\"]}')\n", " print(f' 来源: {result[\"source\"]}')\n", " print(f' 回答: {result[\"answer\"][:150]}...')\n", " print(f' 参考文档: {result[\"knowledge_context\"]}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. 多光谱数据探索 (HDF5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import h5py\n", "\n", "# 加载 HDF5 数据\n", "with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n", " print('HDF5 文件结构:')\n", " print(f' 顶层组: {list(hf.keys())}')\n", " print(f'\\n 元数据属性:')\n", " meta = hf['metadata']\n", " for key, val in meta.attrs.items():\n", " print(f' {key}: {val}')\n", " \n", " print(f'\\n 场景数量: {len(hf[\"images\"])}')\n", " \n", " # 查看第一个场景\n", " scene = hf['images/scene_001']\n", " print(f'\\n 场景1属性:')\n", " for key, val in scene.attrs.items():\n", " print(f' {key}: {val}')\n", " \n", " ms_data = scene['multispectral'][()]\n", " ndvi_data = scene['ndvi'][()]\n", " print(f'\\n 多光谱数据形状: {ms_data.shape} (波段 x 高 x 宽)')\n", " print(f' NDVI 数据形状: {ndvi_data.shape}')\n", " print(f' NDVI 范围: [{ndvi_data.min():.3f}, {ndvi_data.max():.3f}]')\n", " \n", " # 查看标注\n", " labels = hf['labels/segmentation_masks'][()]\n", " classes = json.loads(hf['labels'].attrs['classes'])\n", " print(f'\\n 标注数据形状: {labels.shape}')\n", " print(f' 分类类别: {classes}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 多光谱数据可视化\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "\n", "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n", "\n", "with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n", " scene = hf['images/scene_001']\n", " ms = scene['multispectral'][()]\n", " ndvi = scene['ndvi'][()]\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", " band_names = ['Blue (450nm)', 'Green (560nm)', 'Red (650nm)', 'RedEdge (730nm)', 'NIR (840nm)']\n", " \n", " for i, (ax, name) in enumerate(zip(axes[0], band_names)):\n", " im = ax.imshow(ms[i], cmap='gray')\n", " ax.set_title(name)\n", " plt.colorbar(im, ax=ax, fraction=0.046)\n", " \n", " # NDVI 伪彩色\n", " im = axes[1, 0].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)\n", " axes[1, 0].set_title('NDVI 分布')\n", " plt.colorbar(im, ax=axes[1, 0], fraction=0.046)\n", " \n", " # NDVI 直方图\n", " axes[1, 1].hist(ndvi.flatten(), bins=50, color='green', alpha=0.7, edgecolor='white')\n", " axes[1, 1].set_title('NDVI 值分布直方图')\n", " axes[1, 1].set_xlabel('NDVI')\n", " axes[1, 1].axvline(ndvi.mean(), color='red', linestyle='--', label=f'均值={ndvi.mean():.3f}')\n", " axes[1, 1].legend()\n", " \n", " # 标注掩码\n", " label = hf['labels/segmentation_masks'][0]\n", " classes = json.loads(hf['labels'].attrs['classes'])\n", " im = axes[1, 2].imshow(label, cmap='jet')\n", " axes[1, 2].set_title('分割标注掩码')\n", " cbar = plt.colorbar(im, ax=axes[1, 2], fraction=0.046, ticks=[0, 1, 2, 3])\n", " cbar.ax.set_yticklabels(classes)\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "print('多模态与 RAG 检索示例演示完成!')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }