From 43ceb3a44f6202625e5ad830ada857d38c69911a Mon Sep 17 00:00:00 2001 From: bdhtc Date: Wed, 15 Apr 2026 08:05:42 +0000 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 02_model_training_demo.ipynb | 301 +++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 02_model_training_demo.ipynb diff --git a/02_model_training_demo.ipynb b/02_model_training_demo.ipynb new file mode 100644 index 0000000..7623865 --- /dev/null +++ b/02_model_training_demo.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 作物病害分类模型训练示例\n", + "\n", + "本 Notebook 演示如何使用作物数据集训练机器学习模型,完成病害严重程度分类任务。\n", + "\n", + "**场景:** AI 模型训练(分类任务)\n", + "\n", + "**数据集:** CSV 作物病害标注表\n", + "\n", + "**模型:** 随机森林 / 决策树分类器" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. 环境准备与数据加载" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.tree import DecisionTreeClassifier\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.metrics import (\n", + " classification_report, confusion_matrix, accuracy_score,\n", + " ConfusionMatrixDisplay\n", + ")\n", + "import os\n", + "\n", + "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'WenQuanYi Micro Hei']\n", + "matplotlib.rcParams['axes.unicode_minus'] = False\n", + "\n", + "DATA_DIR = os.path.dirname(os.path.abspath(''))\n", + "print('环境准备完成')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 加载数据\n", + "df = pd.read_csv(f'{DATA_DIR}/csv/作物病害标注表.csv')\n", + "print(f'数据集大小: {df.shape}')\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. 数据预处理" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 特征工程\n", + "le_crop = LabelEncoder()\n", + "le_disease = LabelEncoder()\n", + "le_part = LabelEncoder()\n", + "le_region = LabelEncoder()\n", + "le_severity = LabelEncoder()\n", + "\n", + "df['作物编码'] = le_crop.fit_transform(df['作物'])\n", + "df['病害编码'] = le_disease.fit_transform(df['病害名称'])\n", + "df['部位编码'] = le_part.fit_transform(df['发病部位'])\n", + "df['地区编码'] = le_region.fit_transform(df['地区'])\n", + "df['严重程度编码'] = le_severity.fit_transform(df['严重程度'])\n", + "\n", + "# 选择特征\n", + "feature_cols = ['作物编码', '病害编码', '部位编码', '地区编码', '温度_℃', '湿度_%', '经度', '纬度']\n", + "X = df[feature_cols]\n", + "y = df['严重程度编码']\n", + "\n", + "print(f'特征维度: {X.shape}')\n", + "print(f'标签分布:\\n{pd.Series(y).value_counts().sort_index()}')\n", + "print(f'\\n标签映射:')\n", + "for i, cls in enumerate(le_severity.classes_):\n", + " print(f' {i}: {cls}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. 划分训练集和测试集" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, stratify=y\n", + ")\n", + "print(f'训练集大小: {X_train.shape[0]}')\n", + "print(f'测试集大小: {X_test.shape[0]}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. 训练随机森林模型" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 训练随机森林\n", + "rf_model = RandomForestClassifier(\n", + " n_estimators=100,\n", + " max_depth=10,\n", + " random_state=42,\n", + " n_jobs=-1\n", + ")\n", + "rf_model.fit(X_train, y_train)\n", + "\n", + "# 预测\n", + "y_pred_rf = rf_model.predict(X_test)\n", + "acc_rf = accuracy_score(y_test, y_pred_rf)\n", + "print(f'随机森林准确率: {acc_rf:.4f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. 训练决策树模型(对比)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 训练决策树\n", + "dt_model = DecisionTreeClassifier(max_depth=8, random_state=42)\n", + "dt_model.fit(X_train, y_train)\n", + "\n", + "y_pred_dt = dt_model.predict(X_test)\n", + "acc_dt = accuracy_score(y_test, y_pred_dt)\n", + "print(f'决策树准确率: {acc_dt:.4f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. 模型评估" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 分类报告\n", + "print('=' * 50)\n", + "print('随机森林 - 分类报告')\n", + "print('=' * 50)\n", + "print(classification_report(y_test, y_pred_rf, target_names=le_severity.classes_))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 混淆矩阵可视化\n", + "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", + "\n", + "ConfusionMatrixDisplay.from_predictions(\n", + " y_test, y_pred_rf,\n", + " display_labels=le_severity.classes_,\n", + " cmap='Blues', ax=axes[0]\n", + ")\n", + "axes[0].set_title(f'随机森林混淆矩阵 (准确率={acc_rf:.4f})')\n", + "\n", + "ConfusionMatrixDisplay.from_predictions(\n", + " y_test, y_pred_dt,\n", + " display_labels=le_severity.classes_,\n", + " cmap='Oranges', ax=axes[1]\n", + ")\n", + "axes[1].set_title(f'决策树混淆矩阵 (准确率={acc_dt:.4f})')\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. 特征重要性分析" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 特征重要性\n", + "importances = rf_model.feature_importances_\n", + "indices = np.argsort(importances)[::-1]\n", + "\n", + "plt.figure(figsize=(10, 6))\n", + "plt.bar(range(len(importances)), importances[indices], align='center', color='steelblue')\n", + "plt.xticks(range(len(importances)), [feature_cols[i] for i in indices], rotation=45)\n", + "plt.ylabel('重要性分数')\n", + "plt.title('随机森林 - 特征重要性排序')\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "print('特征重要性排名:')\n", + "for rank, idx in enumerate(indices, 1):\n", + " print(f' {rank}. {feature_cols[idx]}: {importances[idx]:.4f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. 模型预测示例" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 模拟新样本预测\n", + "new_samples = pd.DataFrame({\n", + " '作物编码': le_crop.transform(['水稻', '小麦', '玉米']),\n", + " '病害编码': le_disease.transform(['稻瘟病', '锈病', '大斑病']),\n", + " '部位编码': le_part.transform(['叶片', '叶片', '穗部']),\n", + " '地区编码': le_region.transform(['长江中下游', '华北平原', '东北平原']),\n", + " '温度_℃': [28.5, 22.0, 25.3],\n", + " '湿度_%': [85.0, 65.0, 78.0],\n", + " '经度': [116.4, 114.5, 123.4],\n", + " '纬度': [30.5, 38.0, 41.8],\n", + "})\n", + "\n", + "predictions = rf_model.predict(new_samples)\n", + "pred_labels = le_severity.inverse_transform(predictions)\n", + "\n", + "print('预测结果:')\n", + "print('-' * 50)\n", + "for i, (_, row) in enumerate(new_samples.iterrows()):\n", + " crop_name = le_crop.inverse_transform([row['作物编码']])[0]\n", + " disease = le_disease.inverse_transform([row['病害编码']])[0]\n", + " print(f'样本{i+1}: {crop_name}/{disease} => 预测严重程度: {pred_labels[i]}')n", + "\n", + "print('\\n模型训练和预测流程演示完成!')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file