import torchimport torch.nn as nnimport numpy as npimport matplotlib.pyplot as plt# 1. 定义神经网络 (MLP)class PINN(nn.Module): def __init__(self): super(PINN, self).__init__() # 一个简单的全连接网络:输入1维(x) -> 隐藏层 -> 输出1维(u) self.net = nn.Sequential( nn.Linear(1, 20), nn.Tanh(), nn.Linear(20, 20), nn.Tanh(), nn.Linear(20, 20), nn.Tanh(), nn.Linear(20, 1) ) def forward(self, x): return self.net(x)# 2. 定义物理机制 (计算 PDE 的残差)def physics_loss(model, x): # 关键:需要开启 x 的梯度追踪,以便对 x 求导 x.requires_grad = True # 网络预测 u u = model(x) # 利用自动微分计算 du/dx # grad_outputs=torch.ones_like(u) 是标量求导的标准写法 u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True)[0] # 定义物理方程残差: f = du/dx - 2x # 理想情况下 f 应该为 0 residual = u_x - 2 * x # 返回残差的均方误差 return torch.mean(residual ** 2)# 3. 数据准备与训练model = PINN()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 边界条件数据: u(0) = 1x_boundary = torch.tensor([[0.0]], dtype=torch.float32)u_boundary_true = torch.tensor([[1.0]], dtype=torch.float32)# 物理方程的采样点 (Collocation Points)# 在 [0, 2] 区间内随机采样,用于训练物理约束x_physics = torch.linspace(0, 2, 100).view(-1, 1).requires_grad_(True)# 训练循环epochs = 2000for epoch in range(epochs): optimizer.zero_grad() # A. 计算数据损失 (Boundary Condition Loss) u_boundary_pred = model(x_boundary) loss_data = torch.mean((u_boundary_pred - u_boundary_true) ** 2) # B. 计算物理损失 (Physics Loss) loss_phys = physics_loss(model, x_physics) # C. 总损失 loss = loss_data + loss_phys loss.backward() optimizer.step() if epoch % 200 == 0: print(f"Epoch {epoch}: Loss = {loss.item():.5f} (Data: {loss_data.item():.5f}, Phys: {loss_phys.item():.5f})")# 4. 验证结果with torch.no_grad(): test_x = torch.linspace(0, 2, 50).view(-1, 1) pred_u = model(test_x) true_u = test_x ** 2 + 1 # 解析解 # 此时你可以绘图对比 pred_u 和 true_u,通常会非常吻合 plt.plot(test_x, pred_u, label='Predicted u(x)') plt.plot(test_x,true_u, label='True u(x)') plt.xlabel('x') plt.ylabel('u(x)') plt.legend() plt.grid(True) plt.show() # %%# 随便给个数,比如 x = 1.5x_test = torch.tensor([[1.5]], dtype=torch.float32)# 让模型预测u_pred = model(x_test)# 打印结果print(f"模型预测 u(1.5) = {u_pred.item():.4f}")print(f"真实解析解 u(1.5) = {1.5**2 + 1:.4f}")