新增文件
This commit is contained in:
421
03_multimodal_and_rag_demo.ipynb
Normal file
421
03_multimodal_and_rag_demo.ipynb
Normal file
@@ -0,0 +1,421 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 多模态与 RAG 检索示例\n",
|
||||
"\n",
|
||||
"本 Notebook 演示如何使用多模态数据(文本+图像元数据)构建简单的 RAG 检索和问答系统。\n",
|
||||
"\n",
|
||||
"**场景:** 多模态 AI / RAG 知识库检索\n",
|
||||
"\n",
|
||||
"**数据集:**\n",
|
||||
"- JSONL: 农业问答微调语料\n",
|
||||
"- TXT: 农业知识库语料\n",
|
||||
"- JSON: COCO 标注数据\n",
|
||||
"- HDF5: 多光谱影像数据"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 环境准备"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from collections import defaultdict\n",
|
||||
"\n",
|
||||
"DATA_DIR = os.path.dirname(os.path.abspath(''))\n",
|
||||
"print('数据目录:', DATA_DIR)\n",
|
||||
"print('环境准备完成')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. 加载 JSONL 问答语料"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 加载 JSONL 数据\n",
|
||||
"qa_data = []\n",
|
||||
"with open(f'{DATA_DIR}/jsonl/农业问答_微调语料.jsonl', 'r', encoding='utf-8') as f:\n",
|
||||
" for line in f:\n",
|
||||
" qa_data.append(json.loads(line))\n",
|
||||
"\n",
|
||||
"print(f'问答对数量: {len(qa_data)}')\n",
|
||||
"print('\\n--- 示例数据 ---')\n",
|
||||
"sample = qa_data[0]\n",
|
||||
"for conv in sample['conversations']:\n",
|
||||
" print(f'[{conv[\"role\"]}]: {conv[\"content\"][:100]}...')\n",
|
||||
"print(f'\\n元数据: {json.dumps(sample[\"metadata\"], ensure_ascii=False)}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. 加载 TXT 知识库语料"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 加载知识库文本\n",
|
||||
"with open(f'{DATA_DIR}/txt/农业知识库语料.txt', 'r', encoding='utf-8') as f:\n",
|
||||
" knowledge_raw = f.read()\n",
|
||||
"\n",
|
||||
"# 按文档分割\n",
|
||||
"documents = []\n",
|
||||
"for block in knowledge_raw.split('【文档')[1:]:\n",
|
||||
" lines = block.strip().split('\\n', 1)\n",
|
||||
" doc_id = lines[0].replace('】', '')\n",
|
||||
" content = lines[1].strip() if len(lines) > 1 else ''\n",
|
||||
" documents.append({'id': doc_id, 'content': content})\n",
|
||||
"\n",
|
||||
"print(f'知识库文档数: {len(documents)}')\n",
|
||||
"print(f'\\n第一篇文档前200字:\\n{documents[0][\"content\"][:200]}...')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. 简单关键词检索(模拟 RAG 检索)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SimpleRetriever:\n",
|
||||
" \"\"\"基于关键词的简单检索器,模拟 RAG 检索流程\"\"\"\n",
|
||||
" \n",
|
||||
" def __init__(self, documents):\n",
|
||||
" self.documents = documents\n",
|
||||
" self.index = defaultdict(list) # 倒排索引\n",
|
||||
" self._build_index()\n",
|
||||
" \n",
|
||||
" def _build_index(self):\n",
|
||||
" \"\"\"构建简单的倒排索引\"\"\"\n",
|
||||
" for doc in self.documents:\n",
|
||||
" words = set(doc['content'])\n",
|
||||
" for word in words:\n",
|
||||
" self.index[word].append(doc['id'])\n",
|
||||
" \n",
|
||||
" def search(self, query, top_k=3):\n",
|
||||
" \"\"\"关键词检索\"\"\"\n",
|
||||
" scores = defaultdict(float)\n",
|
||||
" query_chars = set(query)\n",
|
||||
" \n",
|
||||
" for doc in self.documents:\n",
|
||||
" doc_chars = set(doc['content'])\n",
|
||||
" overlap = query_chars & doc_chars\n",
|
||||
" if overlap:\n",
|
||||
" # 简单的字符重叠度打分\n",
|
||||
" scores[doc['id']] = len(overlap) / len(query_chars)\n",
|
||||
" \n",
|
||||
" # 排序返回 top_k\n",
|
||||
" ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]\n",
|
||||
" results = []\n",
|
||||
" for doc_id, score in ranked:\n",
|
||||
" doc = next(d for d in self.documents if d['id'] == doc_id)\n",
|
||||
" results.append({'id': doc_id, 'score': round(score, 4), 'content': doc['content']})\n",
|
||||
" return results\n",
|
||||
"\n",
|
||||
"retriever = SimpleRetriever(documents)\n",
|
||||
"print('检索器构建完成')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 测试检索\n",
|
||||
"queries = [\n",
|
||||
" '水稻稻瘟病怎么识别?',\n",
|
||||
" '遥感技术怎么监测病虫害?',\n",
|
||||
" '智能灌溉系统如何工作?',\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for query in queries:\n",
|
||||
" print(f'\\n🔍 查询: {query}')\n",
|
||||
" print('-' * 60)\n",
|
||||
" results = retriever.search(query, top_k=2)\n",
|
||||
" for i, r in enumerate(results, 1):\n",
|
||||
" print(f' [{i}] 相关度={r[\"score\"]:.4f} | 文档{r[\"id\"]}')\n",
|
||||
" print(f' {r[\"content\"][:100]}...')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. 多模态数据联动:图像标注 + 知识检索"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 加载 COCO 标注\n",
|
||||
"with open(f'{DATA_DIR}/json/作物病害检测_COCO格式.json', 'r', encoding='utf-8') as f:\n",
|
||||
" coco = json.load(f)\n",
|
||||
"\n",
|
||||
"cat_names = {c['id']: c['name'] for c in coco['categories']}\n",
|
||||
"\n",
|
||||
"# 构建图像-标注映射\n",
|
||||
"img_annotations = defaultdict(list)\n",
|
||||
"for ann in coco['annotations']:\n",
|
||||
" img_annotations[ann['image_id']].append({\n",
|
||||
" 'category': cat_names[ann['category_id']],\n",
|
||||
" 'bbox': ann['bbox'],\n",
|
||||
" 'area': ann['area']\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"# 模拟多模态检索:输入图片ID,返回标注+相关知识\n",
|
||||
"def multimodal_search(image_id):\n",
|
||||
" \"\"\"根据图像ID检索标注信息和相关知识\"\"\"\n",
|
||||
" # 获取图像信息\n",
|
||||
" img_info = next(img for img in coco['images'] if img['id'] == image_id)\n",
|
||||
" annotations = img_annotations.get(image_id, [])\n",
|
||||
" \n",
|
||||
" result = {\n",
|
||||
" 'image': img_info['file_name'],\n",
|
||||
" 'size': f\"{img_info['width']}x{img_info['height']}\",\n",
|
||||
" 'annotations': [],\n",
|
||||
" 'knowledge': []\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" for ann in annotations:\n",
|
||||
" category = ann['category']\n",
|
||||
" result['annotations'].append({\n",
|
||||
" '病害类别': category,\n",
|
||||
" '边界框': ann['bbox'],\n",
|
||||
" '面积': ann['area']\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" # 关联知识库检索\n",
|
||||
" knowledge = retriever.search(category, top_k=1)\n",
|
||||
" if knowledge:\n",
|
||||
" result['knowledge'].append({\n",
|
||||
" 'query': category,\n",
|
||||
" 'related_doc': knowledge[0]['content'][:200]\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"# 测试多模态检索\n",
|
||||
"sample_img_id = 1\n",
|
||||
"result = multimodal_search(sample_img_id)\n",
|
||||
"print(f'\\n图像: {result[\"image\"]} ({result[\"size\"]})')\n",
|
||||
"print(f'检测到标注: {len(result[\"annotations\"])} 个')\n",
|
||||
"for ann in result['annotations']:\n",
|
||||
" print(f' - {ann[\"病害类别\"]} | bbox={ann[\"边界框\"]} | area={ann[\"面积\"]}')\n",
|
||||
"if result['knowledge']:\n",
|
||||
" print(f'\\n关联知识:')\n",
|
||||
" for k in result['knowledge']:\n",
|
||||
" print(f' 查询[{k[\"query\"]}]: {k[\"related_doc\"][:100]}...')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. 问答对匹配(模拟 LLM 问答)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def answer_question(question, qa_data, knowledge_docs):\n",
|
||||
" \"\"\"模拟 RAG 问答流程:检索相关文档 + 匹配问答对\"\"\"\n",
|
||||
" \n",
|
||||
" # Step 1: 从知识库检索相关文档\n",
|
||||
" related_docs = retriever.search(question, top_k=2)\n",
|
||||
" \n",
|
||||
" # Step 2: 匹配最相似的问答对\n",
|
||||
" best_match = None\n",
|
||||
" best_score = 0\n",
|
||||
" \n",
|
||||
" for qa in qa_data:\n",
|
||||
" user_q = qa['conversations'][0]['content']\n",
|
||||
" # 简单字符重叠度\n",
|
||||
" overlap = len(set(question) & set(user_q))\n",
|
||||
" score = overlap / max(len(set(question)), 1)\n",
|
||||
" if score > best_score:\n",
|
||||
" best_score = score\n",
|
||||
" best_match = qa\n",
|
||||
" \n",
|
||||
" # Step 3: 组装回答\n",
|
||||
" response = {\n",
|
||||
" 'question': question,\n",
|
||||
" 'matched_qa_score': round(best_score, 4),\n",
|
||||
" 'knowledge_context': [d['id'] for d in related_docs] if related_docs else [],\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" if best_match and best_score > 0.2:\n",
|
||||
" response['answer'] = best_match['conversations'][1]['content']\n",
|
||||
" response['source'] = 'QA语料库'\n",
|
||||
" else:\n",
|
||||
" response['answer'] = '暂未找到精确匹配的回答,请参考知识库文档。'\n",
|
||||
" response['source'] = '知识库'\n",
|
||||
" \n",
|
||||
" return response\n",
|
||||
"\n",
|
||||
"# 测试问答\n",
|
||||
"test_questions = [\n",
|
||||
" '水稻叶片有褐色病斑,是不是稻瘟病?',\n",
|
||||
" '无人机怎么监测作物病害?',\n",
|
||||
" '小麦锈病怎么防治?',\n",
|
||||
" '土壤湿度传感器数据怎么看?',\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for q in test_questions:\n",
|
||||
" result = answer_question(q, qa_data, documents)\n",
|
||||
" print(f'\\n❓ 问题: {q}')\n",
|
||||
" print(f' 匹配度: {result[\"matched_qa_score\"]}')\n",
|
||||
" print(f' 来源: {result[\"source\"]}')\n",
|
||||
" print(f' 回答: {result[\"answer\"][:150]}...')\n",
|
||||
" print(f' 参考文档: {result[\"knowledge_context\"]}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. 多光谱数据探索 (HDF5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import h5py\n",
|
||||
"\n",
|
||||
"# 加载 HDF5 数据\n",
|
||||
"with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n",
|
||||
" print('HDF5 文件结构:')\n",
|
||||
" print(f' 顶层组: {list(hf.keys())}')\n",
|
||||
" print(f'\\n 元数据属性:')\n",
|
||||
" meta = hf['metadata']\n",
|
||||
" for key, val in meta.attrs.items():\n",
|
||||
" print(f' {key}: {val}')\n",
|
||||
" \n",
|
||||
" print(f'\\n 场景数量: {len(hf[\"images\"])}')\n",
|
||||
" \n",
|
||||
" # 查看第一个场景\n",
|
||||
" scene = hf['images/scene_001']\n",
|
||||
" print(f'\\n 场景1属性:')\n",
|
||||
" for key, val in scene.attrs.items():\n",
|
||||
" print(f' {key}: {val}')\n",
|
||||
" \n",
|
||||
" ms_data = scene['multispectral'][()]\n",
|
||||
" ndvi_data = scene['ndvi'][()]\n",
|
||||
" print(f'\\n 多光谱数据形状: {ms_data.shape} (波段 x 高 x 宽)')\n",
|
||||
" print(f' NDVI 数据形状: {ndvi_data.shape}')\n",
|
||||
" print(f' NDVI 范围: [{ndvi_data.min():.3f}, {ndvi_data.max():.3f}]')\n",
|
||||
" \n",
|
||||
" # 查看标注\n",
|
||||
" labels = hf['labels/segmentation_masks'][()]\n",
|
||||
" classes = json.loads(hf['labels'].attrs['classes'])\n",
|
||||
" print(f'\\n 标注数据形状: {labels.shape}')\n",
|
||||
" print(f' 分类类别: {classes}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 多光谱数据可视化\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib\n",
|
||||
"\n",
|
||||
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n",
|
||||
"\n",
|
||||
"with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:\n",
|
||||
" scene = hf['images/scene_001']\n",
|
||||
" ms = scene['multispectral'][()]\n",
|
||||
" ndvi = scene['ndvi'][()]\n",
|
||||
" \n",
|
||||
" fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n",
|
||||
" band_names = ['Blue (450nm)', 'Green (560nm)', 'Red (650nm)', 'RedEdge (730nm)', 'NIR (840nm)']\n",
|
||||
" \n",
|
||||
" for i, (ax, name) in enumerate(zip(axes[0], band_names)):\n",
|
||||
" im = ax.imshow(ms[i], cmap='gray')\n",
|
||||
" ax.set_title(name)\n",
|
||||
" plt.colorbar(im, ax=ax, fraction=0.046)\n",
|
||||
" \n",
|
||||
" # NDVI 伪彩色\n",
|
||||
" im = axes[1, 0].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)\n",
|
||||
" axes[1, 0].set_title('NDVI 分布')\n",
|
||||
" plt.colorbar(im, ax=axes[1, 0], fraction=0.046)\n",
|
||||
" \n",
|
||||
" # NDVI 直方图\n",
|
||||
" axes[1, 1].hist(ndvi.flatten(), bins=50, color='green', alpha=0.7, edgecolor='white')\n",
|
||||
" axes[1, 1].set_title('NDVI 值分布直方图')\n",
|
||||
" axes[1, 1].set_xlabel('NDVI')\n",
|
||||
" axes[1, 1].axvline(ndvi.mean(), color='red', linestyle='--', label=f'均值={ndvi.mean():.3f}')\n",
|
||||
" axes[1, 1].legend()\n",
|
||||
" \n",
|
||||
" # 标注掩码\n",
|
||||
" label = hf['labels/segmentation_masks'][0]\n",
|
||||
" classes = json.loads(hf['labels'].attrs['classes'])\n",
|
||||
" im = axes[1, 2].imshow(label, cmap='jet')\n",
|
||||
" axes[1, 2].set_title('分割标注掩码')\n",
|
||||
" cbar = plt.colorbar(im, ax=axes[1, 2], fraction=0.046, ticks=[0, 1, 2, 3])\n",
|
||||
" cbar.ax.set_yticklabels(classes)\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"print('多模态与 RAG 检索示例演示完成!')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user