import numpy as npimport xarray as xrimport pandas as pdimport matplotlib.pyplot as pltfrom scipy import statsimport cartopy.crs as ccrsplt.rcParams['figure.figsize'] = (10, 10)plt.rcParams['font.size'] = 15def read_nino34_data(file_path): try: ds = xr.open_dataset(file_path) for var in ['nino34', 'nino3.4', 'NINO34', 'NINO3.4', 'sst', 'temp']: if var in ds.data_vars: nino_data = ds[var] break else: nino_data = ds[list(ds.data_vars.keys())[0]] if len(nino_data.shape) > 1: nino_data = nino_data.mean(dim=[d for d in nino_data.dims if d != 'time']) return nino_data.values, nino_data.time except Exception: return None, Nonedef read_hadisst_data(file_path, start_year=1870, end_year=2024): try: ds = xr.open_dataset(file_path) for var in ['sst', 'temp', 'sea_surface_temperature', 'SST']: if var in ds.data_vars: sst_data = ds[var].sel(time=slice(f'{start_year}', f'{end_year}')) break else: sst_data = ds[list(ds.data_vars.keys())[0]].sel(time=slice(f'{start_year}', f'{end_year}')) return sst_data.values, sst_data.time, ds.longitude.values, ds.latitude.values except Exception: return None, None, None, Nonedef regress_field(field, idx): reg = np.zeros(field.shape[1:]) pval = np.zeros(field.shape[1:]) n_time = min(field.shape[0], len(idx)) field, idx = field[:n_time], idx[:n_time] for j in range(field.shape[1]): for i in range(field.shape[2]): valid = ~(np.isnan(idx) | np.isnan(field[:, j, i])) if valid.sum() < 10: reg[j, i] = np.nan pval[j, i] = np.nan continue reg[j, i], _, _, pval[j, i], _ = stats.linregress(idx[valid], field[valid, j, i]) return reg, pvaldef calculate_lag_regression(x, y, max_lag=12): n_time = len(x) if y.ndim == 3: n_lat, n_lon = y.shape[1], y.shape[2] else: y = y.reshape(n_time, -1, 1) n_lat, n_lon = y.shape[1], 1 lags = [-12, -6, -3, 0, 3, 6, 12] reg_coeffs, p_values, correlations = {}, {}, {} for lag in lags: print(f"calculate {lag:3d}") if lag <= 0: slc_f, slc_i = slice(0, n_time + lag), slice(-lag, None) else: slc_f, slc_i = slice(lag, n_time), slice(0, n_time - lag) x_lag, y_lag = x[slc_i], y[slc_f] reg_coeffs[lag], p_values[lag] = regress_field(y_lag, x_lag) corr = np.zeros_like(reg_coeffs[lag]) for j in range(n_lat): for i in range(n_lon): valid = ~(np.isnan(x_lag) | np.isnan(y_lag[:, j, i])) if valid.sum() > 10: corr[j, i] = np.corrcoef(x_lag[valid], y_lag[valid, j, i])[0, 1] else: corr[j, i] = np.nan correlations[lag] = corr return {'lags': lags, 'regression_coeffs': reg_coeffs, 'p_values': p_values, 'correlations': correlations, 'n_lat': n_lat, 'n_lon': n_lon}def plot_lag_regression_results(results, lon, lat, significance_level=0.05): lags = results['lags'] correlations = results['correlations'] p_values = results['p_values'] reg = results['regression_coeffs'] # 将经度中心移到180° lon_shift = (lon + 180) % 360 - 180 lon_idx = np.argsort(lon_shift) lon_shift = lon_shift[lon_idx] n_lags = len(lags) n_lat, n_lon = results['n_lat'], results['n_lon'] fig = plt.figure(figsize=(10, 10)) key_lags = [-12, -6, -3, 0, 6, 12] for i, lag in enumerate(key_lags): if lag in lags: n = 2 * (i % 3) + (i // 3) + 1 ax = plt.subplot(3, 2, n, projection=ccrs.PlateCarree(central_longitude=180)) # 重排数据以匹配经度移动 corr_data = correlations[lag][:, lon_idx] p_data = p_values[lag][:, lon_idx] reg_data = reg[lag][:, lon_idx] # 绘制相关系数/回归系数 im = ax.contourf(lon_shift, lat, corr_data, levels=np.arange(-1,1.1,0.1), cmap='seismic', extend='both', transform=ccrs.PlateCarree()) # 显著性打点 sig_lon, sig_lat = np.meshgrid(lon_shift, lat) cb = ax.contourf(sig_lon, sig_lat, p_data, [0, 0.05, 1], zorder=2, hatches=['//\\', None], colors="none", transform=ccrs.PlateCarree()) cb.set_edgecolor('g') cb.set_linewidth(0) # 地图装饰 ax.coastlines() ax.gridlines(draw_labels=True, xlocs=np.arange(-180, 181, 60), ylocs=np.arange(-90, 91, 30)) ax.set_extent([-180, 180, -90, 90], crs=ccrs.PlateCarree()) ax.set_title(f'Lag = {lag:d} months', fontsize=15) # 色标 if lag in [12]: cax = fig.add_axes([ax.get_position().x0 - 0.2, ax.get_position().y0 - 0.12, ax.get_position().width, 0.01]) cb = plt.colorbar(im, cax=cax, orientation='horizontal') plt.tight_layout() plt.savefig('lead_lag.png', dpi=600, bbox_inches='tight') plt.show()def main(): nino_file = "nino34.long.anom.nc" hadisst_file = "HadISST_sst.nc" nino_data, nino_time = read_nino34_data(nino_file) sst_data, sst_time, lon, lat = read_hadisst_data(hadisst_file) nino_dates, sst_dates = pd.to_datetime(nino_time), pd.to_datetime(sst_time) common_start, common_end = max(nino_dates.min(), sst_dates.min()), min(nino_dates.max(), sst_dates.max()) nino_mask = (nino_dates >= common_start) & (nino_dates <= common_end) sst_mask = (sst_dates >= common_start) & (sst_dates <= common_end) nino_aligned = nino_data[nino_mask] sst_aligned = sst_data[sst_mask, :, :] results = calculate_lag_regression(nino_aligned, sst_aligned, max_lag=12) plot_lag_regression_results(results, lon, lat)if __name__ == "__main__": main()