14 KiB
14 KiB
多模态与 RAG 检索示例¶
本 Notebook 演示如何使用多模态数据(文本+图像元数据)构建简单的 RAG 检索和问答系统。
场景: 多模态 AI / RAG 知识库检索
数据集:
- JSONL: 农业问答微调语料
- TXT: 农业知识库语料
- JSON: COCO 标注数据
- HDF5: 多光谱影像数据
1. 环境准备¶
In [ ]:
import json
import os
import numpy as np
import pandas as pd
from collections import defaultdict
DATA_DIR = os.path.dirname(os.path.abspath(''))
print('数据目录:', DATA_DIR)
print('环境准备完成')
2. 加载 JSONL 问答语料¶
In [ ]:
# 加载 JSONL 数据
qa_data = []
with open(f'{DATA_DIR}/jsonl/农业问答_微调语料.jsonl', 'r', encoding='utf-8') as f:
for line in f:
qa_data.append(json.loads(line))
print(f'问答对数量: {len(qa_data)}')
print('\n--- 示例数据 ---')
sample = qa_data[0]
for conv in sample['conversations']:
print(f'[{conv["role"]}]: {conv["content"][:100]}...')
print(f'\n元数据: {json.dumps(sample["metadata"], ensure_ascii=False)}')
3. 加载 TXT 知识库语料¶
In [ ]:
# 加载知识库文本
with open(f'{DATA_DIR}/txt/农业知识库语料.txt', 'r', encoding='utf-8') as f:
knowledge_raw = f.read()
# 按文档分割
documents = []
for block in knowledge_raw.split('【文档')[1:]:
lines = block.strip().split('\n', 1)
doc_id = lines[0].replace('】', '')
content = lines[1].strip() if len(lines) > 1 else ''
documents.append({'id': doc_id, 'content': content})
print(f'知识库文档数: {len(documents)}')
print(f'\n第一篇文档前200字:\n{documents[0]["content"][:200]}...')
4. 简单关键词检索(模拟 RAG 检索)¶
In [ ]:
class SimpleRetriever:
"""基于关键词的简单检索器,模拟 RAG 检索流程"""
def __init__(self, documents):
self.documents = documents
self.index = defaultdict(list) # 倒排索引
self._build_index()
def _build_index(self):
"""构建简单的倒排索引"""
for doc in self.documents:
words = set(doc['content'])
for word in words:
self.index[word].append(doc['id'])
def search(self, query, top_k=3):
"""关键词检索"""
scores = defaultdict(float)
query_chars = set(query)
for doc in self.documents:
doc_chars = set(doc['content'])
overlap = query_chars & doc_chars
if overlap:
# 简单的字符重叠度打分
scores[doc['id']] = len(overlap) / len(query_chars)
# 排序返回 top_k
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
results = []
for doc_id, score in ranked:
doc = next(d for d in self.documents if d['id'] == doc_id)
results.append({'id': doc_id, 'score': round(score, 4), 'content': doc['content']})
return results
retriever = SimpleRetriever(documents)
print('检索器构建完成')
In [ ]:
# 测试检索
queries = [
'水稻稻瘟病怎么识别?',
'遥感技术怎么监测病虫害?',
'智能灌溉系统如何工作?',
]
for query in queries:
print(f'\n🔍 查询: {query}')
print('-' * 60)
results = retriever.search(query, top_k=2)
for i, r in enumerate(results, 1):
print(f' [{i}] 相关度={r["score"]:.4f} | 文档{r["id"]}')
print(f' {r["content"][:100]}...')
5. 多模态数据联动:图像标注 + 知识检索¶
In [ ]:
# 加载 COCO 标注
with open(f'{DATA_DIR}/json/作物病害检测_COCO格式.json', 'r', encoding='utf-8') as f:
coco = json.load(f)
cat_names = {c['id']: c['name'] for c in coco['categories']}
# 构建图像-标注映射
img_annotations = defaultdict(list)
for ann in coco['annotations']:
img_annotations[ann['image_id']].append({
'category': cat_names[ann['category_id']],
'bbox': ann['bbox'],
'area': ann['area']
})
# 模拟多模态检索:输入图片ID,返回标注+相关知识
def multimodal_search(image_id):
"""根据图像ID检索标注信息和相关知识"""
# 获取图像信息
img_info = next(img for img in coco['images'] if img['id'] == image_id)
annotations = img_annotations.get(image_id, [])
result = {
'image': img_info['file_name'],
'size': f"{img_info['width']}x{img_info['height']}",
'annotations': [],
'knowledge': []
}
for ann in annotations:
category = ann['category']
result['annotations'].append({
'病害类别': category,
'边界框': ann['bbox'],
'面积': ann['area']
})
# 关联知识库检索
knowledge = retriever.search(category, top_k=1)
if knowledge:
result['knowledge'].append({
'query': category,
'related_doc': knowledge[0]['content'][:200]
})
return result
# 测试多模态检索
sample_img_id = 1
result = multimodal_search(sample_img_id)
print(f'\n图像: {result["image"]} ({result["size"]})')
print(f'检测到标注: {len(result["annotations"])} 个')
for ann in result['annotations']:
print(f' - {ann["病害类别"]} | bbox={ann["边界框"]} | area={ann["面积"]}')
if result['knowledge']:
print(f'\n关联知识:')
for k in result['knowledge']:
print(f' 查询[{k["query"]}]: {k["related_doc"][:100]}...')
6. 问答对匹配(模拟 LLM 问答)¶
In [ ]:
def answer_question(question, qa_data, knowledge_docs):
"""模拟 RAG 问答流程:检索相关文档 + 匹配问答对"""
# Step 1: 从知识库检索相关文档
related_docs = retriever.search(question, top_k=2)
# Step 2: 匹配最相似的问答对
best_match = None
best_score = 0
for qa in qa_data:
user_q = qa['conversations'][0]['content']
# 简单字符重叠度
overlap = len(set(question) & set(user_q))
score = overlap / max(len(set(question)), 1)
if score > best_score:
best_score = score
best_match = qa
# Step 3: 组装回答
response = {
'question': question,
'matched_qa_score': round(best_score, 4),
'knowledge_context': [d['id'] for d in related_docs] if related_docs else [],
}
if best_match and best_score > 0.2:
response['answer'] = best_match['conversations'][1]['content']
response['source'] = 'QA语料库'
else:
response['answer'] = '暂未找到精确匹配的回答,请参考知识库文档。'
response['source'] = '知识库'
return response
# 测试问答
test_questions = [
'水稻叶片有褐色病斑,是不是稻瘟病?',
'无人机怎么监测作物病害?',
'小麦锈病怎么防治?',
'土壤湿度传感器数据怎么看?',
]
for q in test_questions:
result = answer_question(q, qa_data, documents)
print(f'\n❓ 问题: {q}')
print(f' 匹配度: {result["matched_qa_score"]}')
print(f' 来源: {result["source"]}')
print(f' 回答: {result["answer"][:150]}...')
print(f' 参考文档: {result["knowledge_context"]}')
7. 多光谱数据探索 (HDF5)¶
In [ ]:
import h5py
# 加载 HDF5 数据
with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:
print('HDF5 文件结构:')
print(f' 顶层组: {list(hf.keys())}')
print(f'\n 元数据属性:')
meta = hf['metadata']
for key, val in meta.attrs.items():
print(f' {key}: {val}')
print(f'\n 场景数量: {len(hf["images"])}')
# 查看第一个场景
scene = hf['images/scene_001']
print(f'\n 场景1属性:')
for key, val in scene.attrs.items():
print(f' {key}: {val}')
ms_data = scene['multispectral'][()]
ndvi_data = scene['ndvi'][()]
print(f'\n 多光谱数据形状: {ms_data.shape} (波段 x 高 x 宽)')
print(f' NDVI 数据形状: {ndvi_data.shape}')
print(f' NDVI 范围: [{ndvi_data.min():.3f}, {ndvi_data.max():.3f}]')
# 查看标注
labels = hf['labels/segmentation_masks'][()]
classes = json.loads(hf['labels'].attrs['classes'])
print(f'\n 标注数据形状: {labels.shape}')
print(f' 分类类别: {classes}')
In [ ]:
# 多光谱数据可视化
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']
with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:
scene = hf['images/scene_001']
ms = scene['multispectral'][()]
ndvi = scene['ndvi'][()]
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
band_names = ['Blue (450nm)', 'Green (560nm)', 'Red (650nm)', 'RedEdge (730nm)', 'NIR (840nm)']
for i, (ax, name) in enumerate(zip(axes[0], band_names)):
im = ax.imshow(ms[i], cmap='gray')
ax.set_title(name)
plt.colorbar(im, ax=ax, fraction=0.046)
# NDVI 伪彩色
im = axes[1, 0].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
axes[1, 0].set_title('NDVI 分布')
plt.colorbar(im, ax=axes[1, 0], fraction=0.046)
# NDVI 直方图
axes[1, 1].hist(ndvi.flatten(), bins=50, color='green', alpha=0.7, edgecolor='white')
axes[1, 1].set_title('NDVI 值分布直方图')
axes[1, 1].set_xlabel('NDVI')
axes[1, 1].axvline(ndvi.mean(), color='red', linestyle='--', label=f'均值={ndvi.mean():.3f}')
axes[1, 1].legend()
# 标注掩码
label = hf['labels/segmentation_masks'][0]
classes = json.loads(hf['labels'].attrs['classes'])
im = axes[1, 2].imshow(label, cmap='jet')
axes[1, 2].set_title('分割标注掩码')
cbar = plt.colorbar(im, ax=axes[1, 2], fraction=0.046, ticks=[0, 1, 2, 3])
cbar.ax.set_yticklabels(classes)
plt.tight_layout()
plt.show()
print('多模态与 RAG 检索示例演示完成!')