PyTorch学习(四)--用PyTorch实现线性回归

教程视频:https://www.bilibili.com/video/BV1tE411s7QT

废话不多说,代码如下:

import  torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self):#构造函数
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)#构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b
    def forward(self, x):
        y_pred = self.linear(x)#可调用对象,计算y=wx+b
        return  y_pred

model = LinearModel()#实例化模型

criterion = torch.nn.MSELoss(size_average=False)
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)#lr为学习率

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

结果:

0 56.52023696899414
1 25.170454025268555
2 11.214292526245117
3 5.001270771026611
4 2.2352840900421143
5 1.0038176774978638
6 0.45547759532928467
7 0.21124869585037231
8 0.10240332782268524
9 0.05382827669382095
10 0.03208546340465546
……
90 0.004652736708521843
91 0.004585907328873873
92 0.004519954323768616
93 0.00445501459762454
94 0.004390999674797058
95 0.004327872302383184
96 0.004265678580850363
97 0.004204379860311747
98 0.004143938422203064
99 0.00408441387116909
w= 2.042545795440674
b= -0.09671643376350403
y_pred = tensor([[8.0735]])

不同优化器,他们的性能在使用上有什么区别?直接看图
以下包含了Adagrad Adam adamax ASGD RMSprop Rprop SGD七种优化器的loss下降图。其实还有一种优化器LBFGS,使用时需要传递闭包等等,我会在之后补上,暂时没有。
在这里插入图片描述

小知识点:可调用对象

如果要使用一个可调用对象,那么在类的声明的时候要定义一个 call()函数就OK了,就像这样

class Foobar:
	def __init__(self):
		pass
	def __call__(self,*args,**kwargs):
		pass

其中参数*args代表把前面n个参数变成n元组,**kwargsd会把参数变成一个词典,举个例子:

 def func(*args,**kwargs):
 	print(args)
 	print(kwargs)

#调用一下
func(1,2,3,4,x=3,y=5)

结果:
(1,2,3,4)
{‘x’:3,‘y’:5}

发布了10 篇原创文章 · 获赞 0 · 访问量 130

猜你喜欢

转载自blog.csdn.net/weixin_44841652/article/details/105068509