新增文件
This commit is contained in:
301
02_model_training_demo.ipynb
Normal file
301
02_model_training_demo.ipynb
Normal file
@@ -0,0 +1,301 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 作物病害分类模型训练示例\n",
|
||||
"\n",
|
||||
"本 Notebook 演示如何使用作物数据集训练机器学习模型,完成病害严重程度分类任务。\n",
|
||||
"\n",
|
||||
"**场景:** AI 模型训练(分类任务)\n",
|
||||
"\n",
|
||||
"**数据集:** CSV 作物病害标注表\n",
|
||||
"\n",
|
||||
"**模型:** 随机森林 / 决策树分类器"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 环境准备与数据加载"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import matplotlib\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.preprocessing import LabelEncoder\n",
|
||||
"from sklearn.metrics import (\n",
|
||||
" classification_report, confusion_matrix, accuracy_score,\n",
|
||||
" ConfusionMatrixDisplay\n",
|
||||
")\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n",
|
||||
"matplotlib.rcParams['axes.unicode_minus'] = False\n",
|
||||
"\n",
|
||||
"DATA_DIR = os.path.dirname(os.path.abspath(''))\n",
|
||||
"print('环境准备完成')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 加载数据\n",
|
||||
"df = pd.read_csv(f'{DATA_DIR}/csv/作物病害标注表.csv')\n",
|
||||
"print(f'数据集大小: {df.shape}')\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. 数据预处理"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 特征工程\n",
|
||||
"le_crop = LabelEncoder()\n",
|
||||
"le_disease = LabelEncoder()\n",
|
||||
"le_part = LabelEncoder()\n",
|
||||
"le_region = LabelEncoder()\n",
|
||||
"le_severity = LabelEncoder()\n",
|
||||
"\n",
|
||||
"df['作物编码'] = le_crop.fit_transform(df['作物'])\n",
|
||||
"df['病害编码'] = le_disease.fit_transform(df['病害名称'])\n",
|
||||
"df['部位编码'] = le_part.fit_transform(df['发病部位'])\n",
|
||||
"df['地区编码'] = le_region.fit_transform(df['地区'])\n",
|
||||
"df['严重程度编码'] = le_severity.fit_transform(df['严重程度'])\n",
|
||||
"\n",
|
||||
"# 选择特征\n",
|
||||
"feature_cols = ['作物编码', '病害编码', '部位编码', '地区编码', '温度_℃', '湿度_%', '经度', '纬度']\n",
|
||||
"X = df[feature_cols]\n",
|
||||
"y = df['严重程度编码']\n",
|
||||
"\n",
|
||||
"print(f'特征维度: {X.shape}')\n",
|
||||
"print(f'标签分布:\\n{pd.Series(y).value_counts().sort_index()}')\n",
|
||||
"print(f'\\n标签映射:')\n",
|
||||
"for i, cls in enumerate(le_severity.classes_):\n",
|
||||
" print(f' {i}: {cls}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. 划分训练集和测试集"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train, X_test, y_train, y_test = train_test_split(\n",
|
||||
" X, y, test_size=0.2, random_state=42, stratify=y\n",
|
||||
")\n",
|
||||
"print(f'训练集大小: {X_train.shape[0]}')\n",
|
||||
"print(f'测试集大小: {X_test.shape[0]}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. 训练随机森林模型"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 训练随机森林\n",
|
||||
"rf_model = RandomForestClassifier(\n",
|
||||
" n_estimators=100,\n",
|
||||
" max_depth=10,\n",
|
||||
" random_state=42,\n",
|
||||
" n_jobs=-1\n",
|
||||
")\n",
|
||||
"rf_model.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"# 预测\n",
|
||||
"y_pred_rf = rf_model.predict(X_test)\n",
|
||||
"acc_rf = accuracy_score(y_test, y_pred_rf)\n",
|
||||
"print(f'随机森林准确率: {acc_rf:.4f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. 训练决策树模型(对比)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 训练决策树\n",
|
||||
"dt_model = DecisionTreeClassifier(max_depth=8, random_state=42)\n",
|
||||
"dt_model.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
"y_pred_dt = dt_model.predict(X_test)\n",
|
||||
"acc_dt = accuracy_score(y_test, y_pred_dt)\n",
|
||||
"print(f'决策树准确率: {acc_dt:.4f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. 模型评估"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 分类报告\n",
|
||||
"print('=' * 50)\n",
|
||||
"print('随机森林 - 分类报告')\n",
|
||||
"print('=' * 50)\n",
|
||||
"print(classification_report(y_test, y_pred_rf, target_names=le_severity.classes_))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 混淆矩阵可视化\n",
|
||||
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
||||
"\n",
|
||||
"ConfusionMatrixDisplay.from_predictions(\n",
|
||||
" y_test, y_pred_rf,\n",
|
||||
" display_labels=le_severity.classes_,\n",
|
||||
" cmap='Blues', ax=axes[0]\n",
|
||||
")\n",
|
||||
"axes[0].set_title(f'随机森林混淆矩阵 (准确率={acc_rf:.4f})')\n",
|
||||
"\n",
|
||||
"ConfusionMatrixDisplay.from_predictions(\n",
|
||||
" y_test, y_pred_dt,\n",
|
||||
" display_labels=le_severity.classes_,\n",
|
||||
" cmap='Oranges', ax=axes[1]\n",
|
||||
")\n",
|
||||
"axes[1].set_title(f'决策树混淆矩阵 (准确率={acc_dt:.4f})')\n",
|
||||
"\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. 特征重要性分析"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 特征重要性\n",
|
||||
"importances = rf_model.feature_importances_\n",
|
||||
"indices = np.argsort(importances)[::-1]\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(10, 6))\n",
|
||||
"plt.bar(range(len(importances)), importances[indices], align='center', color='steelblue')\n",
|
||||
"plt.xticks(range(len(importances)), [feature_cols[i] for i in indices], rotation=45)\n",
|
||||
"plt.ylabel('重要性分数')\n",
|
||||
"plt.title('随机森林 - 特征重要性排序')\n",
|
||||
"plt.tight_layout()\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"print('特征重要性排名:')\n",
|
||||
"for rank, idx in enumerate(indices, 1):\n",
|
||||
" print(f' {rank}. {feature_cols[idx]}: {importances[idx]:.4f}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. 模型预测示例"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 模拟新样本预测\n",
|
||||
"new_samples = pd.DataFrame({\n",
|
||||
" '作物编码': le_crop.transform(['水稻', '小麦', '玉米']),\n",
|
||||
" '病害编码': le_disease.transform(['稻瘟病', '锈病', '大斑病']),\n",
|
||||
" '部位编码': le_part.transform(['叶片', '叶片', '穗部']),\n",
|
||||
" '地区编码': le_region.transform(['长江中下游', '华北平原', '东北平原']),\n",
|
||||
" '温度_℃': [28.5, 22.0, 25.3],\n",
|
||||
" '湿度_%': [85.0, 65.0, 78.0],\n",
|
||||
" '经度': [116.4, 114.5, 123.4],\n",
|
||||
" '纬度': [30.5, 38.0, 41.8],\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"predictions = rf_model.predict(new_samples)\n",
|
||||
"pred_labels = le_severity.inverse_transform(predictions)\n",
|
||||
"\n",
|
||||
"print('预测结果:')\n",
|
||||
"print('-' * 50)\n",
|
||||
"for i, (_, row) in enumerate(new_samples.iterrows()):\n",
|
||||
" crop_name = le_crop.inverse_transform([row['作物编码']])[0]\n",
|
||||
" disease = le_disease.inverse_transform([row['病害编码']])[0]\n",
|
||||
" print(f'样本{i+1}: {crop_name}/{disease} => 预测严重程度: {pred_labels[i]}')n",
|
||||
"\n",
|
||||
"print('\\n模型训练和预测流程演示完成!')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.10.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
Reference in New Issue
Block a user