使用PyTorch搭建线性回归模型
使用自动求导机制和简单计算函数搭建线性回归模型
1 | import torch |
1 | # 构造数据 |
1 | # 初始化参数 |
单轮迭代模拟
1 | # 迭代一次 |
1 | tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |
1 | lr = 0.015 # 设置学习速率 |
1 | # 计算损失函数 |
1 | tensor(16.0647, grad_fn=<DivBackward0>) |
1 | # 反向传播计算梯度 |
1 | # 获取梯度 |
tensor([-7.1164])
tensor([-5.3900])
1 | # 更新参数 |
1 | print(k) |
1 | tensor([0.1067], requires_grad=True) |
1 | # 完成一个batch将梯度置零 |
tensor([0.])
循环迭代
1 | # 迭代多次 |
1 | Iter: 0, k: 0.2096, b: 0.1585, training loss: 14.8915 |
从打印出来的训练过程可以看出,迭代到400轮的时候就已经收敛,k=3.0000,b=2.0000。
上述方法仅仅使用到了pytorch中的自动求导机制进行参数训练,其实对于这些训练过程在PyTorch中还有很多方法可以替代。
例如:
- 更新参数部分,可以在提前实例化一个优化器,反向计算完梯度后直接使用optimizer.step()去更新参数。
- 参数的梯度置零,可以直接调用优化器的optimizer.zero_grad()方法去将所有参数的梯度全部置零。
调用PyTorch优化器搭建线性回归模型
1 | import torch |
定义模型
1 | class SimpleLinear(nn.Module): |
定义优化器
1 | linear = SimpleLinear() |
1 | optimizer = torch.optim.SGD(linear.parameters(), lr=0.015) |
定义损失函数
1 | loss_fun = nn.MSELoss() |
模型训练
1 | for epoch in range(500): |
1 | Epoch:0, k:-0.1556, b:-0.5560, loss:16.833160 |
PyTorch模型搭建与训练过程模拟
定义模型类
1 | class SimpleLinear: |
定义优化器类
1 | class Optimizer: |
定义损失函数类
1 | # 定义损失函数为MSE |
1 | loss_fun = LossFun() |
模型训练函数编写
1 | # 实例化模型 |
1 | Epoch:0, k:0.0915, b:0.0525, loss:12.655186 |
同样可以达到训练效果
使用PyTorch实现手写数字识别任务
1 | import math |
1 | from torchvision import datasets |
使用卷积神经网络实现手写数字识别
数据集定义
在线获取
在线获取PyTorch内置数据集 https://pytorch.org/vision/stable/datasets.html
当download参数为True时,会去校验指定的路径中是否有数据集文件,如果没有会去下载相应的数据集
1 | mnist_train_dataset = datasets.MNIST('../data/MNIST', |
1 | mnist_test_dataset = datasets.MNIST('../data/MNIST', |
1 | mnist_train_dataset[0][0].shape |
torch.Size([1, 28, 28])
1 | mnist_train_dataset[0][1] |
5
自定义DataSet类
1 | # 基本格式 |
1 | import numpy as np |
1 | class MNIST_Dataset(Dataset): |
1 | train_dataset = MNIST_Dataset('../data/self/mnist.npz', |
1 | test_dataset = MNIST_Dataset('../data/self/mnist.npz', |
1 | train_dataset[0][0].shape |
torch.Size([1, 28, 28])
1 | train_dataset[0][1] |
5
卷积神经网络搭建
1 | class Net(nn.Module): |
1 | model = Net() |
1 | ---------------------------------------------------------------- |
定义训练函数
1 | # model:模型 device:模型训练场所 optimizer:优化器 epoch:模型训练轮次 |
定义测试函数
1 | def test(model, device, test_loader, criterion): |
模型训练
1 | epochs = 10 # 迭代次数 |
1 | Train epoch 1: 60000/60000, [-------------------------------------------------->] 100% |
模型参数查看
1 | for i,v in model.named_parameters(): |
1 | conv1.weight torch.Size([6, 1, 5, 5]) |
模型检查点存储
随着模型结构的复杂化,数据量的复杂化,模型的训练所消耗的时间也会逐渐增大,这个时候就需要注意,如果我们的模型训练一半,由于一些不可控因素导致训练中断(例如:服务器资源竞争导致程序被杀死、意外断电等等),那么此时如果再从头开始训练会导致前面训练所消耗的时间和资源白白浪费掉了。
为了应对这种情况的发生,模型检查点技术就应运而生,模型检查点是在模型在每轮的训练过程中,存储模型训练中间状态的一种技术。
模型检查点一般会保存以下几个指标:
- epoch:当前训练的轮数
- step:当前轮数对应的批次数(使用频率低)
- model_state_dict:模型参数
- optimizer_state_dict:优化器参数
- loss:当前模型参数损失值
1 | save_file = f'checkpoint_{epoch}.pt' |
模型检查点读取
1 | resume = f'checkpoint_{epoch}.pt' # 指定恢复文件 |
将模型检查点引入训练
检查点的引入一般有以下两种方式:
- 固定轮数保存,例如:每5轮训练保存一次检查点。
- 指标监测保存,例如:监测loss值,当loss值有下降时存储一次,也即存储最优模型。
1 | def test(model, device, test_loader, criterion): |
1 | epochs = 10 |
1 | Train epoch 1: 60000/60000, [-------------------------------------------------->] 100% |
加载检查点继续训练
1 | epochs = 10 |
1 | loading from cnn_checkpoint_best.pt |
使用循环神经网络实现手写数字识别
数据集DataSet类重构
RNN要求数据输入的格式为3维:[batch_size, sequence_length, hidden_size],所以在数据层面不需要升维
1 | class MNIST_Dataset(Dataset): |
1 | train_dataset = MNIST_Dataset('../data/self/mnist.npz', train=True) |
1 | test_dataset = MNIST_Dataset('../data/self/mnist.npz', train=False) |
循环神经网络搭建
1 | class RNNNet(nn.Module): |
模型定义
1 | input_dim = 28 |
1 | summary(rnn_model.to('cuda'), (28, 28)) |
1 | ---------------------------------------------------------------- |
模型训练
1 | input_dim = 28 |
1 | Train epoch 1: 60000/60000, [-------------------------------------------------->] 100% |
模型继续训练
1 | input_dim = 28 |
1 | loading from rnn_checkpoint_best.pt |
模型更换为LSTM
模型定义
1 | class LSTMNet(nn.Module): |
1 | input_dim = 28 |
1 | lstm_model |
1 | LSTMNet( |
模型训练
1 | input_dim = 28 |
1 | Train epoch 1: 60000/60000, [-------------------------------------------------->] 100% |
1 | input_dim = 28 |
1 | loading from lstm_checkpoint_best.pt |
1 | input_dim = 28 |
1 | loading from lstm_checkpoint_best.pt |