import numpy as npimport pandas as pdimport shapimport matplotlib.pyplot as pltfrom matplotlib.cm import ScalarMappablefrom matplotlib.colors import Normalize, LinearSegmentedColormapfrom matplotlib.patches import Circle, Wedgefrom matplotlib.ticker import FormatStrFormattertry: from scipy.stats import pearsonrexcept Exception: pearsonr = Nonecolumns=['Pregnancies', 'Glucose', 'Blood', 'Skin', 'Insulin', 'BMI', 'Diab', 'Age']def pearson_corr_heatmap_mixed(ax, fig, x, feature_names, cmap, norm, cbar_pos, start_angle=-90.0): df = pd.DataFrame(x, columns=feature_names) corr = df.corr(method="pearson").to_numpy() n = corr.shape[0] pvals = None if pearsonr is not None: pvals = np.ones((n, n), dtype=float) for i in range(n): xi = df.iloc[:, i].to_numpy() for j in range(n): if i == j: pvals[i, j] = 0.0 else: _, pvals[i, j] = pearsonr(xi, df.iloc[:, j].to_numpy()) ax.set_xlim(-0.5, n - 0.5) ax.set_ylim(n - 0.5, -0.5) ax.set_aspect("equal") for k in range(n + 1): ax.plot([-0.5, n - 0.5], [k - 0.5, k - 0.5], color="#cfcfcf", lw=1) ax.plot([k - 0.5, k - 0.5], [-0.5, n - 0.5], color="#cfcfcf", lw=1) r_cell = 0.42 for i in range(n): for j in range(n): r = float(corr[i, j]) x0, y0 = j, i if i == j: ax.text(x0, y0, feature_names[i], ha="center", va="center", fontsize=14, color="#333333") continue if i > j: ax.text(x0, y0, f"{r:.2f}", ha="center", va="center", fontsize=12, color=cmap(norm(r)), fontweight="bold") continue ax.add_patch(Circle((x0, y0), r_cell, facecolor="white", edgecolor="black", lw=1)) theta = 360.0 * abs(r) if theta > 0: start = float(start_angle) if r >= 0: theta1, theta2 = start - theta, start else: theta1, theta2 = start, start + theta ax.add_patch(Wedge((x0, y0), r_cell, theta1=theta1, theta2=theta2, facecolor=cmap(norm(r)), edgecolor="none")) if pvals is not None: p = float(pvals[i, j]) stars = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "" if stars: ax.text(x0, y0, stars, ha="center", va="center", fontsize=12, color="#222222") ax.set_xticks([]) ax.set_yticks([]) for spine in ["top", "right", "bottom", "left"]: ax.spines[spine].set_visible(False) cax = fig.add_axes(cbar_pos) sm = ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cbar = fig.colorbar(sm, cax=cax) cbar.set_ticks([-1, -0.5, 0, 0.5, 1])fig = plt.figure(figsize=(18, 5), dpi=1200)max_display = min(20, len(columns))_left_bee = [0.06, 0.18, 0.36, 0.72]_left_cb = [0.43, 0.18, 0.012, 0.72]_right_corr = [0.52, 0.18, 0.4, 0.72]_right_cb = [0.95, 0.18, 0.015, 0.72]ax1 = fig.add_axes(_left_bee)# 先画蜂巢图(底部 x 轴:SHAP value)plt.sca(ax1)shap.summary_plot( shap_values_test, X_test, feature_names=columns, plot_type="violin", max_display=max_display, show=False, color_bar=True,)ax1 = plt.gca()ax1.set_position(_left_bee)fig.canvas.draw()_cb_ax = Nonefor _a in fig.axes: if _a is not ax1 and _a.get_ylabel() == "Feature value": _cb_ax = _a breakif _cb_ax is not None: _cb_ax.set_position(_left_cb)_xmin, _xmax = ax1.get_xlim()_m_data = max(abs(_xmin), abs(_xmax))if np.isfinite(_m_data) and _m_data > 0: _m_tick = _m_data _m_lim = _m_tick * 1.12 ax1.set_xlim(-_m_lim, _m_lim) ax1.set_xticks([-_m_tick, -_m_tick / 2.0, 0.0, _m_tick / 2.0, _m_tick]) ax1.xaxis.set_major_formatter(FormatStrFormatter("%.2f"))_bee_y_ticks = ax1.get_yticks()_bee_y_labels = [t.get_text() for t in ax1.get_yticklabels()]if not any(_bee_y_labels): fig.canvas.draw() _bee_y_ticks = ax1.get_yticks() _bee_y_labels = [t.get_text() for t in ax1.get_yticklabels()]# 顶部 x 轴:Mean(|SHAP|) 的条形图(与蜂巢图共享 y 轴)ax2 = ax1.twiny()ax2.set_position(ax1.get_position())plt.sca(ax2)shap.summary_plot( shap_values_test, X_test, feature_names=columns, plot_type="bar", max_display=max_display, show=False, color_bar=False,)# 让条形图作为背景显示,蜂巢图盖在上面ax2.set_ylim(ax1.get_ylim())ax2.set_yticks([])ax2.set_ylabel("")for bar in ax2.patches: bar.set_facecolor("#f2a7b5") bar.set_alpha(0.35) bar.set_edgecolor("none") bar.set_linewidth(0)ax2.set_zorder(0)ax1.set_zorder(1)ax1.patch.set_visible(False)# 在每个条形起点标注 mean(|SHAP|) 和百分比if hasattr(shap_values_test, "values"): _sv = np.asarray(shap_values_test.values)else: _sv = np.asarray(shap_values_test)_mean_abs_all = np.mean(np.abs(_sv), axis=0)_total = float(np.sum(_mean_abs_all)) if np.isfinite(np.sum(_mean_abs_all)) else 0.0_order = np.argsort(_mean_abs_all)[::-1][:max_display]_mean_abs_top = _mean_abs_all[_order]_total_top = float(np.sum(_mean_abs_top)) if np.isfinite(np.sum(_mean_abs_top)) else 0.0if len(_bee_y_labels) > 0: ax1.set_yticks(_bee_y_ticks) ax1.set_yticklabels(_bee_y_labels)_x_min, _x_max = ax2.get_xlim()_x0 = 0.0 if (_x_min <= 0.0 <= _x_max) else _x_min_x_text = _x0 + 0.01 * (_x_max - _x_min)_bars_sorted = sorted( ax2.patches, key=lambda b: (b.get_y() + b.get_height() / 2.0), reverse=True,)_n = min(len(_bars_sorted), len(_mean_abs_top))for i in range(_n): bar = _bars_sorted[i] _y = bar.get_y() + bar.get_height() / 2.0 _v = float(_mean_abs_top[i]) _pct = (100.0 * _v / _total_top) if _total_top > 0 else 0.0 ax2.text( _x_text, _y, f"{_v:.3f}({_pct:.2f}%)", va="center", ha="left", fontsize=9, color="black", zorder=3, )# 四边框与左侧刻度线(让特征名称更像示例图)for spine in ["top", "right", "bottom", "left"]: ax1.spines[spine].set_visible(True) ax1.spines[spine].set_linewidth(1.0)ax1.tick_params( axis="y", which="major", left=True, labelleft=True, length=8, width=1.0, direction="out", pad=4,)ax1.yaxis.grid(True, linestyle=":", linewidth=0.6, alpha=0.4)ax1.set_axisbelow(True)# 轴标签与刻度位置(尽量贴近示例图)ax1.set_xlabel("SHAP value (impact on model output)", fontsize=12)ax1.set_ylabel("")ax2.set_xlabel("Mean (|SHAP| value)", fontsize=12)ax2.xaxis.set_label_position("top")ax2.xaxis.tick_top()ax_corr = fig.add_axes(_right_corr)try: _X_corr = Xexcept NameError: _X_corr = X_test_feature_names = [f"X{i + 1}" for i in range(_X_corr.shape[1])]_norm = Normalize(vmin=-1, vmax=1)_cmap = LinearSegmentedColormap.from_list( "red_purple_blue", ["#1e88e5", "#8e24aa", "#ff1744"], N=256,)pearson_corr_heatmap_mixed(ax_corr, fig, _X_corr, _feature_names, _cmap, _norm, _right_cb, start_angle=-90.0)_x_map = " ".join([f"X{i + 1}:{columns[i]}" for i in range(min(len(columns), len(_feature_names)))])fig.text(0.5, 0.02, _x_map, ha="center", va="bottom", fontsize=12)plt.savefig("SHAP_与_Pearson_组合.png", dpi=800, bbox_inches="tight")plt.show()