上一篇我们了解了决策树的理论知识,今天直接上代码,用Python实现一个完整的决策树模型。
# 安装必要库pip install scikit-learn pandas matplotlibimport pandas as pdfrom sklearn.model_selection import train_test_splitfrom sklearn.tree import DecisionTreeClassifier, plot_treefrom sklearn.metrics import accuracy_score, classification_reportimport matplotlib.pyplot as plt# 1. 准备数据# 用经典的鸢尾花数据集from sklearn.datasets import load_irisiris = load_iris()X = iris.data # 特征:花萼长宽、花瓣长宽y = iris.target # 标签:三种鸢尾花# 2. 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42)# 3. 创建并训练模型clf = DecisionTreeClassifier( max_depth=3, # 树的最大深度,防止过拟合 min_samples_split=5, # 节点分裂最小样本数 random_state=42)clf.fit(X_train, y_train)# 4. 预测y_pred = clf.predict(X_test)# 5. 评估print(f"准确率: {accuracy_score(y_test, y_pred):.2%}")print("\n详细报告:")print(classification_report(y_test, y_pred, target_names=iris.target_names))# 6. 可视化决策树plt.figure(figsize=(20, 10))plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, # 填充颜色 rounded=True, # 圆角节点 fontsize=10)plt.savefig('decision_tree.png', dpi=150, bbox_inches='tight')plt.show()# 7. 预测新数据new_flower = [[5.1, 3.5, 1.4, 0.2]] # 一朵新花的数据prediction = clf.predict(new_flower)prob = clf.predict_proba(new_flower)print(f"\n预测结果: {iris.target_names[prediction[0]]}")print(f"各类概率: {dict(zip(iris.target_names, prob[0]))}")from sklearn.datasets import load_irisiris = load_iris()鸢尾花数据集是机器学习界的"Hello World",包含150个样本,3个类别。
DecisionTreeClassifier( max_depth=3, # 树的最大深度 min_samples_split=5, # 节点分裂最小样本数 min_samples_leaf=2, # 叶子节点最小样本数 criterion='gini' # 分裂标准:'gini'或'entropy')max_depth | ||
min_samples_split | ||
min_samples_leaf | ||
criterion |
# 查看各特征的重要性importance = pd.DataFrame({ 'feature': iris.feature_names, 'importance': clf.feature_importances_}).sort_values('importance', ascending=False)print(importance)输出示例:
feature importance2 petal length (cm) 0.5133 petal width (cm) 0.4870 sepal length (cm) 0.0001 sepal width (cm) 0.000花瓣长度和宽度是区分鸢尾花种类的关键特征。
# 读取CSV数据df = pd.read_csv('your_data.csv')# 分离特征和标签X = df.drop('target_column', axis=1)y = df['target_column']# 处理类别特征(如果有)X = pd.get_dummies(X, columns=['category_column'])# 训练模型(同上)# 错误clf.fit(X, y)pred = clf.predict(X) # 用训练集评估,结果虚高!# 必须划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)clf.fit(X_train, y_train)pred = clf.predict(X_test)# 容易过拟合clf = DecisionTreeClassifier() # 默认不限制深度# 设置合理的深度限制clf = DecisionTreeClassifier(max_depth=5, min_samples_split=10)用sklearn实现决策树只需4步:
train_test_splitDecisionTreeClassifierfit()predict() + accuracy_score完整代码不到50行,就能构建一个可解释、易部署的预测模型。
下一篇:我们将用决策树做一个真实的用户流失预测案例,敬请期待!