【深度学习】torch使用nngraph构建网络并训练

torch使用nngraph构建网络并训练

model = nn.Sequential()

model:add(nn.Linear(3,5))

prl = nn.ConcatTable()
prl:add(nn.Linear(5,1))
prl:add(nn.Linear(5,1))

model:add(prl)

criterion = nn.ParallelCriterion()
criterion:add(nn.MSECriterion( ))
criterion:add(nn.MSECriterion( ))

input = torch.rand(5,3)

target = {torch.rand(5,1),torch.rand(5, 1)}

output = model:forward(input)

err = criterion:forward(output,target)

参考自:https://groups.google.com/forum/#!topic/torch7/OLjblK6iVl0

猜你喜欢

转载自blog.csdn.net/Sun7_She/article/details/78027091