Files
model_training_demo/02_model_training_demo.ipynb
2026-04-15 08:05:42 +00:00

301 lines
8.3 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 作物病害分类模型训练示例\n",
"\n",
"本 Notebook 演示如何使用作物数据集训练机器学习模型,完成病害严重程度分类任务。\n",
"\n",
"**场景:** AI 模型训练(分类任务)\n",
"\n",
"**数据集:** CSV 作物病害标注表\n",
"\n",
"**模型:** 随机森林 / 决策树分类器"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. 环境准备与数据加载"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn.metrics import (\n",
" classification_report, confusion_matrix, accuracy_score,\n",
" ConfusionMatrixDisplay\n",
")\n",
"import os\n",
"\n",
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n",
"matplotlib.rcParams['axes.unicode_minus'] = False\n",
"\n",
"DATA_DIR = os.path.dirname(os.path.abspath(''))\n",
"print('环境准备完成')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 加载数据\n",
"df = pd.read_csv(f'{DATA_DIR}/csv/作物病害标注表.csv')\n",
"print(f'数据集大小: {df.shape}')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. 数据预处理"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 特征工程\n",
"le_crop = LabelEncoder()\n",
"le_disease = LabelEncoder()\n",
"le_part = LabelEncoder()\n",
"le_region = LabelEncoder()\n",
"le_severity = LabelEncoder()\n",
"\n",
"df['作物编码'] = le_crop.fit_transform(df['作物'])\n",
"df['病害编码'] = le_disease.fit_transform(df['病害名称'])\n",
"df['部位编码'] = le_part.fit_transform(df['发病部位'])\n",
"df['地区编码'] = le_region.fit_transform(df['地区'])\n",
"df['严重程度编码'] = le_severity.fit_transform(df['严重程度'])\n",
"\n",
"# 选择特征\n",
"feature_cols = ['作物编码', '病害编码', '部位编码', '地区编码', '温度_℃', '湿度_%', '经度', '纬度']\n",
"X = df[feature_cols]\n",
"y = df['严重程度编码']\n",
"\n",
"print(f'特征维度: {X.shape}')\n",
"print(f'标签分布:\\n{pd.Series(y).value_counts().sort_index()}')\n",
"print(f'\\n标签映射:')\n",
"for i, cls in enumerate(le_severity.classes_):\n",
" print(f' {i}: {cls}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. 划分训练集和测试集"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.2, random_state=42, stratify=y\n",
")\n",
"print(f'训练集大小: {X_train.shape[0]}')\n",
"print(f'测试集大小: {X_test.shape[0]}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. 训练随机森林模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 训练随机森林\n",
"rf_model = RandomForestClassifier(\n",
" n_estimators=100,\n",
" max_depth=10,\n",
" random_state=42,\n",
" n_jobs=-1\n",
")\n",
"rf_model.fit(X_train, y_train)\n",
"\n",
"# 预测\n",
"y_pred_rf = rf_model.predict(X_test)\n",
"acc_rf = accuracy_score(y_test, y_pred_rf)\n",
"print(f'随机森林准确率: {acc_rf:.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. 训练决策树模型(对比)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 训练决策树\n",
"dt_model = DecisionTreeClassifier(max_depth=8, random_state=42)\n",
"dt_model.fit(X_train, y_train)\n",
"\n",
"y_pred_dt = dt_model.predict(X_test)\n",
"acc_dt = accuracy_score(y_test, y_pred_dt)\n",
"print(f'决策树准确率: {acc_dt:.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. 模型评估"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 分类报告\n",
"print('=' * 50)\n",
"print('随机森林 - 分类报告')\n",
"print('=' * 50)\n",
"print(classification_report(y_test, y_pred_rf, target_names=le_severity.classes_))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 混淆矩阵可视化\n",
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
"\n",
"ConfusionMatrixDisplay.from_predictions(\n",
" y_test, y_pred_rf,\n",
" display_labels=le_severity.classes_,\n",
" cmap='Blues', ax=axes[0]\n",
")\n",
"axes[0].set_title(f'随机森林混淆矩阵 (准确率={acc_rf:.4f})')\n",
"\n",
"ConfusionMatrixDisplay.from_predictions(\n",
" y_test, y_pred_dt,\n",
" display_labels=le_severity.classes_,\n",
" cmap='Oranges', ax=axes[1]\n",
")\n",
"axes[1].set_title(f'决策树混淆矩阵 (准确率={acc_dt:.4f})')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. 特征重要性分析"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 特征重要性\n",
"importances = rf_model.feature_importances_\n",
"indices = np.argsort(importances)[::-1]\n",
"\n",
"plt.figure(figsize=(10, 6))\n",
"plt.bar(range(len(importances)), importances[indices], align='center', color='steelblue')\n",
"plt.xticks(range(len(importances)), [feature_cols[i] for i in indices], rotation=45)\n",
"plt.ylabel('重要性分数')\n",
"plt.title('随机森林 - 特征重要性排序')\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"print('特征重要性排名:')\n",
"for rank, idx in enumerate(indices, 1):\n",
" print(f' {rank}. {feature_cols[idx]}: {importances[idx]:.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. 模型预测示例"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 模拟新样本预测\n",
"new_samples = pd.DataFrame({\n",
" '作物编码': le_crop.transform(['水稻', '小麦', '玉米']),\n",
" '病害编码': le_disease.transform(['稻瘟病', '锈病', '大斑病']),\n",
" '部位编码': le_part.transform(['叶片', '叶片', '穗部']),\n",
" '地区编码': le_region.transform(['长江中下游', '华北平原', '东北平原']),\n",
" '温度_℃': [28.5, 22.0, 25.3],\n",
" '湿度_%': [85.0, 65.0, 78.0],\n",
" '经度': [116.4, 114.5, 123.4],\n",
" '纬度': [30.5, 38.0, 41.8],\n",
"})\n",
"\n",
"predictions = rf_model.predict(new_samples)\n",
"pred_labels = le_severity.inverse_transform(predictions)\n",
"\n",
"print('预测结果:')\n",
"print('-' * 50)\n",
"for i, (_, row) in enumerate(new_samples.iterrows()):\n",
" crop_name = le_crop.inverse_transform([row['作物编码']])[0]\n",
" disease = le_disease.inverse_transform([row['病害编码']])[0]\n",
" print(f'样本{i+1}: {crop_name}/{disease} => 预测严重程度: {pred_labels[i]}')n",
"\n",
"print('\\n模型训练和预测流程演示完成')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}