【PyTorch修炼】零基础入门PyTorch之线性回归【逐行代码讲解】

前言

本节包含入门torch的一个小例子

逐行代码视频讲解在:https://www.bilibili.com/video/BV1nS4y1u76S?spm_id_from=333.999.0.0

主要关于线性回归的例子

  • 自主创建数据

  • 建立线性回归模型

  • 完成训练过程

  • 画图展示

线性模型

94a77f1d8fdaa1c1053e4965aa4c275a.png其中k是权重, b就是偏置项。

线性模型通俗来说就是拟合其中的k和b其实就是w和bias

比如在这里

33c9ce863f6896e9431a164a90b70152.png

代码整体

模拟数据

添加高斯白噪声(符合均值为0,方差为1的正态分布的一组随机数),x设置为512个点,也就是样本的个数为512

edca9992d8ea27bd93beb8714b158d8e.png

线性模型

因为输入的每个数值和输出的每个数值其实就是维度为1,feature_num=1,线性模型为

class LinearModel(nn.Module):
    def __init__(self, in_fea, out_fea):
        super(LinearModel, self).__init__()
        self.out = nn.Linear(in_fea, out_fea)
    def forward(self, x):
        x = self.out(x)
        return x
7b87ac96e286b0cb41ffb446066491a9.png

定义损失函数和优化器

optimizer = torch.optim.SGD(model.parameters(), lr=0.02)

loss_func = nn.MSELoss()

改变数据的维度为模型输入维度

62db9f00bd1215d9e37b1dc2a7728a86.png在feature那维度加一维

训练与可视化

其中套路为

  1. 前向推理

  2. 算loss

  3. 清空梯度

  4. 反向传播

  5. 更新权重

plt.ion()
for step in range(200):
    prediction = model(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step%10 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.xlim(0,1.1)
        plt.ylim(0, 20)
        [w, b] = model.parameters()
        plt.text(0, 0.5, 'loss=%.4f, k=%.2f, b=%2f'%(loss.item(), w.item(), b.item() ),fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.5)
plt.ioff()
plt.show()

其中ion是打开交互模式,ioff是关闭交互模式,就可以动态的画图看变化了

动态的看拟合的线变换

1d0c9107f38b4beb288412bf807596cc.png 0c0a9120f2fd09e7fa31dd8c9c4878f0.png 89ab799498523a521e86d7662f67cae9.png

推荐阅读:

我的2022届互联网校招分享

我的2021总结

浅谈算法岗和开发岗的区别

互联网校招研发薪资汇总

对于时间序列,你所能做的一切.

什么是时空序列问题?这类问题主要应用了哪些模型?主要应用在哪些领域?

公众号:AI蜗牛车

保持谦逊、保持自律、保持进步

91180d846d68ee78da035e9b1ad376da.png

发送【蜗牛】获取一份《手把手AI项目》(AI蜗牛车著)

发送【1222】获取一份不错的leetcode刷题笔记

发送【AI四大名著】获取四本经典AI电子书

猜你喜欢

转载自blog.csdn.net/qq_33431368/article/details/123516036
今日推荐