六、PyTorch 深度学习 处理多维特征的输入

07处理多维特征的输入

来源:B站 刘二大人

源代码:

import numpy as np
import torch
import matplotlib.pyplot as plt

# prepare dataset
xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])  # 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要
y_data = torch.from_numpy(xy[:, [-1]])  # [-1] 最后得到的是个矩阵


# design model using class
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)  # 输入数据x的特征是8维,x有8个特征
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()  # 将其看作是网络的一层,而不是简单的函数使用

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))  # y hat
        return x


model = Model()

# construct loss and optimizer
criterion = torch.nn.BCELoss(size_average = True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

epoch_list = []
loss_list = []
# training cycle forward, backward, update
# forward
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())
# backward
    optimizer.zero_grad()
    loss.backward()
# update
    optimizer.step()

plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

结果:

0 0.7537468671798706
1 0.7436003684997559
2 0.7344141602516174
3 0.7260976433753967
4 0.7185685634613037
5 0.7117521166801453
6 0.705579936504364
7 0.6999905109405518
8 0.6949279308319092
9 0.6903417706489563
10 0.6861862540245056
11 0.6824199557304382
12 0.6790056824684143
13 0.6759095788002014
14 0.6731014847755432
15 0.6705536842346191
16 0.6682416200637817
17 0.6661427617073059
18 0.6642369627952576
19 0.6625060439109802
20 0.6609333753585815
21 0.6595042943954468
22 0.6582051515579224
23 0.6570240259170532
24 0.6559497714042664
25 0.6549724340438843
26 0.6540831923484802
27 0.6532738208770752
28 0.6525369882583618
29 0.65186607837677
30 0.6512549519538879
31 0.6506983041763306
32 0.650191068649292
33 0.6497288346290588
34 0.649307370185852
35 0.6489231586456299
36 0.6485728025436401
37 0.6482532024383545
38 0.6479616761207581
39 0.6476957201957703
40 0.6474528312683105
41 0.6472313404083252
42 0.6470290422439575
43 0.6468443870544434
44 0.646675705909729
45 0.6465216279029846
46 0.6463809013366699
47 0.6462522745132446
48 0.6461347341537476
49 0.6460273265838623
50 0.645929217338562
51 0.6458394527435303
52 0.645757257938385
53 0.6456822752952576
54 0.6456134915351868
55 0.6455506086349487
56 0.6454931497573853
57 0.6454404592514038
58 0.6453922390937805
59 0.6453481912612915
60 0.645307719707489
61 0.6452706456184387
62 0.6452367305755615
63 0.6452056169509888
64 0.6451771259307861
65 0.6451509594917297
66 0.6451268792152405
67 0.6451048851013184
68 0.6450846791267395
69 0.6450660228729248
70 0.645048975944519
71 0.6450332999229431
72 0.6450188159942627
73 0.6450055241584778
74 0.644993245601654
75 0.6449820399284363
76 0.6449716687202454
77 0.6449620127677917
78 0.644953191280365
79 0.6449450254440308
80 0.6449373960494995
81 0.6449304223060608
82 0.644923985004425
83 0.644917905330658
84 0.6449123620986938
85 0.6449071764945984
86 0.6449023485183716
87 0.6448978185653687
88 0.6448936462402344
89 0.6448897123336792
90 0.6448860168457031
91 0.6448825597763062
92 0.6448794007301331
93 0.6448763608932495
94 0.6448734402656555
95 0.6448707580566406
96 0.6448682546615601
97 0.6448658108711243
98 0.6448635458946228
99 0.6448613405227661

结果示意

猜你喜欢

转载自blog.csdn.net/weixin_46087812/article/details/114179060