Files
model_training_demo/02_model_training_demo.ipynb
2026-04-15 08:05:42 +00:00

8.3 KiB

作物病害分类模型训练示例

本 Notebook 演示如何使用作物数据集训练机器学习模型,完成病害严重程度分类任务。

场景: AI 模型训练(分类任务)

数据集: CSV 作物病害标注表

模型: 随机森林 / 决策树分类器

1. 环境准备与数据加载

In [ ]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    ConfusionMatrixDisplay
)
import os

matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']
matplotlib.rcParams['axes.unicode_minus'] = False

DATA_DIR = os.path.dirname(os.path.abspath(''))
print('环境准备完成')
In [ ]:
# 加载数据
df = pd.read_csv(f'{DATA_DIR}/csv/作物病害标注表.csv')
print(f'数据集大小: {df.shape}')
df.head()

2. 数据预处理

In [ ]:
# 特征工程
le_crop = LabelEncoder()
le_disease = LabelEncoder()
le_part = LabelEncoder()
le_region = LabelEncoder()
le_severity = LabelEncoder()

df['作物编码'] = le_crop.fit_transform(df['作物'])
df['病害编码'] = le_disease.fit_transform(df['病害名称'])
df['部位编码'] = le_part.fit_transform(df['发病部位'])
df['地区编码'] = le_region.fit_transform(df['地区'])
df['严重程度编码'] = le_severity.fit_transform(df['严重程度'])

# 选择特征
feature_cols = ['作物编码', '病害编码', '部位编码', '地区编码', '温度_℃', '湿度_%', '经度', '纬度']
X = df[feature_cols]
y = df['严重程度编码']

print(f'特征维度: {X.shape}')
print(f'标签分布:\n{pd.Series(y).value_counts().sort_index()}')
print(f'\n标签映射:')
for i, cls in enumerate(le_severity.classes_):
    print(f'  {i}: {cls}')

3. 划分训练集和测试集

In [ ]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
print(f'训练集大小: {X_train.shape[0]}')
print(f'测试集大小: {X_test.shape[0]}')

4. 训练随机森林模型

In [ ]:
# 训练随机森林
rf_model = RandomForestClassifier(
    n_estimators=100,
    max_depth=10,
    random_state=42,
    n_jobs=-1
)
rf_model.fit(X_train, y_train)

# 预测
y_pred_rf = rf_model.predict(X_test)
acc_rf = accuracy_score(y_test, y_pred_rf)
print(f'随机森林准确率: {acc_rf:.4f}')

5. 训练决策树模型(对比)

In [ ]:
# 训练决策树
dt_model = DecisionTreeClassifier(max_depth=8, random_state=42)
dt_model.fit(X_train, y_train)

y_pred_dt = dt_model.predict(X_test)
acc_dt = accuracy_score(y_test, y_pred_dt)
print(f'决策树准确率: {acc_dt:.4f}')

6. 模型评估

In [ ]:
# 分类报告
print('=' * 50)
print('随机森林 - 分类报告')
print('=' * 50)
print(classification_report(y_test, y_pred_rf, target_names=le_severity.classes_))
In [ ]:
# 混淆矩阵可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ConfusionMatrixDisplay.from_predictions(
    y_test, y_pred_rf,
    display_labels=le_severity.classes_,
    cmap='Blues', ax=axes[0]
)
axes[0].set_title(f'随机森林混淆矩阵 (准确率={acc_rf:.4f})')

ConfusionMatrixDisplay.from_predictions(
    y_test, y_pred_dt,
    display_labels=le_severity.classes_,
    cmap='Oranges', ax=axes[1]
)
axes[1].set_title(f'决策树混淆矩阵 (准确率={acc_dt:.4f})')

plt.tight_layout()
plt.show()

7. 特征重要性分析

In [ ]:
# 特征重要性
importances = rf_model.feature_importances_
indices = np.argsort(importances)[::-1]

plt.figure(figsize=(10, 6))
plt.bar(range(len(importances)), importances[indices], align='center', color='steelblue')
plt.xticks(range(len(importances)), [feature_cols[i] for i in indices], rotation=45)
plt.ylabel('重要性分数')
plt.title('随机森林 - 特征重要性排序')
plt.tight_layout()
plt.show()

print('特征重要性排名:')
for rank, idx in enumerate(indices, 1):
    print(f'  {rank}. {feature_cols[idx]}: {importances[idx]:.4f}')

8. 模型预测示例

In [ ]:
# 模拟新样本预测
new_samples = pd.DataFrame({
    '作物编码': le_crop.transform(['水稻', '小麦', '玉米']),
    '病害编码': le_disease.transform(['稻瘟病', '锈病', '大斑病']),
    '部位编码': le_part.transform(['叶片', '叶片', '穗部']),
    '地区编码': le_region.transform(['长江中下游', '华北平原', '东北平原']),
    '温度_℃': [28.5, 22.0, 25.3],
    '湿度_%': [85.0, 65.0, 78.0],
    '经度': [116.4, 114.5, 123.4],
    '纬度': [30.5, 38.0, 41.8],
})

predictions = rf_model.predict(new_samples)
pred_labels = le_severity.inverse_transform(predictions)

print('预测结果:')
print('-' * 50)
for i, (_, row) in enumerate(new_samples.iterrows()):
    crop_name = le_crop.inverse_transform([row['作物编码']])[0]
    disease = le_disease.inverse_transform([row['病害编码']])[0]
    print(f'样本{i+1}: {crop_name}/{disease} => 预测严重程度: {pred_labels[i]}')n
print('\n模型训练和预测流程演示完成!')