前面几篇文章,已经基本完全的介绍了模型定义的各种细节,包括模型定义、损失反向传播、权重参数更新等。但我们使用 Sequential 快速搭建的网络模型,只能处理简单的业务,如果碰到复杂的业务场景,继承 nn.Module 自定义模型处理类,将会是更好的选择。
代码示例
import torch import torch.nn as nn # 用类定义模型 class Net(nn.Module): def __init__(self, D_in, H, D_out): super().__init__() self.linear1 = nn.Linear(D_in, H) self.relu = nn.ReLU(H) self.linear2 = nn.Linear(H, D_out) # 前向传播 def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return x N, D_in, H, D_out = 64, 1000, 100, 10 x = torch.randn(N, D_in) y = torch.randn(N, D_out) model = Net(D_in, H, D_out) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) loss_fn = nn.MSELoss(reduction='sum') for i in range(500): y_hat = model(x) loss = loss_fn(y_hat, y) print(i, loss.item()) loss.backward() optimizer.step() optimizer.zero_grad()
本文为 陈华 原创,欢迎转载,但请注明出处:http://ichenhua.cn/read/311