(一)PyTorch搭建神经网络

版权声明:找不到大腿的时候,让自己变成大腿. https://blog.csdn.net/Xin_101/article/details/88744647

1 神经网络结构

import numpy as np
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F

from matplotlib.font_manager import FontProperties
from mpl_toolkits.mplot3d import Axes3D
# Ubuntu system font path
font = FontProperties(fname='/usr/share/fonts/truetype/arphic/ukai.ttc')
PI = np.pi
class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        '''
        Nerual Network Structure Parameters.
        
        :param N: Input data batch size, namely input group.
        :param D_in: Number of ervery group data.
        :param H: Hidden layer node number.
        :param D_out: Output data shape.
        '''
        super(DynamicNet, self).__init__()
        '''weights and bias between input layer and hidden layer.'''
        self.input_linear = torch.nn.Linear(D_in, H)
        '''weights and bias between hidden layer and output layer.'''
        self.output_linear = torch.nn.Linear(H, D_out)
    def forward(self, x):
        '''Relu between hidden layer and output layer.'''
        h_relu = self.input_linear(x).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred
'''Nerual Network Structure Parameters.'''
N, D_in, H, D_out = 100, 1, 100, 1
'''Input datas.'''
x = torch.unsqueeze(torch.linspace(-PI, PI, 100), dim=1)
'''Real output datas.'''
y = torch.sin(x) + 0.2 * torch.rand(x.size())
'''Size of input and out put datas.'''
print("x size: {}".format(x.size))
print("y size: {}".format(y.size))
plt.figure(figsize=(8, 8))
plt.scatter(x.numpy(), y.numpy(), marker='*', s=10, label="$y=sin(x)$")
plt.legend()
plt.grid()
plt.title("理论图", fontproperties=font)
plt.show()
'''Initialize Nerual Network structure parameters.'''
model = DynamicNet(D_in, H, D_out)
'''Loss function criterion.'''
criterion = torch.nn.MSELoss(reduction="sum")
'''Optimizer NN parameters with SGD algorithm.'''
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

losses = []
for t in range(100):
    '''Calculate the prediction value based on input data x.'''
    y_pred = model(x)
    '''Compute loss between prediction and real value.'''
    loss = criterion(y_pred, y)
#     print("第{}轮训练, 损失值:{}".format(t, loss.item()))
    losses.append(loss.item())
    '''Zero all of the gradients for the varibles it will update.'''
    optimizer.zero_grad()
    '''Backward pass: compute gradient of the loss with respect to the model parameters.'''
    loss.backward()
    '''Update parameters of the optimizer.'''
    optimizer.step()
    
    if t == 99:
        plt.figure(figsize=(8, 8))
        plt.subplot(2, 2, 1)
        plt.scatter(x.numpy(), y.numpy(), marker='*', s=10, label="$y=sin(x)$")
        plt.legend()
        plt.grid()
        plt.xlabel("x轴", fontproperties=font)
        plt.ylabel("y轴", fontproperties=font)
        plt.title("理论图", fontproperties=font)
        plt.subplots_adjust(wspace=0.5, hspace=0.3)
        
        print("第{}轮训练, 损失值:{}".format(t, loss.item()))
        plt.subplot(2, 2, 2)
        plt.scatter(x.numpy(), y.numpy(), label="理论值", marker="+", s=10)
        plt.plot(x.numpy(), y_pred.detach().numpy(), 'r-', label="预测值")
        plt.legend(prop=font)
        plt.xlabel("x轴", fontproperties=font)
        plt.ylabel("y轴", fontproperties=font)
        plt.title("预测图", fontproperties=font)
        plt.subplots_adjust(wspace=0.5, hspace=0.2)
        plt.grid()
        
        plt.subplot(2, 2, 3)
        plt.plot(losses, label="Loss")
        plt.title("损失图", fontproperties=font)
        plt.legend()
        plt.grid()
        plt.subplots_adjust(wspace=0.5, hspace=0.3)
        plt.xlabel("x轴", fontproperties=font)
        plt.ylabel("y轴", fontproperties=font)
        plt.savefig("./images/train_use_directly.png", format="png")
        plt.show()

2 结果

在这里插入图片描述

图2.1 训练结果

3 总结

(1) 神经网络搭建遵循网络结构,结构参见:(一)Tensorflow搭建神经网络.
(2) PyTorch框架类似于Tensorflow框架,不同的是神经网络结构的输出,如权重维度和偏置维度与正常的维度是转置关系.


[参考文献]
[1]https://pytorch.org/docs/stable/torch.html
[2]https://pytorch.org/docs/stable/tensors.html
[3]https://pytorch.org/docs/stable/nn.html
[4]https://pytorch.org/docs/stable/optim.html


猜你喜欢

转载自blog.csdn.net/Xin_101/article/details/88744647