import numpy as npimport matplotlib.pyplot as pltfrom itertools import chain# ========== GF(2^8) 运算 ==========def gf_mult(a, b): """GF(2^8)乘法,使用本原多项式 0x11b""" p = 0 for _ in range(8): if b & 1: p ^= a hi_bit = a & 0x80 a = (a << 1) & 0xff if hi_bit: a ^= 0x1b b >>= 1 return p# ========== S盒生成 ==========def build_sbox(): sbox = [0] * 256 for i in range(256): inv = i # 求逆元 if i != 0: for j in range(256): if gf_mult(i, j) == 1: inv = j break # 仿射变换 b = inv sbox[i] = b ^ 0x63 for _ in range(4): b = ((b << 1) | (b >> 7)) & 0xff sbox[i] ^= b sbox[i] ^= 0x63 return sboxSBOX = build_sbox()def sub_bytes(state): return np.vectorize(lambda x: SBOX[x])(state)# ========== ShiftRows ==========def shift_rows(state): s = state.copy() for r in range(4): s[r] = np.roll(state[r], -r) return s# ========== MixColumns ==========MIX_MATRIX = np.array([[2, 3, 1, 1], [1, 2, 3, 1], [1, 1, 2, 3], [3, 1, 1, 2]], dtype=np.uint8)def mix_columns(state): s = np.zeros((4,4), dtype=np.uint8) for c in range(4): col = state[:, c] for r in range(4): s[r, c] = gf_mult(MIX_MATRIX[r,0], col[0]) ^ \ gf_mult(MIX_MATRIX[r,1], col[1]) ^ \ gf_mult(MIX_MATRIX[r,2], col[2]) ^ \ gf_mult(MIX_MATRIX[r,3], col[3]) return s# ========== 密钥扩展 (AES-128) ==========RCON = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36]def key_expansion(key): """key: 16字节np.array""" w = np.zeros((44, 4), dtype=np.uint8) for i in range(4): w[i] = key[4*i:4*i+4] for i in range(4, 44): temp = w[i-1].copy() if i % 4 == 0: temp = np.roll(temp, -1) # 循环左移1字节 temp = np.array([SBOX[x] for x in temp]) temp[0] ^= RCON[i//4 - 1] w[i] = w[i-4] ^ temp # 重塑为11个4x4轮密钥 round_keys = np.zeros((11, 4, 4), dtype=np.uint8) for i in range(11): round_keys[i] = w[4*i:4*i+4].T # 注意矩阵转置以匹配状态 return round_keys# ========== AES加密一个块 ==========def aes_encrypt(plaintext, key): state = plaintext.reshape(4,4).T # 列优先 rkeys = key_expansion(key) states = [state.copy()] # 初始轮密钥加 state ^= rkeys[0] states.append(state.copy()) # 9轮 for r in range(1, 10): state = sub_bytes(state) state = shift_rows(state) state = mix_columns(state) state ^= rkeys[r] states.append(state.copy()) # 最后一轮(无MixColumns) state = sub_bytes(state) state = shift_rows(state) state ^= rkeys[10] states.append(state.copy()) return states, rkeys# ========== 可视化 ==========# 准备一个明文块和密钥plain = np.array([0x32, 0x88, 0x31, 0xe0, 0x43, 0x5a, 0x31, 0x37, 0xf6, 0x30, 0x98, 0x07, 0xa8, 0x8d, 0xa2, 0x34], dtype=np.uint8)key = np.array([0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, 0x4f, 0x3c], dtype=np.uint8)states, rkeys = aes_encrypt(plain, key)# 图1:S盒热力图sbox_2d = np.array(SBOX).reshape(16,16)plt.figure(figsize=(6,6))plt.imshow(sbox_2d, cmap='magma', origin='lower', aspect='auto')plt.colorbar(label='S-box value')plt.title('AES S-box Heatmap (16×16)', fontsize=14)plt.xticks(range(16), [f'{i:x}' for i in range(16)])plt.yticks(range(16), [f'{i:x}' for i in range(16)])plt.xlabel('low nibble')plt.ylabel('high nibble')plt.tight_layout()plt.show()# 图2:加密过程状态热力图(每一轮一个子图)fig, axes = plt.subplots(1, 12, figsize=(18, 3.5))labels = ['Plain','InitKey','Round1','Round2','Round3','Round4', 'Round5','Round6','Round7','Round8','Round9','Round10']for i, (ax, st) in enumerate(zip(axes, states)): im = ax.imshow(st.T, cmap='viridis', vmin=0, vmax=255, origin='upper') ax.set_title(labels[i], fontsize=8) ax.set_xticks([]); ax.set_yticks([]) # 在每个格子上标注字节值 for (x,y), val in np.ndenumerate(st.T): ax.text(x, y, f'{val:02x}', ha='center', va='center', color='w' if val < 128 else 'k', fontsize=6)fig.suptitle('AES-128 State Evolution (Block Cipher)', fontsize=14, y=1.05)plt.tight_layout()plt.show()# 图3:轮密钥热力图fig, axes = plt.subplots(1, 11, figsize=(16, 3.5))for i, (ax, rk) in enumerate(zip(axes, rkeys)): im = ax.imshow(rk.T, cmap='plasma', vmin=0, vmax=255, origin='upper') ax.set_title(f'RoundKey{i}', fontsize=8) ax.set_xticks([]); ax.set_yticks([]) for (x,y), val in np.ndenumerate(rk.T): ax.text(x, y, f'{val:02x}', ha='center', va='center', color='w' if val < 128 else 'k', fontsize=6)fig.suptitle('AES-128 Round Keys', fontsize=14, y=1.05)plt.tight_layout()plt.show()