type
status
date
slug
summary
tags
category
icon
password
这里我们会实现线性回归

3.2.1 生成数据集

这里用有噪声的线性模型构造一个人造数据集
这里normal 是用来生成满足正态分布的随机数
这里定义了一个函数,返回人造的数据集

3.2.2 读取数据集

这里我们弄一个类似于 data_loader 的东西,用于读取并遍历数据集

3.2.3 初始化模型参数

这里给出参数的初始值,从初始值开始更新参数
这里打开梯度记录,用于后续更新参数

3.2.4 定义模型

这里我们使用线性回归模型,用于基于参数进行运算

3.2.5 定义损失函数

这里使用之前的 1/2 均方损失作为 loss function
这里对 y 的预测值做一个 reshape,以确保与真实值的形状一样

3.2.6 定义优化算法

在计算的时候不保留梯度,并且在根据梯度更新参数后将梯度归零重置

3.2.7 训练

这里把前面定义好的函数以及参数赋值
循环一定的次数,计算损失,反向传播,更新参数
这里使用 backward 进行反向传播计算梯度时,需要先求和,因为 backward 只能对标量使用,而这里的 l 是一批量的损失
 
3.3 线性回归的简洁实现3.1 线性回归
Loading...