type
status
date
slug
summary
tags
category
icon
password
这里我们调用深度学习框架进行简洁的实现
3.3.1 生成数据集
3.3.2 读取数据集
这里调用
dataloader
进行数据读取然后,
data_iter
便可以通过 iter
进行迭代,逐一获得 data3.3.3 定义模型
这里调用
torch.nn
通过指定层的类型 Linear
,以及输入、输出特征,通过 Sequential
把各个层串联起来3.3.4 初始化模型参数
这里通过
net[0]
访问网络中第一个图层,使用weight.data
和bias.data
方法访问参数,进行初始化3.3.5 定义损失函数
这里选择 MSELoss,也叫平方 范数
默认情况下,它返回所有样本损失的平均值。
3.3.6 定义优化算法
optim
模块中有很多优化算法,这里我们选择 SGD 算法(随机梯度下降),并指定参数和学习率3.3.7 训练
- 作者:昊卿
- 链接:hqhq1025.tech/article/%20d2l/linear-regression-concise
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。