你还在用 Python 写模型,却把性能优化交给 C++ 和 CUDA?
一条来自 Keras 创始人 François Chollet 的推文,正在撕开 AI 开发的新裂缝:现在,你不用离开 Python,就能写出逼近硬件极限的高性能算子。
这不只是“方便”,更可能是对整个 AI 工程栈的重新洗牌。
“融合”两个字,省下 80% 的内存搬运
想象一下:你的神经网络里,一个全连接层后面紧跟着 ReLU 激活函数。在传统写法中,这需要两步:
- 先做矩阵乘法(matmul),结果从高速片上内存(VMEM/SMEM)写回高带宽主存(HBM);
两次昂贵的 HBM 读写,中间还夹着一次无谓的“落地”。
而 Pallas 做的事,简单粗暴:把 matmul 和 ReLU 合成一个算子,在数据还在片上内存时,直接激活。
“Operator fusion is the process of combining two or more ops into one ‘fused’ op… to squeeze even more performance out of the TPU or GPU.”
Keras 官方文档毫不掩饰野心:XLA 编译器虽然能自动融合部分算子,但要榨干最后一滴性能,你得自己动手。
于是,他们给出了一个FusedDense层的实现——外表是普通的Keras Layer,内核却是手写的 Pallas kernel。
代码对比:优雅 vs 硬核
先看“标准版”:
def call(self, inputs): y = keras.ops.matmul(inputs, self.w) return keras.ops.relu(y)
干净、易读、符合直觉。但代价是什么?每次操作都要和 HBM 打交道。
再看 Pallas 版本的核心:
def matmul_relu_kernel(a_ref, b_eq, c_ref): acc = pl.dot(a_ref[...], b_ref[...]) result = jnp.maximum(acc, 0) # ReLU here! c_ref[...] = result
关键就这一行:jnp.maximum(acc, 0)。ReLU 不是在主存里做的,而是在 TPU 的 Matrix Unit 或 GPU 的 Tensor Core 刚算完 matmul、数据还热乎的时候,当场激活。
没有中间变量,没有 HBM 往返——数据生命周期被压缩到极致。
但等等,性能真的提升了吗?
官方 benchmark 给出了一个反直觉的结果:
- 标准 Keras(matmul + ReLU):7.811 ms
慢了近 4.5 倍?!
别急。这恰恰暴露了真相:当前示例中的 Pallas 实现,尚未针对大矩阵优化调度策略。
文档坦承:“Inputs must be multiples of 128 for this demo”。而 benchmark 用的是 8192x8192 矩阵——理论上适合分块,但若 tile 策略不当,反而引入调度开销。
但这不是失败,而是邀请。
Pallas 的真正威力,在于它给了你控制 tiling 的权力。通过 BlockSpec,你可以精确指定:“每次从 HBM 搬 128 行 A 和 128 列 B 到片上内存,算完一块写一块。”
一旦调优得当,内存带宽瓶颈将大幅缓解——尤其在 Transformer 的 MLP 层这种密集计算场景。
训练才是真正的深水区
推理快只是开始。能让模型训练的自定义算子,才算真正可用。
问题来了:JAX 无法自动对 Pallas kernel 求导。
如果你直接用 fused_matmul 建模型训练,会撞上这个错误:
“Linearization failed to produce known values for all output primals…”
因为反向传播需要梯度,而 Pallas 是黑盒。
解决方案?手动写反向传播!
Keras 团队展示了如何用 jax.custom_vjp 注册前向和后向函数:
- 前向:保存输入 x、权重 w 和输出 y 作为“残差”;
- 后向:用 g * (y > 0) 计算 ReLU 梯度,再套用标准 matmul 反传公式。
最终,FusedDenseTrainable 成功跑通训练循环——从推理到训练,闭环打通。
Python 程序员赢了!
过去,想压榨 GPU/TPU 性能,你得:
- 或啃 Triton 教程,用类 Python 语法写 kernel;
- 再用 PyBind11 封装,塞回 Python 环境。
上下文切换成本高到劝退。
而 Pallas 的野心,藏在评论区一句热评里:
“Staying in Python for kernel-level performance is a game-changer. It lowers the barrier for serious hardware hacking without the context switch pain.”
你不再需要“出 Python”去搞硬件优化。
Keras + Pallas + JAX 的组合,让你在同一个语言、同一个心智模型里,从高层模型一路钻到硬件指令。
TPU 用 Mosaic,GPU 用 Triton——底层差异被抽象掉,你只关心算法逻辑。
这不只是工具升级,是范式迁移
硬件越来越快,但内存墙(Memory Wall)越来越厚。
HBM 带宽虽高,终究有限。谁减少数据搬运,谁就赢。
Pallas把“tiling”和“fusion”这两个HPC(高性能计算)老概念,包装成 Python 函数和 BlockSpec 对象,让 ML 工程师也能玩转内存层次优化。
更妙的是,它嵌在 Keras 里——那个以“简洁”著称的框架。极简 API 背后,藏着对硬件的极致掌控。
当写 kernel 不再是 CUDA 专家的特权,当 Python 程序员也能手搓高效算子,
AI 模型的性能天花板,会不会被一群“不懂底层”的人捅破?
或许,真正的智能,从来不在语言边界之内。