Files
multimodal_and_rag_demo/03_multimodal_and_rag_demo.ipynb
2026-04-15 08:05:00 +00:00

14 KiB
Raw Permalink Blame History

多模态与 RAG 检索示例

本 Notebook 演示如何使用多模态数据(文本+图像元数据)构建简单的 RAG 检索和问答系统。

场景: 多模态 AI / RAG 知识库检索

数据集:

  • JSONL: 农业问答微调语料
  • TXT: 农业知识库语料
  • JSON: COCO 标注数据
  • HDF5: 多光谱影像数据

1. 环境准备

In [ ]:
import json
import os
import numpy as np
import pandas as pd
from collections import defaultdict

DATA_DIR = os.path.dirname(os.path.abspath(''))
print('数据目录:', DATA_DIR)
print('环境准备完成')

2. 加载 JSONL 问答语料

In [ ]:
# 加载 JSONL 数据
qa_data = []
with open(f'{DATA_DIR}/jsonl/农业问答_微调语料.jsonl', 'r', encoding='utf-8') as f:
    for line in f:
        qa_data.append(json.loads(line))

print(f'问答对数量: {len(qa_data)}')
print('\n--- 示例数据 ---')
sample = qa_data[0]
for conv in sample['conversations']:
    print(f'[{conv["role"]}]: {conv["content"][:100]}...')
print(f'\n元数据: {json.dumps(sample["metadata"], ensure_ascii=False)}')

3. 加载 TXT 知识库语料

In [ ]:
# 加载知识库文本
with open(f'{DATA_DIR}/txt/农业知识库语料.txt', 'r', encoding='utf-8') as f:
    knowledge_raw = f.read()

# 按文档分割
documents = []
for block in knowledge_raw.split('【文档')[1:]:
    lines = block.strip().split('\n', 1)
    doc_id = lines[0].replace('】', '')
    content = lines[1].strip() if len(lines) > 1 else ''
    documents.append({'id': doc_id, 'content': content})

print(f'知识库文档数: {len(documents)}')
print(f'\n第一篇文档前200字:\n{documents[0]["content"][:200]}...')

4. 简单关键词检索(模拟 RAG 检索)

In [ ]:
class SimpleRetriever:
    """基于关键词的简单检索器,模拟 RAG 检索流程"""
    
    def __init__(self, documents):
        self.documents = documents
        self.index = defaultdict(list)  # 倒排索引
        self._build_index()
    
    def _build_index(self):
        """构建简单的倒排索引"""
        for doc in self.documents:
            words = set(doc['content'])
            for word in words:
                self.index[word].append(doc['id'])
    
    def search(self, query, top_k=3):
        """关键词检索"""
        scores = defaultdict(float)
        query_chars = set(query)
        
        for doc in self.documents:
            doc_chars = set(doc['content'])
            overlap = query_chars & doc_chars
            if overlap:
                # 简单的字符重叠度打分
                scores[doc['id']] = len(overlap) / len(query_chars)
        
        # 排序返回 top_k
        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        results = []
        for doc_id, score in ranked:
            doc = next(d for d in self.documents if d['id'] == doc_id)
            results.append({'id': doc_id, 'score': round(score, 4), 'content': doc['content']})
        return results

retriever = SimpleRetriever(documents)
print('检索器构建完成')
In [ ]:
# 测试检索
queries = [
    '水稻稻瘟病怎么识别?',
    '遥感技术怎么监测病虫害?',
    '智能灌溉系统如何工作?',
]

for query in queries:
    print(f'\n🔍 查询: {query}')
    print('-' * 60)
    results = retriever.search(query, top_k=2)
    for i, r in enumerate(results, 1):
        print(f'  [{i}] 相关度={r["score"]:.4f} | 文档{r["id"]}')
        print(f'      {r["content"][:100]}...')

5. 多模态数据联动:图像标注 + 知识检索

In [ ]:
# 加载 COCO 标注
with open(f'{DATA_DIR}/json/作物病害检测_COCO格式.json', 'r', encoding='utf-8') as f:
    coco = json.load(f)

cat_names = {c['id']: c['name'] for c in coco['categories']}

# 构建图像-标注映射
img_annotations = defaultdict(list)
for ann in coco['annotations']:
    img_annotations[ann['image_id']].append({
        'category': cat_names[ann['category_id']],
        'bbox': ann['bbox'],
        'area': ann['area']
    })

# 模拟多模态检索输入图片ID返回标注+相关知识
def multimodal_search(image_id):
    """根据图像ID检索标注信息和相关知识"""
    # 获取图像信息
    img_info = next(img for img in coco['images'] if img['id'] == image_id)
    annotations = img_annotations.get(image_id, [])
    
    result = {
        'image': img_info['file_name'],
        'size': f"{img_info['width']}x{img_info['height']}",
        'annotations': [],
        'knowledge': []
    }
    
    for ann in annotations:
        category = ann['category']
        result['annotations'].append({
            '病害类别': category,
            '边界框': ann['bbox'],
            '面积': ann['area']
        })
        
        # 关联知识库检索
        knowledge = retriever.search(category, top_k=1)
        if knowledge:
            result['knowledge'].append({
                'query': category,
                'related_doc': knowledge[0]['content'][:200]
            })
    
    return result

# 测试多模态检索
sample_img_id = 1
result = multimodal_search(sample_img_id)
print(f'\n图像: {result["image"]} ({result["size"]})')
print(f'检测到标注: {len(result["annotations"])} 个')
for ann in result['annotations']:
    print(f'  - {ann["病害类别"]} | bbox={ann["边界框"]} | area={ann["面积"]}')
if result['knowledge']:
    print(f'\n关联知识:')
    for k in result['knowledge']:
        print(f'  查询[{k["query"]}]: {k["related_doc"][:100]}...')

6. 问答对匹配(模拟 LLM 问答)

In [ ]:
def answer_question(question, qa_data, knowledge_docs):
    """模拟 RAG 问答流程:检索相关文档 + 匹配问答对"""
    
    # Step 1: 从知识库检索相关文档
    related_docs = retriever.search(question, top_k=2)
    
    # Step 2: 匹配最相似的问答对
    best_match = None
    best_score = 0
    
    for qa in qa_data:
        user_q = qa['conversations'][0]['content']
        # 简单字符重叠度
        overlap = len(set(question) & set(user_q))
        score = overlap / max(len(set(question)), 1)
        if score > best_score:
            best_score = score
            best_match = qa
    
    # Step 3: 组装回答
    response = {
        'question': question,
        'matched_qa_score': round(best_score, 4),
        'knowledge_context': [d['id'] for d in related_docs] if related_docs else [],
    }
    
    if best_match and best_score > 0.2:
        response['answer'] = best_match['conversations'][1]['content']
        response['source'] = 'QA语料库'
    else:
        response['answer'] = '暂未找到精确匹配的回答,请参考知识库文档。'
        response['source'] = '知识库'
    
    return response

# 测试问答
test_questions = [
    '水稻叶片有褐色病斑,是不是稻瘟病?',
    '无人机怎么监测作物病害?',
    '小麦锈病怎么防治?',
    '土壤湿度传感器数据怎么看?',
]

for q in test_questions:
    result = answer_question(q, qa_data, documents)
    print(f'\n❓ 问题: {q}')
    print(f'   匹配度: {result["matched_qa_score"]}')
    print(f'   来源: {result["source"]}')
    print(f'   回答: {result["answer"][:150]}...')
    print(f'   参考文档: {result["knowledge_context"]}')

7. 多光谱数据探索 (HDF5)

In [ ]:
import h5py

# 加载 HDF5 数据
with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:
    print('HDF5 文件结构:')
    print(f'  顶层组: {list(hf.keys())}')
    print(f'\n  元数据属性:')
    meta = hf['metadata']
    for key, val in meta.attrs.items():
        print(f'    {key}: {val}')
    
    print(f'\n  场景数量: {len(hf["images"])}')
    
    # 查看第一个场景
    scene = hf['images/scene_001']
    print(f'\n  场景1属性:')
    for key, val in scene.attrs.items():
        print(f'    {key}: {val}')
    
    ms_data = scene['multispectral'][()]
    ndvi_data = scene['ndvi'][()]
    print(f'\n  多光谱数据形状: {ms_data.shape} (波段 x 高 x 宽)')
    print(f'  NDVI 数据形状: {ndvi_data.shape}')
    print(f'  NDVI 范围: [{ndvi_data.min():.3f}, {ndvi_data.max():.3f}]')
    
    # 查看标注
    labels = hf['labels/segmentation_masks'][()]
    classes = json.loads(hf['labels'].attrs['classes'])
    print(f'\n  标注数据形状: {labels.shape}')
    print(f'  分类类别: {classes}')
In [ ]:
# 多光谱数据可视化
import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']

with h5py.File(f'{DATA_DIR}/hdf5/多光谱作物影像数据.h5', 'r') as hf:
    scene = hf['images/scene_001']
    ms = scene['multispectral'][()]
    ndvi = scene['ndvi'][()]
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    band_names = ['Blue (450nm)', 'Green (560nm)', 'Red (650nm)', 'RedEdge (730nm)', 'NIR (840nm)']
    
    for i, (ax, name) in enumerate(zip(axes[0], band_names)):
        im = ax.imshow(ms[i], cmap='gray')
        ax.set_title(name)
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # NDVI 伪彩色
    im = axes[1, 0].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
    axes[1, 0].set_title('NDVI 分布')
    plt.colorbar(im, ax=axes[1, 0], fraction=0.046)
    
    # NDVI 直方图
    axes[1, 1].hist(ndvi.flatten(), bins=50, color='green', alpha=0.7, edgecolor='white')
    axes[1, 1].set_title('NDVI 值分布直方图')
    axes[1, 1].set_xlabel('NDVI')
    axes[1, 1].axvline(ndvi.mean(), color='red', linestyle='--', label=f'均值={ndvi.mean():.3f}')
    axes[1, 1].legend()
    
    # 标注掩码
    label = hf['labels/segmentation_masks'][0]
    classes = json.loads(hf['labels'].attrs['classes'])
    im = axes[1, 2].imshow(label, cmap='jet')
    axes[1, 2].set_title('分割标注掩码')
    cbar = plt.colorbar(im, ax=axes[1, 2], fraction=0.046, ticks=[0, 1, 2, 3])
    cbar.ax.set_yticklabels(classes)

plt.tight_layout()
plt.show()
print('多模态与 RAG 检索示例演示完成!')