import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib.colors as mcolorsimport matplotlib.cm as cmimport networkx as nxfrom matplotlib.lines import Line2D# ---------------------------# 1) 兼容 multi-class 的 SHAP 输出# ---------------------------def _pick_class(arr, class_index=0): """ shap_values / shap_interaction_values 在多分类时常返回 list,每个元素对应一个类别。 这里默认取 class_index=0;你也可以改成 1 或其他。 """ if isinstance(arr, (list, tuple)): return arr[class_index] return arr# ---------------------------# 2) 计算节点/边指标# ---------------------------def compute_shap_stats(explainer, X_test, class_index=0): # SHAP 交互值 (n_samples, n_features, n_features) shap_interaction_values = explainer.shap_interaction_values(X_test) shap_interaction_values = _pick_class(shap_interaction_values, class_index) # SHAP 值 (n_samples, n_features) shap_values = explainer.shap_values(X_test) shap_values = _pick_class(shap_values, class_index) shap_values = np.asarray(shap_values) shap_interaction_values = np.asarray(shap_interaction_values) # 节点:重要性(绝对值均值) -> 点大小 feature_importance_abs = np.mean(np.abs(shap_values), axis=0) # 节点:方向(均值) -> 点颜色 feature_importance_signed = np.mean(shap_values, axis=0) # 边:交互强度(绝对值均值) -> 线宽 mean_interaction_matrix_abs = np.mean(np.abs(shap_interaction_values), axis=0) np.fill_diagonal(mean_interaction_matrix_abs, 0) # 边:交互方向(均值) -> 线颜色 mean_interaction_matrix_signed = np.mean(shap_interaction_values, axis=0) np.fill_diagonal(mean_interaction_matrix_signed, 0) return (feature_importance_abs, feature_importance_signed, mean_interaction_matrix_abs, mean_interaction_matrix_signed)# ---------------------------# 3) 画图函数:环形交互网络# ---------------------------def plot_circular_interaction( features, importance_abs, importance_signed, interaction_matrix_abs, interaction_matrix_signed, title="Shap Circular Interaction", edge_threshold=0.0, # 只显示强度>阈值的边;想和示例一样更“密”就保持 0 max_edges=None, # 限制最多画多少条边(从强到弱筛选);None=不限制 figsize=(12, 8), node_size_range=(80, 900), # 节点大小映射范围(可按你的视觉偏好微调) width_range=(0.3, 8.0), # 线宽映射范围 alpha_range=(0.08, 0.85), # 透明度映射范围 cmap_edges_name="PRGn", # 负紫正绿:非常接近你图里那根边颜色条 cmap_nodes_name="RdBu_r", # 负蓝正红:接近你图里节点颜色条 marker="o", edge_linestyle="solid", save_png="shap_circular_interaction.png", save_pdf="shap_circular_interaction.pdf", dpi=600): features = list(features) n_features = len(features) # --- 构图 + 圆环布局 --- G = nx.Graph() G.add_nodes_from(features) pos = nx.circular_layout(G) label_pos = {k: (v * 1.12) for k, v in pos.items()} # 标签稍微外扩 # --- 颜色归一化:建议做对称归一化(正负视觉更平衡)--- # 边(交互方向) edge_vmax = np.max(np.abs(interaction_matrix_signed)) if np.any(interaction_matrix_signed) else 1.0 norm_edges = mcolors.Normalize(vmin=-edge_vmax, vmax=edge_vmax) cmap_edges = cm.get_cmap(cmap_edges_name) # 点(SHAP方向) node_vmax = np.max(np.abs(importance_signed)) if np.any(importance_signed) else 1.0 norm_nodes = mcolors.Normalize(vmin=-node_vmax, vmax=node_vmax) cmap_nodes = cm.get_cmap(cmap_nodes_name) # --- 映射基准 --- max_interaction_abs = np.max(interaction_matrix_abs) if np.any(interaction_matrix_abs) else 1.0 max_importance_abs = np.max(importance_abs) if np.any(importance_abs) else 1.0 # --- 收集边(只取上三角)--- interactions = [] for i in range(n_features): for j in range(i + 1, n_features): strength_abs = interaction_matrix_abs[i, j] strength_signed = interaction_matrix_signed[i, j] if strength_abs > edge_threshold: interactions.append((features[i], features[j], strength_abs, strength_signed)) # 按强度从弱到强画(弱的先画,强的后画,视觉更像你那张图) interactions.sort(key=lambda x: x[2]) # 限制边数(如果你觉得太密) if max_edges is not None and len(interactions) > max_edges: # 取最强的 max_edges 条,然后仍旧按从弱到强绘制 interactions = sorted(interactions, key=lambda x: x[2])[-max_edges:] interactions.sort(key=lambda x: x[2]) # --- 画布 --- fig, ax = plt.subplots(figsize=figsize, subplot_kw={'aspect': 'equal'}) # --------------------------- # 3.1 绘制边 # --------------------------- w_min, w_max = width_range a_min, a_max = alpha_range for u, v, strength_abs, strength_signed in interactions: # 线宽/透明度映射 w = w_min + (strength_abs / max_interaction_abs) * (w_max - w_min) a = a_min + (strength_abs / max_interaction_abs) * (a_max - a_min) color = cmap_edges(norm_edges(strength_signed)) nx.draw_networkx_edges( G, pos, edgelist=[(u, v)], width=w, edge_color=[color], style=edge_linestyle, alpha=a, ax=ax ) # --------------------------- # 3.2 绘制节点 # --------------------------- ns_min, ns_max = node_size_range node_sizes = [] node_colors = [] for i, feat in enumerate(features): imp_abs = importance_abs[i] imp_sign = importance_signed[i] s = ns_min + (imp_abs / max_importance_abs) * (ns_max - ns_min) node_sizes.append(s) node_colors.append(cmap_nodes(norm_nodes(imp_sign))) nx.draw_networkx_nodes( G, pos, node_size=node_sizes, node_color=node_colors, linewidths=0.8, edgecolors="white", alpha=0.98, ax=ax ) # --------------------------- # 3.3 绘制外圈标签 # --------------------------- for node, (x, y) in label_pos.items(): ha = "left" if x >= 0 else "right" ax.text(x, y, node, fontsize=11, ha=ha, va="center") # 轴/范围/标题 ax.axis("off") ax.set_xlim(-1.55, 1.55) ax.set_ylim(-1.55, 1.55) ax.set_title(title, y=0.97, fontsize=15) # --------------------------- # 4) 左侧图例:线宽 + 节点大小 # --------------------------- # 线宽图例(用相同映射公式得到示例线宽) line_levels = [max_interaction_abs, max_interaction_abs * 0.5, max_interaction_abs * 0.1] line_labels = [f"{val:.2f}" for val in line_levels] legend_lines = [] for val in line_levels: w = w_min + (val / max_interaction_abs) * (w_max - w_min) legend_lines.append(Line2D([0], [0], color="black", lw=w)) legend1 = ax.legend( legend_lines, line_labels, loc="center left", bbox_to_anchor=(-0.10, 0.78), title="Interaction Strength\n(Line Width)", frameon=False, labelspacing=1.5 ) ax.add_artist(legend1) # 节点大小图例(用相同映射公式得到示例点大小) node_levels = [max_importance_abs, max_importance_abs * 0.5, max_importance_abs * 0.1] node_labels = [f"{val:.2f}" for val in node_levels] legend_nodes = [] for val in node_levels: s = ns_min + (val / max_importance_abs) * (ns_max - ns_min) # Line2D 的 markersize 单位与 scatter 不同,这里做一个经验缩放 legend_nodes.append(Line2D( [0], [0], marker=marker, color="w", markerfacecolor="black", markersize=np.sqrt(s) / 2.2, linestyle="None" )) ax.legend( legend_nodes, node_labels, loc="center left", bbox_to_anchor=(-0.10, 0.33), title="Feature Importance\n(Node Size)", frameon=False, labelspacing=2.6 ) # --------------------------- # 5) 右侧颜色条:边颜色 + 点颜色 # --------------------------- # 边颜色条 sm_edge = cm.ScalarMappable(norm=norm_edges, cmap=cmap_edges) sm_edge.set_array([]) cax_edge = fig.add_axes([0.86, 0.55, 0.018, 0.28]) # [left, bottom, width, height] cbar_edge = plt.colorbar(sm_edge, cax=cax_edge) cbar_edge.set_label("Interaction Value (Signed)", rotation=270, labelpad=16, fontsize=10) cbar_edge.outline.set_visible(False) # 节点颜色条 sm_node = cm.ScalarMappable(norm=norm_nodes, cmap=cmap_nodes) sm_node.set_array([]) cax_node = fig.add_axes([0.86, 0.18, 0.018, 0.28]) cbar_node = plt.colorbar(sm_node, cax=cax_node) cbar_node.set_label("Feature Value (Signed)", rotation=270, labelpad=16, fontsize=10) cbar_node.outline.set_visible(False) # 保存# plt.savefig(save_png, dpi=dpi, bbox_inches="tight") plt.savefig(save_pdf, dpi=600, bbox_inches='tight') return fig, ax# ---------------------------# 6) 主程序示例(你把 explainer / X_test 换成自己的即可)# ---------------------------features=X_test.columns.tolist()# 这里演示写法(你要把 explainer、X_test、features 换成真实变量)feature_importance_abs, feature_importance_signed, mean_abs, mean_signed = compute_shap_stats(explainer, X_test)plot_circular_interaction(features, feature_importance_abs, feature_importance_signed, mean_abs, mean_signed)