
代码绘制成果展示












代码解释


第一部分

# =========================================================================================# ====================================== 1. 环境设置 =======================================# =========================================================================================import matplotlib.pyplot as pltimport numpy as npimport matplotlib.gridspec as gridspecimport shapimport pandas as pd

第二部分

# =========================================================================================# ======================================2.颜色库=======================================# =========================================================================================COLOR_SCHEMES = {1: {'beeswarm': 'Spectral_r','rose': ['#4A1028', '#7B241C', '#A93226', '#CB4335', '#E67E22', '#F5B041', '#F7DC6F', '#76D7C4', '#1F618D', '#2874A6']},}

第三部分

# =========================================================================================# ======================================3.蜂巢图抖动计算函数=======================================# =========================================================================================def simple_beeswarm(x_values, nbins=40, width=0.1):hist_range = (np.min(x_values), np.max(x_values)) #数据的最小值和最大值范围if hist_range[0] == hist_range[1]: # 如果最大值等于最小值hist_range = (hist_range[0] - 0.1, hist_range[1] + 0.1) #手动扩展范围counts, edges = np.histogram(x_values, bins=nbins, range=hist_range) #计算直方图,获取各区间的计数和边界current_width = (counts[i] / max_count) * width # 根据当前箱子的密度计算抖动宽度ys = np.linspace(-current_width, current_width, len(idxs)) # 在宽度范围内生成均匀分布的Y值np.random.shuffle(ys) # 打乱Y值顺序y_values[idxs] = ys # 将计算好的Y值赋给对应的数据点return y_values # 返回计算好的Y轴抖动坐标

第四部分

# =========================================================================================# ======================================4.单幅蜂巢图及玫瑰图绘制函数=======================================# =========================================================================================def plot_shap_beeswarm(ax, features_base, shap_values, feature_values, title, letter, config, top_k_global_inds):beeswarm_cmap = config['beeswarm'] #蜂巢图颜色rose_colors = config['rose'] #玫瑰图颜色n_features = len(features_base) #特征总数num_colors = len(rose_colors) #玫瑰图特征数fv_norm = (fv - fv_min) / (fv_max - fv_min) # 将特征值进行Min-Max归一化,缩放到[0, 1]区间以匹配颜色映射条#调用函数根据SHAP值的密度分布计算Y轴方向的偏移量y_offset = simple_beeswarm(sv,nbins=100,width=0.25)#绘制散点ax.scatter(sv, #xi + y_offset, #yc=fv_norm, #根据归一化后的特征大小来决定颜色cmap=beeswarm_cmap, #颜色映射方案s=12, #散点大小alpha=1, #透明度edgecolors='none') #不绘制边缘线

第五部分

#X=0的垂直辅助线ax.axvline(x=0, #xcolor='grey', #颜色linestyle='-', #样式linewidth=2, #粗细alpha=1) #透明度ax.set_xlabel('Feature Impact model output (SHAP)', fontsize=20, fontweight='bold') #x轴标题ax.set_yticks(range(n_features)) #Y轴上为每一个特征刻度ax.text(-0.1, #x1.05, #yletter, #编号transform=ax.transAxes, #坐标系fontsize=20, #大小va='top', #垂直ha='right', #水平fontweight='bold') #加粗#目标类ax.text(0.95, #x0.05, #ytitle, #文本transform=ax.transAxes, #坐标系fontsize=20, #字体大小fontweight='bold', #加粗ha='right') #水平

第六部分

#==================================================================================================================================#========================================================内嵌玫瑰图============================================================#==================================================================================================================================#创建轴inset_ax = ax.inset_axes([0.6, -0.1, 0.65, 0.65], polar=True)inset_ax.set_theta_offset(np.pi / 2) #起始角度inset_ax.set_theta_direction(-1) #系绘图方向top_k_inds = top_k_global_inds #全局排名前10的特征索引raw_contributions = mean_abs_shap[top_k_inds] #提取绝对SHAP平均值contributions = raw_contributions / (raw_contributions.sum() + 1e-8) #百分比归一化#添加文本inset_ax.text(current_theta + width_val / 2, #角度radius_val + 4, #半径f"{prob * 100:.1f}%", #文本fontsize=13, #字体大小ha='center', #水平va='center', #垂直fontweight='bold') #加粗current_theta += width_val #更新起始角度inset_ax.axis('off') #去掉默认网格线及刻度标签

第七部分

# =========================================================================================# ======================================5.主绘图函数=======================================# =========================================================================================def generate_shap_plots(shap_values_list, X_eval, features_base, dataset_name, scheme_id):current_config = COLOR_SCHEMES.get(scheme_id, COLOR_SCHEMES[1]) #提取配色方案rose_colors = current_config['rose'] #玫瑰图配色num_colors = len(rose_colors) #获取数量长度#创建画布fig = plt.figure(figsize=(17, 14), constrained_layout=True)#设置布局gs = gridspec.GridSpec(2, #行3, #列figure=fig, #图表对象wspace=0.05, #水平间hspace=0.05) #垂直间#如果是第一张子图if letters[i] == "a":ax.text(-0.15, #x1.1, #y"Features", #文本transform=ax.transAxes, #坐标系fontsize=20, #文字大小fontweight='bold') #加粗#右下角子图ax_f = fig.add_subplot(gs[1, 2])ax_f.axis('off') #去掉坐标轴线、刻度标签#圆环图标题ax_f.text(0.25, #x1.35, #yf"SHAP summary for diffrent classess\nwith contrubution for CatBoost % ({dataset_name})\nCatboost", #文本ha='center', #水平fontsize=20, #字体大小fontweight='bold') #加粗radius =1.2 #最外层圆环半径ring_width = 0.14 #圆环厚度legend_labels, #文本loc='lower center', #位置bbox_to_anchor=(0.5, -0.72), #精确位置ncol=3, #列frameon=False, #去掉边框title='Rosechart Feature Contribution %',#图例标题handlelength=2, #长handleheight=1.2, #高fontsize=16, #字体大小。title_fontsize=20 #标题的字体大小。)

第八部分

GridSearchCV寻找表现最优的超参数。进行模型性能评估。调用SHAP的树解释器为验证集进行分析,批量生成并保存绘图结果。# =========================================================================================# ======================================7.执行部分=======================================# =========================================================================================if __name__ == "__main__":df_train_test = pd.read_excel(r'dataset_train_test.xlsx') #模型数据集df_val = pd.read_excel(r'dataset_validation.xlsx') #独立验证数据X_train_test = df_train_test.drop(columns=['Target_Class']) #特征数据y_train_test = df_train_test['Target_Class'] #目标数据features_base = X_train_test.columns.tolist() #特征名#划分书记X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=0.3, random_state=42)X_val = df_val.drop(columns=['Target_Class']) #特征数据y_val = df_val['Target_Class'] #目标数据print("网格搜索")#实例化模型base_model = CatBoostClassifier(loss_function='MultiClass', verbose=0, random_state=42)plot_all =Trueif plot_all:for i in COLOR_SCHEMES.keys():generate_shap_plots(shap_values_val_list, X_val.values, features_base, "Validation_Set", scheme_id=i)else:target_scheme = 1generate_shap_plots(shap_values_val_list, X_val.values, features_base, "Validation_Set",scheme_id=target_scheme)

如何应用到你自己的数据

1.设置是一次绘制一张图还是一次性绘制出所有配色的图,执行部分:
plot_all =True2.设置模型数据集保存的路径,执行部分:
df_train_test = pd.read_excel(r'dataset_train_test.xlsx') #模型数据集3.设置独立验证数据集保存的路径,执行部分:
df_val = pd.read_excel(r'dataset_validation.xlsx') #独立验证数据4.设置目标变量,执行部分:
y_train_test = df_train_test['Target_Class'] #目标数据5.设置特征变量,执行部分:
X_train_test = df_train_test.drop(columns=['Target_Class']) #特征数据6.设置超参数,执行部分:
param_grid = {'iterations': [100, 200, 300],'learning_rate': [0.01, 0.05, 0.1],'depth': [4, 6, 8],}
7.设置保存路径,主绘图函数:
plt.savefig(fr'\shap_results_{dataset_name}_scheme{scheme_id}.png', dpi=300, bbox_inches='tight')
推荐


获取方式
