前言
本节包含入门torch的一个小例子
逐行代码视频讲解在:https://www.bilibili.com/video/BV1nS4y1u76S?spm_id_from=333.999.0.0
主要关于线性回归的例子
自主创建数据
建立线性回归模型
完成训练过程
画图展示
线性模型
其中k是权重, b就是偏置项。
线性模型通俗来说就是拟合其中的k和b其实就是w和bias
比如在这里
代码整体
模拟数据
添加高斯白噪声(符合均值为0,方差为1的正态分布的一组随机数),x设置为512个点,也就是样本的个数为512
线性模型
因为输入的每个数值和输出的每个数值其实就是维度为1,feature_num=1,线性模型为
class LinearModel(nn.Module):
def __init__(self, in_fea, out_fea):
super(LinearModel, self).__init__()
self.out = nn.Linear(in_fea, out_fea)
def forward(self, x):
x = self.out(x)
return x
定义损失函数和优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.02)
loss_func = nn.MSELoss()
改变数据的维度为模型输入维度
在feature那维度加一维
训练与可视化
其中套路为
前向推理
算loss
清空梯度
反向传播
更新权重
plt.ion()
for step in range(200):
prediction = model(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.xlim(0,1.1)
plt.ylim(0, 20)
[w, b] = model.parameters()
plt.text(0, 0.5, 'loss=%.4f, k=%.2f, b=%2f'%(loss.item(), w.item(), b.item() ),fontdict={'size': 20, 'color': 'red'})
plt.pause(0.5)
plt.ioff()
plt.show()
其中ion是打开交互模式,ioff是关闭交互模式,就可以动态的画图看变化了
动态的看拟合的线变换
推荐阅读:
什么是时空序列问题?这类问题主要应用了哪些模型?主要应用在哪些领域?
公众号:AI蜗牛车
保持谦逊、保持自律、保持进步
发送【蜗牛】获取一份《手把手AI项目》(AI蜗牛车著)
发送【1222】获取一份不错的leetcode刷题笔记
发送【AI四大名著】获取四本经典AI电子书