PyTorch 已成为当前人工智能领域最具影响力的深度学习框架之一,在学术研究和工业应用中都占据着重要地位。其以 Python 为主要编程接口,提供灵活直观的模型构建方式,使研究人员能够快速实现和验证新的算法思想;与此同时,PyTorch 的核心计算、算子实现以及自动求导引擎主要依托高性能的 C/C++ 后端完成,从而在保持开发效率的同时保证计算性能与系统扩展能力。正是这种“Python 前端表达 + C/C++ 后端执行”的架构,使 PyTorch 在大规模模型训练、高性能张量计算以及跨硬件平台支持方面表现出强大的能力。因此,深入探索 PyTorch 中 Python 与 C 之间的接口机制,不仅有助于理解其高效执行与灵活编程并存的系统设计思想,也为分析框架内部算子调用流程、性能优化路径以及自定义算子扩展机制奠定了重要基础。
import torchimport torch.nn as nnimport torch.optim as optim# ----------------------------# 1. Create training data# ----------------------------# Input: [x, y]# Target: z = x^2 + y^3torch.manual_seed(0)X = torch.randn(10000, 2) # 1000 samples, (x, y)y = X[:, 0]**2 + X[:, 1]**3 # true function
第11行
最终调用的是C++的函数:
#0 at::native::AVX512::pow_tensor_scalar_optimized_kernel<float, double, double>(at::TensorIteratorBase&, double)::{lambda(float)#2}::operator()(float) const (__closure=0x7fffffff87be, base=-1.1523602) at /home/zzy/Program/Torch/pytorch/aten/src/ATen/native/cpu/PowKernel.cpp:6767 return base * base * base;(gdb) list62 [](Vec base) -> Vec { return base * base; }63 );64 } else if (exp == 3.0) {65 cpu_kernel_vec(iter,66 [](scalar_t base) -> scalar_t {67 return base * base * base;68 },69 [](Vec base) -> Vec { return base * base * base; }70 );
这段代码来自 PyTorch 的 CPU pow 算子优化实现(PowKernel.cpp),用于处理 张量元素按标量指数求幂 的特殊情况优化。当指数 exp == 3.0 时,代码不会调用通用的 pow() 函数,而是通过 cpu_kernel_vec 注册一个专门的计算内核:对每个输入元素 base 直接计算 base * base * base。这种实现利用简单的乘法替代通用幂函数,可以减少函数调用开销并提高性能。同时,该内核同时提供了 标量版本([](scalar_t base))和 向量化版本([](Vec base)),后者利用 AVX512 等 SIMD 指令一次处理多个数据元素,从而进一步提升 CPU 并行计算效率。上面 gdb 中看到的调用栈正是在执行这个 针对 exp=3 的优化路径中的标量 lambda 函数。
在执行链上,有如下调用关系:
Python API (torch.pow / Tensor.pow) ↓Python wrapper(_handle_torch_function...) ↓ATen dispatcher ↓C++ kernel(如 PowKernel.cpp 中的 AVX512 优化实现)
最上层的Python API就是
最下层的C++ kernel就是上面的C++代码段,而中间的Python wrapper层的实现在pytorch/torch/_tensor.py:
30 def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): 31 assigned = functools.WRAPPER_ASSIGNMENTS 32 33 @functools.wraps(f, assigned=assigned) 34 def wrapped(*args, **kwargs): 35 try: 36 # See https://github.com/pytorch/pytorch/issues/75462 37 if has_torch_function(args): 38 return handle_torch_function(wrapped, args, *args, **kwargs) 39 return f(*args, **kwargs) 40 except TypeError: 41 return NotImplemented 42 43 return wrapped
这段代码的作用可以简单理解为:在调用 PyTorch 算子之前做一次“检查和转发”。它会先看看传入的参数里有没有实现了 __torch_function__ 的特殊 Tensor 类型,如果有,就把算子的执行交给这个自定义类型来处理;如果没有,就按正常流程调用原来的函数继续执行。如果过程中出现类型不匹配的错误,它会返回 NotImplemented,让 Python 继续尝试其他可能的实现。换句话说,它是 PyTorch 在 Python 层实现算子可扩展和类型兼容的一道“分发入口”。
这段语句是如何实现的呢?
我首先修改_tensor.py以便调试
--- a/torch/_tensor.py+++ b/torch/_tensor.py@@ -7,7 +7,7 @@ from collections import OrderedDict from copy import deepcopy from numbers import Number from typing import Any, Callable, cast, Optional, Union-+import pdb import torch import torch._C as _C from torch._namedtensor_internals import (@@ -29,7 +29,7 @@ from torch.overrides import ( def _handle_torch_function_and_wrap_type_error_to_not_implemented(f): assigned = functools.WRAPPER_ASSIGNMENTS-+ pdb.set_trace() @functools.wraps(f, assigned=assigned) def wrapped(*args, **kwargs): try:
然后我在gdb中运行python
gdb Python-3.13.2/build/python (gdb) r -B ./learn1.py
运行过程中,会首先触及pdb的断点:
/home/zzy/Program/Torch/pytorch/torch/_tensor.py(102)<module>()-> class Tensor(torch._C.TensorBase): /home/zzy/Program/Torch/pytorch/torch/_tensor.py(1096)Tensor()-> @_handle_torch_function_and_wrap_type_error_to_not_implemented> /home/zzy/Program/Torch/pytorch/torch/_tensor.py(32)_handle_torch_function_and_wrap_type_error_to_not_implemented()-> pdb.set_trace()
这段信息的意思是:在 PyTorch 中定义 Tensor 类的方法时,这些方法被装饰器 _handle_torch_function_and_wrap_type_error_to_not_implemented 修饰。装饰器的作用是在方法真正执行之前先进行一层包装处理:当调用该方法时,实际上会先进入装饰器生成的 wrapped 函数,由它先检查参数中是否存在需要通过 __torch_function__ 机制接管算子的特殊对象,如果有就把执行权转交给它,否则再调用原始的方法实现。因此,这里的调用关系体现的是 Python 装饰器机制:方法调用先经过装饰器包装的函数,再决定是否执行原始函数。
gdb中执行Ctrl+C来看以上机制背后的C/C++级别的翻译过程:
#9 0x000055555569eb33 in PyOS_Readline (sys_stdin=0x7ffff7e038e0 <_IO_2_1_stdin_>, sys_stdout=0x7ffff7e045c0 <_IO_2_1_stdout_>, prompt=prompt@entry=0x7fffd49a0160 "(Pdb) ") at ../Parser/myreadline.c:412#10 0x00005555557ec4b6 in builtin_input_impl (module=module@entry=0x7ffff7ba6930, prompt=0x7fffd493e200) at ../Python/bltinmodule.c:2311#11 0x00005555557ec73c in builtin_input (module=0x7ffff7ba6930, args=args@entry=0x7ffff7fb2ca8, nargs=nargs@entry=1) at ../Python/clinic/bltinmodule.c.h:1014#12 0x0000555555712b7e in cfunction_vectorcall_FASTCALL (func=0x7ffff7ba7290, args=0x7ffff7fb2ca8, nargsf=<optimized out>, kwnames=<optimized out>) at ../Objects/methodobject.c:425#13 0x00005555556bc1bb in _PyObject_VectorcallTstate (tstate=0x555555bd5cc0 <_PyRuntime+299040>, callable=0x7ffff7ba7290, args=0x7ffff7fb2ca8, nargsf=9223372036854775809, kwnames=0x0) at ../Include/internal/pycore_call.h:168#14 0x00005555556bc297 in PyObject_Vectorcall (callable=<optimized out>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ../Objects/call.c:327#15 0x00005555557f453b in _PyEval_EvalFrameDefault (tstate=0x555555bd5cc0 <_PyRuntime+299040>, frame=0x7ffff7fb2c20, throwflag=0) at ../Python/generated_cases.c.h:813#16 0x0000555555817276 in _PyEval_EvalFrame (tstate=<optimized out>, frame=<optimized out>, throwflag=<optimized out>) at ../Include/internal/pycore_ceval.h:119#17 0x0000555555812674 in _PyEval_Vector (tstate=0x555555bd5cc0 <_PyRuntime+299040>, func=0x7fffd49b9310, locals=0x0, args=0x7fffffff14c0, argcount=4, --Type <RET> for more, q to quit, c to continue without paging-- kwnames=0x0) at ../Python/ceval.c:1814#18 0x00005555556bbe52 in _PyFunction_Vectorcall (func=<optimized out>, stack=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>) at ../Objects/call.c:413#19 0x00005555556befea in _PyObject_VectorcallTstate (tstate=tstate@entry=0x555555bd5cc0 <_PyRuntime+299040>, callable=callable@entry=0x7fffd49b9310, args=args@entry=0x7fffffff14c0, nargsf=nargsf@entry=4, kwnames=kwnames@entry=0x0) at ../Include/internal/pycore_call.h:168#20 0x00005555556bf162 in method_vectorcall (method=<optimized out>, args=0x7fffffff1570, nargsf=<optimized out>, kwnames=0x0) at ../Objects/classobject.c:92#21 0x00005555558a644d in _PyObject_VectorcallTstate (tstate=0x555555bd5cc0 <_PyRuntime+299040>, callable=0x7fffd49599d0, args=args@entry=0x7fffffff1570, nargsf=nargsf@entry=3, kwnames=kwnames@entry=0x0) at ../Include/internal/pycore_call.h:168#22 0x00005555558a6927 in call_trampoline (tstate=tstate@entry=0x555555bd5cc0 <_PyRuntime+299040>, callback=<optimized out>, frame=frame@entry=0x7fffd4a665d0, what=<optimized out>, arg=<optimized out>) at ../Python/sysmodule.c:1028#23 0x00005555558a7e34 in trace_trampoline (self=<optimized out>, frame=0x7fffd4a665d0, what=<optimized out>, arg=<optimized out>) at ../Python/sysmodule.c:1064#24 0x0000555555878d7f in sys_trace_instruction_func (self=0x7fffd4abbbc0, args=<optimized out>, nargsf=<optimized out>, kwnames=0x0) at ../Python/legacy_tracing.c:297#25 0x000055555587187d in _PyObject_VectorcallTstate (tstate=tstate@entry=0x555555bd5cc0 <_PyRuntime+299040>, callable=0x7fffd4abbbc0, args=args@entry=0x7fffffff1678, nargsf=nargsf@entry=9223372036854775810, kwnames=kwnames@entry=0x0) at ../Include/internal/pycore_call.h:168#26 0x0000555555871b19 in call_one_instrument (interp=interp@entry=0x555555ba6470 <_PyRuntime+104400>, tstate=tstate@entry=0x555555bd5cc0 <_PyRuntime+299040>, args=args@entry=0x7fffffff1678, nargsf=nargsf@entry=9223372036854775810, tool=<optimized out>, event=event@entry=6) at ../Python/instrumentation.c:907#27 0x0000555555874836 in _Py_call_instrumentation_instruction (tstate=0x555555bd5cc0 <_PyRuntime+299040>, frame=<optimized out>, instr=<optimized out>) at ../Python/instrumentation.c:1372#28 0x0000555555800a04 in _PyEval_EvalFrameDefault (tstate=0x555555bd5cc0 <_PyRuntime+299040>, frame=0x7ffff7fb28f8, throwflag=0) at ../Python/generated_cases.c.h:3332#29 0x0000555555817276 in _PyEval_EvalFrame (tstate=<optimized out>, frame=<optimized out>, throwflag=<optimized out>) at ../Include/internal/pycore_ceval.h:119#30 0x0000555555812674 in _PyEval_Vector (tstate=0x555555bd5cc0 <_PyRuntime+299040>, func=0x7fffd4844d10, locals=0x7fffd4968ad0, args=0x0, argcount=0, kwnames=0x0) at ../Python/ceval.c:1814#31 0x00005555557ee369 in builtin___build_class__ (self=<optimized out>, args=args@entry=0x7ffff7fb27e0, nargs=nargs@entry=3, kwnames=kwnames@entry=0x0) at ../Python/bltinmodule.c:202
(gdb) f 31#31 0x00005555557ee369 inbuiltin___build_class__ (self=<optimized out>, args=args@entry=0x7ffff7fb27e0, nargs=nargs@entry=3, kwnames=kwnames@entry=0x0) at ../Python/bltinmodule.c:202202 cell = _PyEval_Vector(tstate, (PyFunctionObject *)func, ns, NULL, 0, NULL);(gdb) p name$2 = (PyObject *) 0x7ffff77f5490(gdb) p (char *)((void *)name + 0x28)$3 = 0x7ffff77f54b8 "Tensor"
这段 gdb 信息的意思是:Python 解释器正在执行 class Tensor(...) 这条类定义语句,并通过内置函数 __build_class__ 来创建这个类对象。
追溯Python解释器对以上语句的编译过程:
#31 0x00005555557ee369 in builtin___build_class__ (self=<optimized out>, args=args@entry=0x7ffff7fb27e0, nargs=nargs@entry=3, kwnames=kwnames@entry=0x0) at ../Python/bltinmodule.c:202202 cell = _PyEval_Vector(tstate, (PyFunctionObject *)func, ns, NULL, 0, NULL);(gdb) b 202Breakpoint 1 at 0x5555557ee34d: file ../Python/bltinmodule.c, line 202.(gdb) p name$4 = (PyObject *) 0x7ffff77f5490(gdb) condition 1 name==0x7ffff77f5490
(gdb) x/10gx frame->instr_ptr0x555559413cb8: 0x0172005c0095005e 0x03720153027200530x555559413cc8: 0x0253055c045c0025 0x001a0353001100270x555559413cd8: 0x0772001a04530672 0x06530872001a05530x555559413ce8: 0x001a07530972001a 0x0b72001a08530a720x555559413cf8: 0x0a530c72001a0953 0x001a0c53012e0b53(gdb) watch *(unsigned long *)0x555559413cb8Hardware watchpoint 2: *(unsigned long *)0x555559413cb8(gdb) condition 2 *(unsigned long *)0x555559413cb8==0x0172005c0095005e
我来到了
(gdb) bt
#0 __memcpy_evex_unaligned_erms () at ../sysdeps/x86_64/multiarch/memmove-vec-unaligned-erms.S:505#1 0x00005555556c14ba in memcpy (__len=1424, __src=<optimized out>, __dest=0x555559413cb8) at /usr/include/x86_64-linux-gnu/bits/string_fortified.h:29#2 init_code (co=co@entry=0x555559413bf0, con=con@entry=0x7fffffff4cf0) at ../Objects/codeobject.c:524#3 0x00005555556c25ba in _PyCode_New (con=con@entry=0x7fffffff4cf0) at ../Objects/codeobject.c:688#4 0x00005555557e04a1 in makecode (umd=umd@entry=0x555559298b88, a=a@entry=0x7fffffff4dd0, const_cache=const_cache@entry=0x7fffd4b7c110, constslist=constslist@entry=0x7fffd4995630, maxdepth=maxdepth@entry=22, nlocalsplus=nlocalsplus@entry=1, code_flags=0, filename=0x7fffd4b53a40) at ../Python/assemble.c:623#5 0x00005555557e0698 in _PyAssemble_MakeCodeObject (umd=0x555559298b88, const_cache=0x7fffd4b7c110, consts=0x7fffd4995630, maxdepth=22, instrs=0x7fffffff4ec0, nlocalsplus=1, code_flags=0, filename=0x7fffd4b53a40) at ../Python/assemble.c:754#6 0x000055555583cbd7 in optimize_and_assemble_code_unit (u=0x555559298810, const_cache=0x7fffd4b7c110, code_flags=0, filename=0x7fffd4b53a40) at ../Python/compile.c:7678#7 0x000055555583ccae in optimize_and_assemble (c=0x7fffd4aed740, addNone=1) at ../Python/compile.c:7705#8 0x0000555555822372 in compiler_class_body (c=0x7fffd4aed740, s=0x5555594843a8, firstlineno=102) at ../Python/compile.c:2606#9 0x0000555555822770 in compiler_class (c=0x7fffd4aed740, s=0x5555594843a8) at ../Python/compile.c:2668#10 0x000055555582a81b in compiler_visit_stmt (c=0x7fffd4aed740, s=0x5555594843a8) at ../Python/compile.c:4047
这段调用栈表示:Python 解释器正在把 class Tensor 的类体代码编译成可执行的 Python 字节码(code object)。具体过程可以通俗理解为:当解释器读到 class Tensor(...) 这样的类定义时,并不会直接执行,而是先经过 编译阶段。在这个阶段,解释器会把类体里的 Python 代码(例如方法定义、装饰器等)转换成内部的 code object。调用栈里 compiler_class_body → optimize_and_assemble → makecode → _PyCode_New → init_code → memcpy 就是这个过程的实现路径:编译器先生成指令,再组装成 code object,最后通过 memcpy 把字节码等数据拷贝到新创建的代码对象里。因此,这段栈信息说明 Python 解释器此时正处于类定义的编译阶段,正在为 Tensor 类生成对应的字节码结构,而还没有真正开始执行类中的方法逻辑。
来到第10个调用栈
#10 0x000055555582a81b in compiler_visit_stmt (c=0x7fffd4aed740, s=0x5555594843a8) at ../Python/compile.c:40474047 return compiler_class(c, s);(gdb) p (char *)((void *)s->v.FunctionDef.name + 0x28)$17 = 0x7ffff77f54b8 "Tensor"(gdb) p s->lineno$18 = 102(gdb) p s->end_lineno$19 = 1783
这段信息说明 Python 编译器正在处理源码中的 Tensor 类定义。在 compiler_visit_stmt 函数中,解释器发现当前语句是一个类定义,于是调用 compiler_class 来编译这个类。打印出的 s->v.FunctionDef.name 显示当前处理的类名是 "Tensor",而 s->lineno = 102、s->end_lineno = 1783 表明这个类在源文件中从 第 102 行开始一直到第 1783 行结束。通俗来说,就是 Python 解释器正在把 class Tensor(...) 这一大段类代码编译成内部的字节码结构,为后续创建 Tensor 类对象做准备。
Anthropic CEO Dario Amodei 指出,在人工智能快速发展的时代,人类的判断力与价值决策能力仍然不可替代。随着 AI 能自动完成越来越多技术性任务,人类的核心价值将更多体现在批判性思维、伦理判断以及对 AI 行为的监督上。在一个“AI 可以生成几乎任何内容”的世界里,基本的批判性思维和判断能力可能成为最重要的能力之一[1]
我认为坚持学习是人类保持判断力的重要手段!
[1] https://www.justearthnews.com/economy-details/1100/coding-at-risk-anthropic-ceo-dario-amodei-says-human-centric-roles-may-last-longer.html