深度学习中,RNN网络在理解上有一些难度,本文以最简单的LSTM模型,实现MNIST数字识别,来帮助大家理解RNN的模型参数。因为基础的RNN模型在案例中表现不佳,故使用改进版的LSTM模型。
代码示例
1、加载数据集
from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import DataLoader mnist = MNIST('./mnist/', train=True, transform=transforms.ToTensor(), download=True) loader = DataLoader(mnist, batch_size=100, shuffle=False) # for train_x,train_y in loader: # print(train_x.shape) # torch.Size([100, 1, 28, 28]) # print(train_y.shape) # torch.Size([100]) # exit()
2、定义模型
import torch.nn as nn class Module(nn.Module): def __init__(self): super().__init__() self.rnn = nn.LSTM(28, 32, batch_first=True) self.out = nn.Linear(32, 10) def forward(self, x): x, hn = self.rnn(x) x = self.out(x[:, -1, :]) return x module = Module() # print(module)
3、模型训练
import torch loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(module.parameters(), lr=0.05) for epoch in range(100): for i, (x, y) in enumerate(loader): # batch, timestamp, input x = x.reshape(-1, x.shape[2], x.shape[3]) y_seq = module(x) loss = loss_fn(y_seq, y) optimizer.zero_grad() loss.backward() optimizer.step() # 打印提示信息 if i%50 == 0: y_hat = torch.argmax(y_seq, dim=1) accuracy = (y_hat == y).sum() / len(y_hat) print('epoch', epoch, 'loss:', loss.item(), 'accuracy:', accuracy)
4、模型测试
test_mnist = MNIST('./mnist/', train=False, transform=transforms.ToTensor(), download=True) test_loader = DataLoader(mnist, batch_size=100, shuffle=False) for i, (x, y) in enumerate(loader): # batch, timestamp, input x = x.reshape(-1, x.shape[2], x.shape[3]) y_seq = module(x) loss = loss_fn(y_seq, y) # 打印提示信息 if i%50 == 0: y_hat = torch.argmax(y_seq, dim=1) accuracy = (y_hat == y).sum() / len(y_hat) print('loss:', loss.item(), 'accuracy:', accuracy)
本项目的目的,是理解RNN网络的输入和输出参数,对于中间state,会在后续项目中补充。
本文为 陈华 原创,欢迎转载,但请注明出处:http://ichenhua.cn/read/248
- 下一篇:
- 手写AI算法之KMeans聚类算法