代码篇:完整的鸢尾花分类代码,包含参数详解、交叉验证、可视化、模型保存
阅读时间:8分钟
环境准备
pip install scikit-learn pandas numpy matplotlib
完整代码:鸢尾花分类
import numpy as npimport matplotlib.pyplot as pltfrom sklearn import datasetsfrom sklearn.model_selection import train_test_splitfrom sklearn.svm import SVCfrom sklearn.metrics import accuracy_score, classification_report# 1. 加载数据iris = datasets.load_iris()X = iris.data[:, :2] # 只取前两个特征,方便可视化y = iris.target# 2. 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=42)# 3. 创建SVM模型(RBF核)model = SVC( kernel='rbf', # 核函数:线性'linear'、多项式'poly'、RBF'rbf' C=1.0, # 惩罚系数:越大越严格,容易过拟合 gamma='scale', # 核系数:控制单个样本的影响范围 probability=True # 启用概率预测(会慢一点))# 4. 训练模型model.fit(X_train, y_train)# 5. 预测y_pred = model.predict(X_test)# 6. 评估print(f"准确率: {accuracy_score(y_test, y_pred):.3f}")print("\n详细报告:")print(classification_report(y_test, y_pred, target_names=iris.target_names))# 7. 查看支持向量print(f"\n支持向量数量: {len(model.support_)}")print(f"支持向量索引: {model.support_[:5]}...") # 显示前5个# 8. 概率预测(需要probability=True)proba = model.predict_proba(X_test[:3])print(f"\n前3个样本的预测概率:\n{proba}")
关键参数详解
C(惩罚系数)
# C越大,模型越"严格",不允许分错# C越小,模型越"宽松",允许一些错误model_hard = SVC(C=100) # 严格版model_soft = SVC(C=0.1) # 宽松版
怎么选? 用交叉验证:
from sklearn.model_selection import GridSearchCVparam_grid = {'C': [0.1, 1, 10, 100], 'gamma': ['scale', 'auto', 0.001, 0.01]}grid = GridSearchCV(SVC(kernel='rbf'), param_grid, cv=5)grid.fit(X_train, y_train)print(f"最佳参数: {grid.best_params_}")
kernel(核函数)
# 快速选择指南SVC(kernel='linear') # 数据线性可分,特征多样本少SVC(kernel='rbf') # 不知道选什么,默认这个SVC(kernel='poly', degree=3) # 需要捕捉特征交互
可视化决策边界
# 创建网格h = 0.02x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))# 预测网格点Z = model.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)# 画图plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.coolwarm)plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.coolwarm, edgecolors='k')plt.xlabel('花萼长度')plt.ylabel('花萼宽度')plt.title('SVM 决策边界')plt.show()
保存和加载模型
import joblib# 保存joblib.dump(model, 'svm_model.pkl')# 加载model_loaded = joblib.load('svm_model.pkl')predictions = model_loaded.predict(X_new)
下期预告:用SVM做垃圾邮件过滤,完整项目代码!
#支持向量机 #python #机器学习 #数据分析 #互联网