莫烦python 关系拟合 (回归)

import torch
import torch.nn.functional as F  #包含激励函数
import matplotlib.pyplot as plt  #画图工具包
a = torch.linspace(-1, 1, 100)  #linspace()函数从(-1,1)的区间均匀地取100个值,返回一个一维张量 torch.Size([100])
x = torch.unsqueeze(a, dim=1)  #使用unsqueeze()函数能增加a的维度,原本a是一维的,现在x是二维 torch.Size([100, 1])
y = x.pow(2)+0.2*torch.rand(x.size())  #加号后的部分作用是增加数据噪声

print(a.shape)
print(x.shape)
print(y.shape)
#plt.scatter(x.data.numpy(),y.data.numpy())   #scatter()画散点图,plot()画连续图
#plt.show()


class Net(torch.nn.Module):    #继承torch的Module
    def __init__(self, n_feature, n_hidden, n_output):
      super(Net, self).__init__()  #必须继承_init_()函数
      #定义网络的结构
      self.hidden = torch.nn.Linear(n_feature, n_hidden)  #nn.Linear(input_size,output_size) 输入节点数&输出节点数
      self.predict = torch.nn.Linear(n_hidden, n_output)

    def forward(self, x):
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x


net = Net(n_feature=1, n_hidden=10, n_output=1)

print(net)

para = list(net.parameters())
print(para)

optimizer = torch.optim.SGD(net.parameters(), lr=0.2) #net的参数是何时初始化的?参数里有什么值
loss_func = torch.nn.MSELoss()

plt.ion()
plt.show()

for t in range(100):
    prediction = net(x)  #net虽然是一个object,但是也可以当作函数使用,具体可以查看_call_(),给net输入训练数据集x,输出预测数据
    loss = loss_func(prediction, y)  #计算预测数据prediction和实际数值y之间的误差
    print(loss)
    optimizer.zero_grad() #梯度初始化为零,网络训练,loss.backward()会把梯度累计
    loss.backward()  #误差反向传播, 计算参数更新值
    optimizer.step() #将参数更新值施加到 net 的 parameters 上

    if t % 5 == 0:
        plt.clf()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)

plt.ioff()
plt.show()

其实原理还不是很明白,有以下几个疑问

1、self.hidden和self.prediction都是对象,但在forward()函数中也被当作函数,这和net可以net(x)是一个道理吗?也是因为_call_函数?

2、网络的权重等参数是什么时候初始化的?在net对象创建时就默认初始化了吗?

3、loss和optimizer之间的联系?

发布了52 篇原创文章 · 获赞 6 · 访问量 9015

猜你喜欢

转载自blog.csdn.net/PMPWDF/article/details/98191135
今日推荐