调试网络模型【PyTorch版】

测试网络流程是否顺畅

if __name__ == '__main__':
    x=torch.randn(2,3,256,256)
    net=Union_Seg_1_v1()
    print(net(x).shape)

其中,torch.randn(batch_size , channel , size[0] , size[1] )

batch_size : 运行一次输入的数据量个数

channel : 输入的通道数

size : 输入图像的规模(长和宽)

首先,定义输入数据格式

然后,定义网络

将数据输入网络,并打印输出数据的格式

打印网络结构

from torchsummary import summary

# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net=DPNet_v1()
model = net.to(device)
# input_size=(channel,size,size)
summary(model, input_size=(3,256,256))

需要使用torchsummary包。pip install torchsummary 或者 conda install -c ravelbio torchsummary

猜你喜欢

转载自blog.csdn.net/qq_41704436/article/details/131147158