type
status
date
slug
summary
tags
category
icon
password
这里我们会实现线性回归
3.2.1 生成数据集
这里用有噪声的线性模型构造一个人造数据集
这里
normal
是用来生成满足正态分布的随机数这里定义了一个函数,返回人造的数据集
3.2.2 读取数据集
这里我们弄一个类似于
data_loader
的东西,用于读取并遍历数据集3.2.3 初始化模型参数
这里给出参数的初始值,从初始值开始更新参数
这里打开梯度记录,用于后续更新参数
3.2.4 定义模型
这里我们使用线性回归模型,用于基于参数进行运算
3.2.5 定义损失函数
这里使用之前的 1/2 均方损失作为 loss function
这里对 y 的预测值做一个 reshape,以确保与真实值的形状一样
3.2.6 定义优化算法
在计算的时候不保留梯度,并且在根据梯度更新参数后将梯度归零重置
3.2.7 训练
这里把前面定义好的函数以及参数赋值
循环一定的次数,计算损失,反向传播,更新参数
这里使用
backward
进行反向传播计算梯度时,需要先求和,因为 backward
只能对标量使用,而这里的 l
是一批量的损失- 作者:昊卿
- 链接:hqhq1025.tech/article/d2l/linear-regression-scratch
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。
相关文章