class Model(nn.Module):#2022.11.7修改前,这个Model能跑通#forMultivariate
def __init__(self,configs,channel=96,ratio=1):#channel针对ili数据集应该改成36 channel=input_length
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.union = 1*self.pred_len
self.Linear = nn.Linear(self.seq_len,self.union)
self.Linear_1 = nn.Linear(self.union, self.pred_len)
self.fc1 = nn.Linear(self.union,2*self.union)
self.fc2 = nn.Linear(self.union,2*self.union)
self.fc3 = nn.Linear(2*self.union,1)
def forward(self, x):
x = x.permute(0,2,1) # (B,L,C)=》(B,C,L)#forL
b, c, l = x.size() # (B,C,L)
y = self.Linear(x)#forL
y_r = self.Linear_1(y)
y_shuffle = y[torch.randperm(y.size(0))]#将 y按照批次维度打乱顺序得到y_shuffle
x = F.interpolate(x, size=self.union)#插值,mode选择默认
# y = F.interpolate(y, size=self.seq_len)#插值,mode选择默认
# #joint
# h1 = F.relu(self.Linear(x) + self.Linear(y))
# pred_xy = self.Linear(h1)
# # #marginal
# h2 = F.relu(self.Linear(x) + self.Linear(y_shuffle))
# pred_x_y = self.Linear(h2)
#joint
h1 = F.relu(self.fc1(x) + self.fc2(y))
h2 = F.relu(self.fc1(x) + self.fc2(y))
pred_xy = self.fc3(h1)
# #marginal
h2 = F.relu(self.fc1(x) + self.fc2(y_shuffle))
pred_x_y = self.fc3(h2)
return y_r.permute(0,2,1),pred_xy,pred_x_y
为了让中间层特征与输入对齐,使用toch.nn.Functional.interpolate.
参考资料
torch.nn.interpolate—torch上采样和下采样操作_两只蜡笔的小新的博客-CSDN博客_torch 下采样
torch.nn.functional中的interpolate插值函数_大梦冲冲冲的博客-CSDN博客_torch 插值
pytorch torch.nn.functional实现插值和上采样_weixin_30905133的博客-CSDN博客