import torch
batch_size=1
sep_len=3 #一个样本列中,所含独立样本x的个数
input_size=4
hidden_size=2
cell=torch.nn.RNNCell(input_size=input_size,hidden_size=hidden_size)
#(seq,batch,feature)
dataset=torch.randn(sep_len,batch_size,input_size)
hidden=torch.zeros(batch_size,hidden_size)#定义h0
for idx,input in enumerate(dataset): #一个batch_size的x1+h0=h1 x2+h1=h2 以此迭代
print('='*20,idx,'='*20)
print('input_size',input.shape)
hidden=cell(input,hidden)
print('output_size', hidden.shape)
print(hidden)
循环神经网络:RNNCell的定义
猜你喜欢
转载自blog.csdn.net/qq_21686871/article/details/114407865
今日推荐
周排行