type
status
date
slug
summary
tags
category
icon
password
光学一些理论知识,或者浏览代码感觉对我的提升有限,再遇到一些真正要自己敲的场景总是发现自己的想法无法转化为代码,因此准备尝试在 gpt 的教学和引导下,尝试自己搭建一个个 pipeline,进行一些经典模型的复现和应用,希望能提高自己相关的能力。
导入库
首先需要导入一些必要的库
导入
torch
,然后是 torch
中的 nn
(神经网络模型) 以及 optim
(优化器)导入
torchvision
中的 datasets
(数据集,用于导入需要的 mnist 数据),transforms
(对图片预处理)导入
Dataloader
,用于加载数据加载数据集
transform
定义图像预处理的方式:这里选择用 totensor 把图片转换成 tensor 格式的数据
dataset
这里是借助 dataset 库选择想要使用的数据集
其中,需要指定根目录,用于保存数据
./data
表示下载或保存到当前目录下的 data 文件夹train
表示加载数据的类型,到底是训练集还是测试集(数据集有两部分)transform
是指定每张图像预处理的方式,这里选择最初定义好的 transformdownload
就是选择是否要从网络下载到本地(如果本地没有数据)dataloader
Dataset
只是一个“数据集合”,能按索引取数据,但不能自动分批
DataLoader
可以:- 自动按批次(batch)读取
- 自动打乱数据顺序(shuffle)
- 支持并行加载(num_workers)
- 方便用于
for x, y in loader:
迭代训练
对应两种数据,loader 也分为两种
Dataloader
这里的使用方法其实是实例化 Dataloader:指定加载的数据;设置 batchsize;选择是否 shuffle
batch_size
对于训练数据一般是 64 或者 128(大小会对训练效果有影响,具体可见李宏毅课程);对于测试数据,为了加快测试的速度,设置的会很大,这里是 1000shuffle
也就是是否要打乱:训练数据为了避免模型记住顺序,导致过拟合,要打乱;测试数据无所谓了定义模型
需要定义一个 MLP 类,从
nn.module
中继承init
这里其实就是进行一个初始化,首先使用 super 进行 nn.module 的初始化
然后根据自己模型各个 layer 的设置定义几个 function,以及非线性激活函数(这里是 relu),作为几个方法在后面调用
forward
在
nn.module
里,调用的时候好像是执行的 forward
,因此我们这里也需要相应地定义一下 forward 函数,需要接受参数 x
这里就是按顺序把各个层给用上,然后传递
x
的值即可这里的 relu 方法可以和层的变换套用(因为都是“函数”)
具体参数的设置
反正我当时作为一个小白,是不知道这里的几个 linear 层的数据咋得出的,这里来详细讲一讲
首先是
flatten
,这个其实是因为接受的是图像,因此需要把图像进行展平原本尺寸是 28x28,展平后则是 784
后面的线性隐藏层就是根据前一层的输出来决定输入即可
一般来说都是维度越来越小的,因为要实现压缩
关于怎么得知图像大小,这里可以通过从 dataset 中任取一个元素来得到:
这三个数分别表示:通道数、宽、高

设置训练设备
这里首先查看自己哪些训练设备可用,有 cuda 的话肯定优先 cuda
我这里用的是 mac,因此多加了对于 mps 的判断,mps 是苹果联合 pytorch 搞得 metal 的加速
创建模型并传入设备
这里将设置好的模型
MLP
进行实例化,并指定训练设备 device
设置损失函数
这里使用交叉熵损失来衡量 loss
正是因为这里使用了交叉熵,因此在模型的 forward 中的最后一个 linear 才没有套上 relu,因为交叉熵自带 softmax,如果套上 relu 会影响效果
🔍 为啥不能在最后一层加 ReLU?🔹
ReLU
的作用:
- 把负数变为 0,正数不变
- 是一个非线性激活函数
🔹 如果你加了 ReLU:
- 你把一部分原本可能代表“低得分”的类别(负数)都变成了 0
- 这会让 softmax 后的概率计算失真
- 交叉熵无法得到真实的类别差异,模型训练效果会很差!
定义优化器
这里使用 Adam 优化器
首先,需要设置模型需要更新的参数:这里利用 parameter 方法把模型的参数传入
然后,设置学习率,这个具体大小的设置可以看相关的笔记
定义评估函数
这里需要定义一个评估函数,用于测试模型在测试集上的情况,用于循环训练的时候调用
这里对我比较困惑的三行做一下详细的解释:
💡 作用:
output
的形状是[batch_size, 10]
,每一行代表一张图对 10 个类别的得分(logits)
torch.max(..., dim=1)
会在每一行里找最大值所在的位置,即预测的类别
💡 作用:
predicted == label
会生成一个布尔张量,比如[True, False, True, ...]
.sum()
会统计有多少个True
(预测对了的数量)
.item()
把 tensor 转成数字,才能和correct
累加
💡 作用:
label
的长度就是当前 batch 的样本数量,比如 batch_size = 64 就是 64
- 每个 batch 都加一遍,就能知道总共评估了多少张图像
训练
✅ 训练主循环逐行简洁解释:
👉 设置训练轮数(epoch)为 10,每一轮都会完整遍历一遍训练集
👉 创建两个空列表,用于记录每轮的训练损失和测试准确率,后续可用于画图
👉 进入训练主循环,总共跑
num_epochs
轮👉 切换模型到“训练模式”(启用 dropout、batchnorm 等训练行为)
👉 用于累加当前 epoch 中每个 batch 的 loss,方便后面计算平均损失
👉 遍历训练集的每个小批量(batch),每次拿一组图像 + 标签
👉 把数据和标签搬到训练设备(CPU/GPU/MPS),确保模型输入一致
👉 前向传播:模型接收输入,输出每个样本对各类别的得分(logits)
👉 用交叉熵计算预测和真实标签之间的差距(越小越好)
👉 清除上一次反向传播留下的梯度值,防止梯度累加
👉 反向传播:根据 loss 自动计算每个参数的梯度值
👉 使用优化器更新模型参数(执行一次梯度下降)
👉 累加当前 batch 的 loss,
.item()
是把 tensor 转为标量👉 当前 epoch 训练完后,使用验证函数在测试集上跑一遍,返回准确率(0~1)
👉 把本轮的平均训练损失添加到列表中,用于可视化
👉 把本轮测试准确率添加到列表中,用于可视化
👉 打印当前 epoch 的训练损失和测试准确率,便于实时观察模型训练状态
✅ 一轮训练完成后,你会得到:
- 一个模型经过训练
- 一个训练损失列表:
train_losses
- 一个测试准确率列表:
test_accuracies
- 控制台实时输出每轮的表现
- 可以画出训练曲线,或者保存模型等后续操作
画图
✅ 第一段:画训练损失曲线
👉 导入
matplotlib
的 pyplot
模块,这是 Python 中最常用的画图工具👉 创建一个新的图像窗口,设置图像大小为 6x4 英寸(适中)
👉 把训练损失
train_losses
画成一条折线图:- 横轴是 epoch(默认从 0 开始)
- 纵轴是 loss(越小越好)
label='Train Loss'
会显示图例
color='blue'
设置线条颜色
👉 设置横坐标标签为“Epoch”
👉 设置纵坐标标签为“Loss”,表示训练误差
👉 设置整个图像的标题,清晰表明这是“训练损失曲线”
👉 打开网格线,便于观察每一轮的数值位置
👉 显示图例(即“Train Loss”那一行小说明文字)
👉 显示图像(必要步骤,否则画不出来)
✅ 第二段:画测试准确率曲线(完全类似)
✅ 最终你会看到两张图:
- 训练损失曲线:
- 应该是不断下降的
- 如果 loss 不降,说明模型没学到
- 如果震荡不稳定,可能学习率太高
- 测试准确率曲线:
- 应该逐步上升
- 如果先升后降,可能开始过拟合
- 如果长时间不升,可能模型太弱或数据不足
🎯 Bonus:如何保存图像?
如果你想把图保存为文件,比如
loss.png
、accuracy.png
,可以在 plt.show()
之前加一句:保存模型
✅ PyTorch 推荐保存方式:保存模型的 state_dict
(参数字典)
✅ 一行保存模型参数:
🧠 含义:
'mlp_mnist.pth'
是你保存的文件名(可以改)
model.state_dict()
是模型所有可学习参数的集合(权重、偏置等)
保存的是“参数字典”,不是整个模型对象,灵活且更轻便
✅ 你可以在训练完成后加上这一行:
✅ 加载模型的方式(恢复训练或推理)
当你下次想用这个模型,只需:
🔍 注意:
- 你需要用相同结构的模型类(比如
MLP()
)
model.eval()
很重要!否则有些层(如 Dropout)还处于训练状态
🧠 一句话总结保存加载流程:
训练后:torch.save(model.state_dict(), 'xxx.pth')使用时:重新建模型结构 +load_state_dict()
+eval()
- 作者:昊卿
- 链接:hqhq1025.tech/article/mlp_pipeline
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章