single attention head
import mathimport torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as np
defcausal_mask(T: int, device=None):"""Returns a bool mask where True means *masked* (disallowed). Shape: (1, 1, T, T) suitable for broadcasting with (B, heads, T, T). """ m = torch.triu(torch.ones((T, T), dtype=torch.bool, device=device), diagonal=1)return m.view(1, 1, T, T)
classSingleHeadSelfAttention(nn.Module):"""Single-head attention (explicit shapes)."""def__init__(self, d_model: int, d_k: int, dropout: float = 0.0, trace_shapes: bool = False): super().__init__() self.q = nn.Linear(d_model, d_k, bias=False) self.k = nn.Linear(d_model, d_k, bias=False) self.v = nn.Linear(d_model, d_k, bias=False) self.dropout = nn.Dropout(dropout) self.trace_shapes = trace_shapesdefforward(self, x: torch.Tensor):# x: (B, T, d_model) B, T, _ = x.shape q = self.q(x) # (B,T,d_k) k = self.k(x) # (B,T,d_k) v = self.v(x) # (B,T,d_k)if self.trace_shapes: print(f"q {q.shape} k {k.shape} v {v.shape}") scale = 1.0 / math.sqrt(q.size(-1)) attn = torch.matmul(q, k.transpose(-2, -1)) * scale # (B,T,T) mask = causal_mask(T, device=x.device) attn = attn.masked_fill(mask.squeeze(1), float('-inf')) w = F.softmax(attn, dim=-1) w = self.dropout(w) out = torch.matmul(w, v) # (B,T,d_k)if self.trace_shapes: print(f"weights {w.shape} out {out.shape}")return out, w
X = np.array([[[0.1, 0.2, 0.3, 0.4], [0.5, 0.4, 0.3, 0.2], [0.0, 0.1, 0.0, 0.1]]], dtype=np.float32)Wq = np.array([[ 0.2, -0.1],[ 0.0, 0.1],[ 0.1, 0.2],[-0.1, 0.0]], dtype=np.float32)Wk = np.array([[ 0.1, 0.1],[ 0.0, -0.1],[ 0.2, 0.0],[ 0.0, 0.2]], dtype=np.float32)Wv = np.array([[ 0.1, 0.0],[-0.1, 0.1],[ 0.2, -0.1],[ 0.0, 0.2]], dtype=np.float32)
# 模拟输入torch.manual_seed(0)x = torch.tensor(X)# 初始化 decoderattn = SingleHeadSelfAttention(d_model=4, d_k=2, trace_shapes=True)# 加载权重 [ 模型参数 ]with torch.no_grad(): attn.q.weight.copy_(torch.tensor(Wq).t()) attn.k.weight.copy_(torch.tensor(Wk).t()) attn.v.weight.copy_(torch.tensor(Wv).t())# 推理出结果out, w = attn(x)print(out.shape)print(torch.isfinite(out).all())print(torch.isfinite(w).all())
q torch.Size([1, 3, 2]) k torch.Size([1, 3, 2]) v torch.Size([1, 3, 2])weights torch.Size([1, 3, 3]) out torch.Size([1, 3, 2])torch.Size([1, 3, 2])tensor(True)tensor(True)