用SHAP打破机器学习“黑箱”,核心看两个关键可视化。1. 特征重要性交互网络:节点代表各类核心特征,节点大小体现特征重要性,连线强弱反映特征间交互关系,直观呈现多特征协同作用的规律。2. SHAP瀑布图:聚焦单一样本,从基线值拆解每个特征对个体的作用,可精准定位关键影响特征。

先看复现前后的对比图


Python代码
import pandas as pdimport numpy as npimport pickleimport xgboost as xgbimport shapfrom sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressorfrom lightgbm import LGBMRegressorimport matplotlib.pyplot as pltimport osimport matplotlib.colors as mcolorsimport networkx as nxfrom sklearn.model_selection import train_test_split, GridSearchCVimport matplotlib.patches as patches
df = pd.read_csv("./data/train.csv")X = df.iloc[:, 4:]y = df['Yield'].valuesX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
param = {'max_depth': 6,'eta': 0.036,'tree_method': 'hist','device': 'cuda'}n_trees = 500dmat = xgb.DMatrix(X_train, y_train)best_model = xgb.train(param, dmat, n_trees)
explainer= shap.TreeExplainer(best_model)shap_interaction_values= explainer.shap_interaction_values(X_test)shap_values= explainer.shap_values(X_test)
# 全局绘图风格:矢量PDF可编辑 + Arial字体plt.rcParams['font.family'] = 'Arial'plt.rcParams['pdf.fonttype'] = 42plt.rcParams['ps.fonttype'] = 42# 颜色库COLOR_SCHEMES = {1: {'nodes': plt.cm.Greens, 'edges': plt.cm.Purples,'bar_pos': '#d55e5b', 'bar_neg': '#5b85ba', 'scatter_color': '#e0e0e0'},}scheme_index = 1 # 设置要使用的颜色方案(原代码45超出范围,改为1)current_color_scheme = COLOR_SCHEMES.get(scheme_index, COLOR_SCHEMES[1]) # 提取颜色方案# 形状标记库STYLE_SCHEMES = {1: {'marker': 'o', 'linestyle': '-'},}style_index = 1 # 设置要使用的形状方案current_style_scheme = STYLE_SCHEMES.get(style_index, STYLE_SCHEMES[1]) # 提取形状方案
def plot_interaction(features,importance,interaction_matrix,top_n=None,node_size_min=220,node_size_max=2200,node_size_power=1.15,edge_width_weak_min=0.5,edge_width_weak_max=1.1,edge_width_strong_min=1.6,edge_width_strong_max=7.6,edge_highlight_percentile=85,edge_alpha=0.85,save_pdf_path=None,):# 如果指定了top_n,只选择最重要的特征if top_n is not None and top_n < len(features):# 先取前top_n,再按重要性从高到低排序,便于阅读sorted_indices = np.argsort(importance)[-top_n:][::-1]features = [features[i] for i in sorted_indices]importance = importance[sorted_indices]interaction_matrix = interaction_matrix[np.ix_(sorted_indices, sorted_indices)]cmap_nodes = current_color_scheme['nodes'] # 节点颜色映射cmap_edges = current_color_scheme['edges'] # 边颜色映射node_marker = current_style_scheme['marker'] # 节点的标记形状edge_linestyle = current_style_scheme['linestyle'] # 线条样式# 创建画布fig, ax = plt.subplots(figsize=(12, 10), subplot_kw={'aspect': 'equal'})n_features = len(features) # 特征总数量# 创建一个 NetworkX 图对象G = nx.Graph()G.add_nodes_from(features) # 向图中添加节点pos = nx.circular_layout(G) # 生成节点的环形布局坐标# 标签的坐标(向外偏移)label_pos = {k: (v[0] * 1.18, v[1] * 1.18) for k, v in pos.items()}# 归一化(真实值映射:Vimp/Vint 与“实际绘制边”一一对应)max_imp = float(np.max(importance)) if np.max(importance) > 0 else 1.0# 只统计将被绘制的边强度(上三角、非对角)pair_interactions = []for i in range(n_features):for j in range(i + 1, n_features):v = (interaction_matrix[i, j] + interaction_matrix[j, i]) / 2if v > 0:pair_interactions.append(v)pair_interactions = np.array(pair_interactions, dtype=float)if pair_interactions.size > 0:edge_vmin = float(np.min(pair_interactions))edge_vmax = float(np.max(pair_interactions))if edge_vmax <= edge_vmin:edge_vmax = edge_vmin + 1e-12else:edge_vmin, edge_vmax = 0.0, 1.0max_interaction = edge_vmaxnorm_nodes = mcolors.Normalize(vmin=0, vmax=max_imp)norm_edges = mcolors.Normalize(vmin=edge_vmin, vmax=edge_vmax)# 计算高强度边阈值(用于加粗高亮)edge_highlight_percentile = float(np.clip(edge_highlight_percentile, 0, 100))highlight_threshold = (np.percentile(pair_interactions, edge_highlight_percentile)if pair_interactions.size > 0 else 0)# 参数安全处理node_size_min = max(float(node_size_min), 1.0)node_size_max = max(float(node_size_max), node_size_min)node_size_power = max(float(node_size_power), 0.01)edge_width_weak_min = max(float(edge_width_weak_min), 0.01)edge_width_weak_max = max(float(edge_width_weak_max), edge_width_weak_min)edge_width_strong_min = max(float(edge_width_strong_min), 0.01)edge_width_strong_max = max(float(edge_width_strong_max), edge_width_strong_min)edge_alpha = float(np.clip(edge_alpha, 0, 1))# 节点颜色和大小node_colors = []node_sizes = []for i in range(n_features):node_colors.append(cmap_nodes(norm_nodes(importance[i])))# 节点大小可控:最小值 + 缩放范围 * (标准化重要性^指数)size = node_size_min + (node_size_max - node_size_min) * (importance[i] / max_imp) ** node_size_powernode_sizes.append(size)# 绘制所有弱交互边(浅色细线)weak_edges = []weak_edge_colors = []weak_edge_widths = []strong_edges = []strong_edge_colors = []strong_edge_widths = []##TODO:篇幅问题此处没有完全展示代码,需要定义节点循环# 绘制节点(实心绿、少量白心)nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, ax=ax,alpha=0.98, node_shape=node_marker, linewidths=0.6, edgecolors='#2f2f2f')# 绘制节点标签nx.draw_networkx_labels(G, label_pos, font_size=11, font_weight='bold', ax=ax, font_family='serif')# 设置图形属性ax.axis('off')ax.set_xlim(-1.55, 1.55)ax.set_ylim(-1.65, 1.55)# 设置标题title_text = 'Impact intensity'if top_n is not None:title_text += f' (Top {top_n} Features)'plt.title(title_text, loc='left', fontsize=12, pad=14, fontfamily='Arial')# 颜色条与文字(紧贴网络图并缩小,匹配示例)sm_node = plt.cm.ScalarMappable(cmap=cmap_nodes, norm=norm_nodes)sm_node.set_array([])sm_edge = plt.cm.ScalarMappable(cmap=cmap_edges, norm=norm_edges)sm_edge.set_array([])# 节点重要性颜色条(更短、更近)cax_node = fig.add_axes([0.22, 0.115, 0.20, 0.016])cbar_node = plt.colorbar(sm_node, cax=cax_node, orientation='horizontal')cbar_node.set_ticks(np.linspace(0, max_imp, 6))cbar_node.ax.tick_params(labelsize=9, rotation=90, pad=2)cbar_node.outline.set_linewidth(0.8)# 交互强度颜色条(更短、更近)cax_edge = fig.add_axes([0.58, 0.115, 0.20, 0.016])cbar_edge = plt.colorbar(sm_edge, cax=cax_edge, orientation='horizontal')cbar_edge.set_ticks(np.linspace(edge_vmin, max_interaction, 4))cbar_edge.ax.tick_params(labelsize=9, rotation=90, pad=2)cbar_edge.outline.set_linewidth(0.8)# 添加色带标题与箭头文字(按示例:两段词 + 实线箭头)ax.text(0.32, 0.147, 'Importance', transform=fig.transFigure,ha='center', va='bottom', fontsize=12, fontfamily='Arial')ax.text(0.273, 0.131, 'Low', transform=fig.transFigure,ha='right', va='bottom', fontsize=10.5, fontfamily='serif')ax.text(0.367, 0.131, 'High', transform=fig.transFigure,ha='left', va='bottom', fontsize=10.5, fontfamily='serif')ax.annotate('', xy=(0.362, 0.134), xytext=(0.278, 0.134),xycoords=fig.transFigure, textcoords=fig.transFigure,arrowprops=dict(arrowstyle='-|>', lw=1.0, color='black', mutation_scale=10))ax.text(0.215, 0.123, 'Vimp', transform=fig.transFigure,ha='right', va='center', fontsize=12, fontfamily='serif')ax.text(0.68, 0.147, 'Interaction Intensity', transform=fig.transFigure,ha='center', va='bottom', fontsize=12, fontfamily='serif')ax.text(0.628, 0.131, 'Weak', transform=fig.transFigure,ha='right', va='bottom', fontsize=10.5, fontfamily='serif')ax.text(0.733, 0.131, 'Strong', transform=fig.transFigure,ha='left', va='bottom', fontsize=10.5, fontfamily='serif')ax.annotate('', xy=(0.728, 0.134), xytext=(0.633, 0.134),xycoords=fig.transFigure, textcoords=fig.transFigure,arrowprops=dict(arrowstyle='-|>', lw=1.0, color='black', mutation_scale=10))ax.text(0.575, 0.123, 'Vint', transform=fig.transFigure,ha='right', va='center', fontsize=12, fontfamily='serif')plt.tight_layout(rect=[0.02, 0.18, 0.98, 0.98])if save_pdf_path is not None:fig.savefig(save_pdf_path, format='pdf', bbox_inches='tight')plt.show()return fig, ax

def plot_waterfall(features, shap_values, expected_value,current_color_scheme, top_n=15,save_pdf_path=None,show_contribution_labels=True,contribution_label_fmt='{:+.3f}',contribution_label_fontsize=9):mean_shap = np.mean(shap_values, axis=0)sorted_indices = np.argsort(np.abs(mean_shap))[::-1][:top_n]sorted_features = [features[i] for i in sorted_indices]shap_vals = mean_shap[sorted_indices]if isinstance(expected_value, (list, np.ndarray)):base_value = float(np.array(expected_value).flatten()[0])else:base_value = float(expected_value)starts, ends = [], []y_curr = base_valuefor v in shap_vals:starts.append(y_curr)y_curr += vends.append(y_curr)final_value = y_currpos_color = '#f1696d'neg_color = '#44c3df'fig, ax = plt.subplots(figsize=(13.8, 6.25), dpi=140)fig.patch.set_facecolor('white')ax.set_facecolor('white')x = np.arange(len(sorted_features))y_all = [base_value] + starts + endsy_range = max(np.max(y_all) - np.min(y_all), 1e-6)bar_w = 0.36tip_h = 0.085 * y_range # 箭头尖高度(全局基准)## TODO: 篇幅原因,代码没有展示完整,此处需要定义循环展示每个特征的贡献值# 竖向参考线for xi in x:ax.axvline(xi, color='#9a9a9a', linestyle='-', linewidth=0.55, alpha=0.4, zorder=1)y_min, y_max = min(y_all), max(y_all)y_span = (y_max - y_min) if y_max > y_min else 1.0pad = 0.10 * y_span# 预留额外上下空间,确保两侧文本框始终在坐标轴内部ax.set_ylim(y_min - pad, y_max + pad)y_low, y_high = ax.get_ylim()margin = 0.08 * (y_high - y_low)# 左右说明框:固定纵向居中(不随数据上下漂移)center_y = (y_low + y_high) / 2.0fx_y = center_yex_y = center_yax.set_xticks(x)ax.set_xticklabels(sorted_features, rotation=90, fontsize=10, fontfamily='Arial')# 取消左边框刻度(主轴左侧不显示任何刻度线/标签)ax.tick_params(axis='y', which='both', left=False, labelleft=False, length=0)ax.set_ylabel('')ax_right = ax.twinx()ax_right.set_ylim(ax.get_ylim())ax_right.set_ylabel('SHAP value', fontsize=13, fontfamily='Arial')ax_right.tick_params(axis='y', labelsize=11)ax_right.yaxis.set_major_locator(plt.MaxNLocator(4))ax_right.yaxis.set_major_formatter(plt.FormatStrFormatter('%.1f'))for side in ['left', 'right', 'top', 'bottom']:ax.spines[side].set_linewidth(1.2)ax.spines[side].set_color('#2f2f2f')for sp in ax_right.spines.values():sp.set_visible(False)# 左右竖排框注:放在边框外,且始终纵向居中x_left, x_right = ax.get_xlim()x_span = x_right - x_leftex_x = x_left - 0.035 * x_spanfx_x = x_right - 0.020 * x_spanbox_kw = dict(facecolor='white', edgecolor='#4a4a4a', boxstyle='square,pad=0.25')ax.text(ex_x, ex_y, f'E[f(x)]={base_value:.4f}',rotation=90, va='center', ha='center', fontsize=12, fontfamily='Arial',bbox=box_kw, clip_on=False)ax.text(fx_x, fx_y, f'f(x)={final_value:.4f}',rotation=90, va='center', ha='center', fontsize=12, fontfamily='Arial',bbox=box_kw, clip_on=True)ax.set_title('Impact direction', loc='left', fontsize=20, fontfamily='Arial', pad=4)legend_elements = [patches.Patch(facecolor=pos_color, edgecolor='#2f2f2f', label='Positive'),patches.Patch(facecolor=neg_color, edgecolor='#2f2f2f', label='Negative')]legend_y = -0.46label_fs = 16# 先放“Impact direction”文字,再让图例紧跟其后,保证三者同一水平线且字号一致impact_x = 0.42ax.legend(handles=legend_elements,ncol=2,loc='upper left',bbox_to_anchor=(impact_x + 0.01, legend_y),frameon=False,fontsize=label_fs,columnspacing=1.2,handletextpad=0.5,borderaxespad=0.0,prop={'family': 'Arial'})plt.subplots_adjust(bottom=0.44, left=0.10, right=0.88, top=0.88)if save_pdf_path is not None:fig.savefig(save_pdf_path, format='pdf', bbox_inches='tight')plt.show()return fig, ax
if __name__ == "__main__":# 输出目录(CSV 与 PDF)output_dir = './outputs'os.makedirs(output_dir, exist_ok=True)# 特征名称features = X.columns.tolist()# 统一处理shap_values(某些场景下可能返回list)shap_values_arr = shap_values[0] if isinstance(shap_values, list) else shap_values# 统一处理交互值(某些场景下可能返回list)shap_interaction_arr = shap_interaction_values[0] if isinstance(shap_interaction_values, list) else shap_interaction_values# 计算一维特征重要性(用于节点大小和颜色)importance_scaled = np.mean(np.abs(shap_values_arr), axis=0)# 计算二维平均交互矩阵(用于边)mean_interaction_matrix = np.mean(np.abs(shap_interaction_arr), axis=0)# 保存特征重要性排序(CSV)importance_df = pd.DataFrame({'feature': features,'importance': importance_scaled}).sort_values('importance', ascending=False).reset_index(drop=True)importance_df.to_csv(os.path.join(output_dir, 'feature_importance_ranking_all_features.csv'),index=False,encoding='utf-8-sig')# 保存两两特征交互值(CSV,上三角去重)interaction_records = []for i in range(len(features)):for j in range(i + 1, len(features)):interaction_strength = (mean_interaction_matrix[i, j] + mean_interaction_matrix[j, i]) / 2interaction_records.append({'feature_1': features[i],'feature_2': features[j],'interaction_value': float(interaction_strength)})interaction_df = pd.DataFrame(interaction_records).sort_values('interaction_value', ascending=False).reset_index(drop=True)interaction_df.to_csv(os.path.join(output_dir, 'pairwise_feature_interactions_all_features.csv'),index=False,encoding='utf-8-sig')# 极简可调参数:仅控制特征数量、节点大小、线条粗细interaction_plot_params = {'top_n': 20, # 显示前N个特征'node_size_min': 220, # 节点最小大小'node_size_max': 2200, # 节点最大大小'edge_width_weak_min': 1.5, # 线条最细(弱交互)'edge_width_weak_max': 3.1, # 线条最粗(弱交互)'edge_width_strong_min': 2.6,# 线条最细(强交互)'edge_width_strong_max': 7.6 # 线条最粗(强交互)}# 另外导出“当前绘图所选特征”的重要性与交互值,便于逐项核对top_n = interaction_plot_params.get('top_n', None)if top_n is not None and top_n < len(features):selected_idx = np.argsort(importance_scaled)[-top_n:][::-1]else:selected_idx = np.arange(len(features))selected_features = [features[i] for i in selected_idx]selected_importance = importance_scaled[selected_idx]selected_importance_df = pd.DataFrame({'feature': selected_features,'importance': selected_importance}).sort_values('importance', ascending=False).reset_index(drop=True)selected_importance_df.to_csv(os.path.join(output_dir, 'feature_importance_ranking_for_plot.csv'),index=False,encoding='utf-8-sig')# 只用于画图的特征两两交互值(上三角去重)selected_interaction_matrix = mean_interaction_matrix[np.ix_(selected_idx, selected_idx)]selected_interaction_records = []for i in range(len(selected_features)):for j in range(i + 1, len(selected_features)):interaction_strength = (selected_interaction_matrix[i, j] + selected_interaction_matrix[j, i]) / 2selected_interaction_records.append({'feature_1': selected_features[i],'feature_2': selected_features[j],'interaction_value': float(interaction_strength)})pd.DataFrame(selected_interaction_records).sort_values('interaction_value', ascending=False).reset_index(drop=True).to_csv(os.path.join(output_dir, 'pairwise_feature_interactions_for_plot.csv'),index=False,encoding='utf-8-sig')# 调用绘制环形交互图函数并保存矢量PDFplot_circular_interaction(features,importance_scaled,mean_interaction_matrix,save_pdf_path=os.path.join(output_dir, 'A_impact_intensity_feature_interaction_network.pdf'),**interaction_plot_params)# 获取模型的基础期望值(模型在整个数据集上的平均预测基准值)expected_value = explainer.expected_value# 使用X_train来计算SHAP值shap_values_train = explainer.shap_values(X_train)shap_values_train_arr = shap_values_train[0] if isinstance(shap_values_train, list) else shap_values_train# 调用绘制影响方向瀑布图函数并保存矢量PDFplot_impact_direction_waterfall(features,shap_values_train_arr,expected_value,current_color_scheme,top_n=20,save_pdf_path=os.path.join(output_dir, 'B_impact_direction_waterfall.pdf'))
以上内容为原创,转载需声明出处。
往期回顾
参考文献:Yin L, Wei W, Li H, et al. Hidden risks in greening: Unveiling the impact of bare land changes on landscape ecological risks in arid and semi-arid regions of China[J]. Environmental Impact Assessment Review, 2026, 117: 108244.
以上内容为原创,转载需声明出处。

🔥亲测有效,一键运行,助你快速上手!
🔥整理不易,欢迎点赞分享给更多小伙伴~