前言
在tensorflow2.x
里面,可以直接调用model.summary()
和plot_model()
查看和保存网络的结构。但是在torch
里面没有这么简单。需要借助额外的包torchsummary
和torchviz
以一个3D分割网络为例
import os
os.environ["CUDA_VISIBLE_DEVICES"]='-1'
from monai.networks import nets
import torchsummary
import torch
import torchviz
def get_model(img_size_,
in_channels_,
num_classes_,
feature_size_,
depths_,
):
net = nets.SwinUNETR(
img_size=img_size_,
in_channels=in_channels_,
out_channels=num_classes_,
feature_size=feature_size_,
depths=depths_
)
return net
if __name__ == '__main__':
img_size = (64,64,32)
in_channels = 1
num_classes = 4
feature_size = 48
depths = [2,4,2,2]
swinuntr = get_model(img_size_=img_size,
in_channels_=in_channels,
num_classes_=num_classes,
feature_size_=feature_size,
depths_=depths)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = swinuntr.to(device)
torchsummary.summary(model,input_size=(1,)+img_size)
# -> batch_size,inchannel,h,w,d
rand_data = torch.rand((1,1,)+(img_size)).to(device)
torchviz.make_dot(model(rand_data),params=dict(model.named_parameters())).render("swin",format="pdf")
结果
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv3d-1 [-1, 48, 32, 32, 16] 432
PatchEmbed-2 [-1, 48, 32, 32, 16] 0
Dropout-3 [-1, 48, 32, 32, 16] 0
LayerNorm-4 [-1, 32, 32, 16, 48] 96
Linear-5 [-1, 343, 144] 7,056
Softmax-6 [-1, 3, 343, 343] 0
Dropout-7 [-1, 3, 343, 343] 0
Linear-8 [-1, 343, 48] 2,352
Dropout-9 [-1, 343, 48] 0
WindowAttention-10 [-1, 343, 48] 0
Identity-11 [-1, 32, 32, 16, 48] 0
LayerNorm-12 [-1, 32, 32, 16, 48] 96
Linear-13 [-1, 32, 32, 16, 192] 9,408
GELU-14 [-1, 32, 32, 16, 192] 0
Dropout-15 [-1, 32, 32, 16, 192] 0
Linear-16 [-1, 32, 32, 16, 48] 9,264
Dropout-17 [-1, 32, 32, 16, 48] 0
MLPBlock-18 [-1, 32, 32, 16, 48] 0
Identity-19 [-1, 32, 32, 16, 48] 0
SwinTransformerBlock-20 [-1, 32, 32, 16, 48] 0
LayerNorm-21 [-1, 32, 32, 16, 48] 96
Linear-22 [-1, 343, 144] 7,056
Softmax-23 [-1, 3, 343, 343] 0
Dropout-24 [-1, 3, 343, 343] 0
Linear-25 [-1, 343, 48] 2,352
Dropout-26 [-1, 343, 48] 0
WindowAttention-27 [-1, 343, 48] 0
Identity-28 [-1, 32, 32, 16, 48] 0
LayerNorm-29 [-1, 32, 32, 16, 48] 96
Linear-30 [-1, 32, 32, 16, 192] 9,408
GELU-31 [-1, 32, 32, 16, 192] 0
Dropout-32 [-1, 32, 32, 16, 192] 0
Linear-33 [-1, 32, 32, 16, 48] 9,264
Dropout-34 [-1, 32, 32, 16, 48] 0
MLPBlock-35 [-1, 32, 32, 16, 48] 0
Identity-36 [-1, 32, 32, 16, 48] 0
SwinTransformerBlock-37 [-1, 32, 32, 16, 48] 0
LayerNorm-38 [-1, 16, 16, 8, 384] 768
Linear-39 [-1, 16, 16, 8, 96] 36,864
PatchMerging-40 [-1, 16, 16, 8, 96] 0
BasicLayer-41 [-1, 96, 16, 16, 8] 0
LayerNorm-42 [-1, 16, 16, 8, 96] 192
Linear-43 [-1, 343, 288] 27,936
Softmax-44 [-1, 6, 343, 343] 0
Dropout-45 [-1, 6, 343, 343] 0
Linear-46 [-1, 343, 96] 9,312
Dropout-47 [-1, 343, 96] 0
WindowAttention-48 [-1, 343, 96] 0
Identity-49 [-1, 16, 16, 8, 96] 0
LayerNorm-50 [-1, 16, 16, 8, 96] 192
Linear-51 [-1, 16, 16, 8, 384] 37,248
GELU-52 [-1, 16, 16, 8, 384] 0
Dropout-53 [-1, 16, 16, 8, 384] 0
Linear-54 [-1, 16, 16, 8, 96] 36,960
Dropout-55 [-1, 16, 16, 8, 96] 0
MLPBlock-56 [-1, 16, 16, 8, 96] 0
Identity-57 [-1, 16, 16, 8, 96] 0
SwinTransformerBlock-58 [-1, 16, 16, 8, 96] 0
LayerNorm-59 [-1, 16, 16, 8, 96] 192
Linear-60 [-1, 343, 288] 27,936
Softmax-61 [-1, 6, 343, 343] 0
Dropout-62 [-1, 6, 343, 343] 0
Linear-63 [-1, 343, 96] 9,312
Dropout-64 [-1, 343, 96] 0
WindowAttention-65 [-1, 343, 96] 0
Identity-66 [-1, 16, 16, 8, 96] 0
LayerNorm-67 [-1, 16, 16, 8, 96] 192
Linear-68 [-1, 16, 16, 8, 384] 37,248
GELU-69 [-1, 16, 16, 8, 384] 0
Dropout-70 [-1, 16, 16, 8, 384] 0
Linear-71 [-1, 16, 16, 8, 96] 36,960
Dropout-72 [-1, 16, 16, 8, 96] 0
MLPBlock-73 [-1, 16, 16, 8, 96] 0
Identity-74 [-1, 16, 16, 8, 96] 0
SwinTransformerBlock-75 [-1, 16, 16, 8, 96] 0
LayerNorm-76 [-1, 16, 16, 8, 96] 192
Linear-77 [-1, 343, 288] 27,936
Softmax-78 [-1, 6, 343, 343] 0
Dropout-79 [-1, 6, 343, 343] 0
Linear-80 [-1, 343, 96] 9,312
Dropout-81 [-1, 343, 96] 0
WindowAttention-82 [-1, 343, 96] 0
Identity-83 [-1, 16, 16, 8, 96] 0
LayerNorm-84 [-1, 16, 16, 8, 96] 192
Linear-85 [-1, 16, 16, 8, 384] 37,248
GELU-86 [-1, 16, 16, 8, 384] 0
Dropout-87 [-1, 16, 16, 8, 384] 0
Linear-88 [-1, 16, 16, 8, 96] 36,960
Dropout-89 [-1, 16, 16, 8, 96] 0
MLPBlock-90 [-1, 16, 16, 8, 96] 0
Identity-91 [-1, 16, 16, 8, 96] 0
SwinTransformerBlock-92 [-1, 16, 16, 8, 96] 0
LayerNorm-93 [-1, 16, 16, 8, 96] 192
Linear-94 [-1, 343, 288] 27,936
Softmax-95 [-1, 6, 343, 343] 0
Dropout-96 [-1, 6, 343, 343] 0
Linear-97 [-1, 343, 96] 9,312
Dropout-98 [-1, 343, 96] 0
WindowAttention-99 [-1, 343, 96] 0
Identity-100 [-1, 16, 16, 8, 96] 0
LayerNorm-101 [-1, 16, 16, 8, 96] 192
Linear-102 [-1, 16, 16, 8, 384] 37,248
GELU-103 [-1, 16, 16, 8, 384] 0
Dropout-104 [-1, 16, 16, 8, 384] 0
Linear-105 [-1, 16, 16, 8, 96] 36,960
Dropout-106 [-1, 16, 16, 8, 96] 0
MLPBlock-107 [-1, 16, 16, 8, 96] 0
Identity-108 [-1, 16, 16, 8, 96] 0
SwinTransformerBlock-109 [-1, 16, 16, 8, 96] 0
LayerNorm-110 [-1, 8, 8, 4, 768] 1,536
Linear-111 [-1, 8, 8, 4, 192] 147,456
PatchMerging-112 [-1, 8, 8, 4, 192] 0
BasicLayer-113 [-1, 192, 8, 8, 4] 0
LayerNorm-114 [-1, 8, 8, 4, 192] 384
Linear-115 [-1, 196, 576] 111,168
Softmax-116 [-1, 12, 196, 196] 0
Dropout-117 [-1, 12, 196, 196] 0
Linear-118 [-1, 196, 192] 37,056
Dropout-119 [-1, 196, 192] 0
WindowAttention-120 [-1, 196, 192] 0
Identity-121 [-1, 8, 8, 4, 192] 0
LayerNorm-122 [-1, 8, 8, 4, 192] 384
Linear-123 [-1, 8, 8, 4, 768] 148,224
GELU-124 [-1, 8, 8, 4, 768] 0
Dropout-125 [-1, 8, 8, 4, 768] 0
Linear-126 [-1, 8, 8, 4, 192] 147,648
Dropout-127 [-1, 8, 8, 4, 192] 0
MLPBlock-128 [-1, 8, 8, 4, 192] 0
Identity-129 [-1, 8, 8, 4, 192] 0
SwinTransformerBlock-130 [-1, 8, 8, 4, 192] 0
LayerNorm-131 [-1, 8, 8, 4, 192] 384
Linear-132 [-1, 196, 576] 111,168
Softmax-133 [-1, 12, 196, 196] 0
Dropout-134 [-1, 12, 196, 196] 0
Linear-135 [-1, 196, 192] 37,056
Dropout-136 [-1, 196, 192] 0
WindowAttention-137 [-1, 196, 192] 0
Identity-138 [-1, 8, 8, 4, 192] 0
LayerNorm-139 [-1, 8, 8, 4, 192] 384
Linear-140 [-1, 8, 8, 4, 768] 148,224
GELU-141 [-1, 8, 8, 4, 768] 0
Dropout-142 [-1, 8, 8, 4, 768] 0
Linear-143 [-1, 8, 8, 4, 192] 147,648
Dropout-144 [-1, 8, 8, 4, 192] 0
MLPBlock-145 [-1, 8, 8, 4, 192] 0
Identity-146 [-1, 8, 8, 4, 192] 0
SwinTransformerBlock-147 [-1, 8, 8, 4, 192] 0
LayerNorm-148 [-1, 4, 4, 2, 1536] 3,072
Linear-149 [-1, 4, 4, 2, 384] 589,824
PatchMerging-150 [-1, 4, 4, 2, 384] 0
BasicLayer-151 [-1, 384, 4, 4, 2] 0
LayerNorm-152 [-1, 4, 4, 2, 384] 768
Linear-153 [-1, 32, 1152] 443,520
Softmax-154 [-1, 24, 32, 32] 0
Dropout-155 [-1, 24, 32, 32] 0
Linear-156 [-1, 32, 384] 147,840
Dropout-157 [-1, 32, 384] 0
WindowAttention-158 [-1, 32, 384] 0
Identity-159 [-1, 4, 4, 2, 384] 0
LayerNorm-160 [-1, 4, 4, 2, 384] 768
Linear-161 [-1, 4, 4, 2, 1536] 591,360
GELU-162 [-1, 4, 4, 2, 1536] 0
Dropout-163 [-1, 4, 4, 2, 1536] 0
Linear-164 [-1, 4, 4, 2, 384] 590,208
Dropout-165 [-1, 4, 4, 2, 384] 0
MLPBlock-166 [-1, 4, 4, 2, 384] 0
Identity-167 [-1, 4, 4, 2, 384] 0
SwinTransformerBlock-168 [-1, 4, 4, 2, 384] 0
LayerNorm-169 [-1, 4, 4, 2, 384] 768
Linear-170 [-1, 32, 1152] 443,520
Softmax-171 [-1, 24, 32, 32] 0
Dropout-172 [-1, 24, 32, 32] 0
Linear-173 [-1, 32, 384] 147,840
Dropout-174 [-1, 32, 384] 0
WindowAttention-175 [-1, 32, 384] 0
Identity-176 [-1, 4, 4, 2, 384] 0
LayerNorm-177 [-1, 4, 4, 2, 384] 768
Linear-178 [-1, 4, 4, 2, 1536] 591,360
GELU-179 [-1, 4, 4, 2, 1536] 0
Dropout-180 [-1, 4, 4, 2, 1536] 0
Linear-181 [-1, 4, 4, 2, 384] 590,208
Dropout-182 [-1, 4, 4, 2, 384] 0
MLPBlock-183 [-1, 4, 4, 2, 384] 0
Identity-184 [-1, 4, 4, 2, 384] 0
SwinTransformerBlock-185 [-1, 4, 4, 2, 384] 0
LayerNorm-186 [-1, 2, 2, 1, 3072] 6,144
Linear-187 [-1, 2, 2, 1, 768] 2,359,296
PatchMerging-188 [-1, 2, 2, 1, 768] 0
BasicLayer-189 [-1, 768, 2, 2, 1] 0
SwinTransformer-190 [[-1, 48, 32, 32, 16], [-1, 96, 16, 16, 8], [-1, 192, 8, 8, 4], [-1, 384, 4, 4, 2], [-1, 768, 2, 2, 1]] 0
Conv3d-191 [-1, 48, 64, 64, 32] 1,296
InstanceNorm3d-192 [-1, 48, 64, 64, 32] 0
LeakyReLU-193 [-1, 48, 64, 64, 32] 0
Conv3d-194 [-1, 48, 64, 64, 32] 62,208
InstanceNorm3d-195 [-1, 48, 64, 64, 32] 0
Conv3d-196 [-1, 48, 64, 64, 32] 48
InstanceNorm3d-197 [-1, 48, 64, 64, 32] 0
LeakyReLU-198 [-1, 48, 64, 64, 32] 0
UnetResBlock-199 [-1, 48, 64, 64, 32] 0
UnetrBasicBlock-200 [-1, 48, 64, 64, 32] 0
Conv3d-201 [-1, 48, 32, 32, 16] 62,208
InstanceNorm3d-202 [-1, 48, 32, 32, 16] 0
LeakyReLU-203 [-1, 48, 32, 32, 16] 0
Conv3d-204 [-1, 48, 32, 32, 16] 62,208
InstanceNorm3d-205 [-1, 48, 32, 32, 16] 0
LeakyReLU-206 [-1, 48, 32, 32, 16] 0
UnetResBlock-207 [-1, 48, 32, 32, 16] 0
UnetrBasicBlock-208 [-1, 48, 32, 32, 16] 0
Conv3d-209 [-1, 96, 16, 16, 8] 248,832
InstanceNorm3d-210 [-1, 96, 16, 16, 8] 0
LeakyReLU-211 [-1, 96, 16, 16, 8] 0
Conv3d-212 [-1, 96, 16, 16, 8] 248,832
InstanceNorm3d-213 [-1, 96, 16, 16, 8] 0
LeakyReLU-214 [-1, 96, 16, 16, 8] 0
UnetResBlock-215 [-1, 96, 16, 16, 8] 0
UnetrBasicBlock-216 [-1, 96, 16, 16, 8] 0
Conv3d-217 [-1, 192, 8, 8, 4] 995,328
InstanceNorm3d-218 [-1, 192, 8, 8, 4] 0
LeakyReLU-219 [-1, 192, 8, 8, 4] 0
Conv3d-220 [-1, 192, 8, 8, 4] 995,328
InstanceNorm3d-221 [-1, 192, 8, 8, 4] 0
LeakyReLU-222 [-1, 192, 8, 8, 4] 0
UnetResBlock-223 [-1, 192, 8, 8, 4] 0
UnetrBasicBlock-224 [-1, 192, 8, 8, 4] 0
Conv3d-225 [-1, 768, 2, 2, 1] 15,925,248
InstanceNorm3d-226 [-1, 768, 2, 2, 1] 0
LeakyReLU-227 [-1, 768, 2, 2, 1] 0
Conv3d-228 [-1, 768, 2, 2, 1] 15,925,248
InstanceNorm3d-229 [-1, 768, 2, 2, 1] 0
LeakyReLU-230 [-1, 768, 2, 2, 1] 0
UnetResBlock-231 [-1, 768, 2, 2, 1] 0
UnetrBasicBlock-232 [-1, 768, 2, 2, 1] 0
ConvTranspose3d-233 [-1, 384, 4, 4, 2] 2,359,296
Conv3d-234 [-1, 384, 4, 4, 2] 7,962,624
InstanceNorm3d-235 [-1, 384, 4, 4, 2] 0
LeakyReLU-236 [-1, 384, 4, 4, 2] 0
Conv3d-237 [-1, 384, 4, 4, 2] 3,981,312
InstanceNorm3d-238 [-1, 384, 4, 4, 2] 0
Conv3d-239 [-1, 384, 4, 4, 2] 294,912
InstanceNorm3d-240 [-1, 384, 4, 4, 2] 0
LeakyReLU-241 [-1, 384, 4, 4, 2] 0
UnetResBlock-242 [-1, 384, 4, 4, 2] 0
UnetrUpBlock-243 [-1, 384, 4, 4, 2] 0
ConvTranspose3d-244 [-1, 192, 8, 8, 4] 589,824
Conv3d-245 [-1, 192, 8, 8, 4] 1,990,656
InstanceNorm3d-246 [-1, 192, 8, 8, 4] 0
LeakyReLU-247 [-1, 192, 8, 8, 4] 0
Conv3d-248 [-1, 192, 8, 8, 4] 995,328
InstanceNorm3d-249 [-1, 192, 8, 8, 4] 0
Conv3d-250 [-1, 192, 8, 8, 4] 73,728
InstanceNorm3d-251 [-1, 192, 8, 8, 4] 0
LeakyReLU-252 [-1, 192, 8, 8, 4] 0
UnetResBlock-253 [-1, 192, 8, 8, 4] 0
UnetrUpBlock-254 [-1, 192, 8, 8, 4] 0
ConvTranspose3d-255 [-1, 96, 16, 16, 8] 147,456
Conv3d-256 [-1, 96, 16, 16, 8] 497,664
InstanceNorm3d-257 [-1, 96, 16, 16, 8] 0
LeakyReLU-258 [-1, 96, 16, 16, 8] 0
Conv3d-259 [-1, 96, 16, 16, 8] 248,832
InstanceNorm3d-260 [-1, 96, 16, 16, 8] 0
Conv3d-261 [-1, 96, 16, 16, 8] 18,432
InstanceNorm3d-262 [-1, 96, 16, 16, 8] 0
LeakyReLU-263 [-1, 96, 16, 16, 8] 0
UnetResBlock-264 [-1, 96, 16, 16, 8] 0
UnetrUpBlock-265 [-1, 96, 16, 16, 8] 0
ConvTranspose3d-266 [-1, 48, 32, 32, 16] 36,864
Conv3d-267 [-1, 48, 32, 32, 16] 124,416
InstanceNorm3d-268 [-1, 48, 32, 32, 16] 0
LeakyReLU-269 [-1, 48, 32, 32, 16] 0
Conv3d-270 [-1, 48, 32, 32, 16] 62,208
InstanceNorm3d-271 [-1, 48, 32, 32, 16] 0
Conv3d-272 [-1, 48, 32, 32, 16] 4,608
InstanceNorm3d-273 [-1, 48, 32, 32, 16] 0
LeakyReLU-274 [-1, 48, 32, 32, 16] 0
UnetResBlock-275 [-1, 48, 32, 32, 16] 0
UnetrUpBlock-276 [-1, 48, 32, 32, 16] 0
ConvTranspose3d-277 [-1, 48, 64, 64, 32] 18,432
Conv3d-278 [-1, 48, 64, 64, 32] 124,416
InstanceNorm3d-279 [-1, 48, 64, 64, 32] 0
LeakyReLU-280 [-1, 48, 64, 64, 32] 0
Conv3d-281 [-1, 48, 64, 64, 32] 62,208
InstanceNorm3d-282 [-1, 48, 64, 64, 32] 0
Conv3d-283 [-1, 48, 64, 64, 32] 4,608
InstanceNorm3d-284 [-1, 48, 64, 64, 32] 0
LeakyReLU-285 [-1, 48, 64, 64, 32] 0
UnetResBlock-286 [-1, 48, 64, 64, 32] 0
UnetrUpBlock-287 [-1, 48, 64, 64, 32] 0
Conv3d-288 [-1, 4, 64, 64, 32] 196
UnetOutBlock-289 [-1, 4, 64, 64, 32] 0
================================================================
Total params: 62,212,756
Trainable params: 62,212,756
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.50
Forward/backward pass size (MB): 1658.32
Params size (MB): 237.32
Estimated Total Size (MB): 1896.14
----------------------------------------------------------------
生成的pdf文件: