pytorch的summary和画网络结构图

前言

tensorflow2.x里面,可以直接调用model.summary()plot_model()查看和保存网络的结构。但是在torch里面没有这么简单。需要借助额外的包torchsummarytorchviz

以一个3D分割网络为例

import os
os.environ["CUDA_VISIBLE_DEVICES"]='-1'

from monai.networks import nets
import torchsummary
import torch
import torchviz

def get_model(img_size_,
              in_channels_,
              num_classes_,
              feature_size_,
              depths_,
              ):
    net = nets.SwinUNETR(
        img_size=img_size_,
        in_channels=in_channels_,
        out_channels=num_classes_,
        feature_size=feature_size_,
        depths=depths_
    )
    return net

if __name__ == '__main__':
    img_size = (64,64,32)
    in_channels = 1
    num_classes = 4
    feature_size = 48
    depths = [2,4,2,2]
    swinuntr = get_model(img_size_=img_size,
                         in_channels_=in_channels,
                         num_classes_=num_classes,
                         feature_size_=feature_size,
                         depths_=depths)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = swinuntr.to(device)
    torchsummary.summary(model,input_size=(1,)+img_size)

    # -> batch_size,inchannel,h,w,d
    rand_data = torch.rand((1,1,)+(img_size)).to(device)
    torchviz.make_dot(model(rand_data),params=dict(model.named_parameters())).render("swin",format="pdf")

结果

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1       [-1, 48, 32, 32, 16]             432
        PatchEmbed-2       [-1, 48, 32, 32, 16]               0
           Dropout-3       [-1, 48, 32, 32, 16]               0
         LayerNorm-4       [-1, 32, 32, 16, 48]              96
            Linear-5             [-1, 343, 144]           7,056
           Softmax-6          [-1, 3, 343, 343]               0
           Dropout-7          [-1, 3, 343, 343]               0
            Linear-8              [-1, 343, 48]           2,352
           Dropout-9              [-1, 343, 48]               0
  WindowAttention-10              [-1, 343, 48]               0
         Identity-11       [-1, 32, 32, 16, 48]               0
        LayerNorm-12       [-1, 32, 32, 16, 48]              96
           Linear-13      [-1, 32, 32, 16, 192]           9,408
             GELU-14      [-1, 32, 32, 16, 192]               0
          Dropout-15      [-1, 32, 32, 16, 192]               0
           Linear-16       [-1, 32, 32, 16, 48]           9,264
          Dropout-17       [-1, 32, 32, 16, 48]               0
         MLPBlock-18       [-1, 32, 32, 16, 48]               0
         Identity-19       [-1, 32, 32, 16, 48]               0
SwinTransformerBlock-20       [-1, 32, 32, 16, 48]               0
        LayerNorm-21       [-1, 32, 32, 16, 48]              96
           Linear-22             [-1, 343, 144]           7,056
          Softmax-23          [-1, 3, 343, 343]               0
          Dropout-24          [-1, 3, 343, 343]               0
           Linear-25              [-1, 343, 48]           2,352
          Dropout-26              [-1, 343, 48]               0
  WindowAttention-27              [-1, 343, 48]               0
         Identity-28       [-1, 32, 32, 16, 48]               0
        LayerNorm-29       [-1, 32, 32, 16, 48]              96
           Linear-30      [-1, 32, 32, 16, 192]           9,408
             GELU-31      [-1, 32, 32, 16, 192]               0
          Dropout-32      [-1, 32, 32, 16, 192]               0
           Linear-33       [-1, 32, 32, 16, 48]           9,264
          Dropout-34       [-1, 32, 32, 16, 48]               0
         MLPBlock-35       [-1, 32, 32, 16, 48]               0
         Identity-36       [-1, 32, 32, 16, 48]               0
SwinTransformerBlock-37       [-1, 32, 32, 16, 48]               0
        LayerNorm-38       [-1, 16, 16, 8, 384]             768
           Linear-39        [-1, 16, 16, 8, 96]          36,864
     PatchMerging-40        [-1, 16, 16, 8, 96]               0
       BasicLayer-41        [-1, 96, 16, 16, 8]               0
        LayerNorm-42        [-1, 16, 16, 8, 96]             192
           Linear-43             [-1, 343, 288]          27,936
          Softmax-44          [-1, 6, 343, 343]               0
          Dropout-45          [-1, 6, 343, 343]               0
           Linear-46              [-1, 343, 96]           9,312
          Dropout-47              [-1, 343, 96]               0
  WindowAttention-48              [-1, 343, 96]               0
         Identity-49        [-1, 16, 16, 8, 96]               0
        LayerNorm-50        [-1, 16, 16, 8, 96]             192
           Linear-51       [-1, 16, 16, 8, 384]          37,248
             GELU-52       [-1, 16, 16, 8, 384]               0
          Dropout-53       [-1, 16, 16, 8, 384]               0
           Linear-54        [-1, 16, 16, 8, 96]          36,960
          Dropout-55        [-1, 16, 16, 8, 96]               0
         MLPBlock-56        [-1, 16, 16, 8, 96]               0
         Identity-57        [-1, 16, 16, 8, 96]               0
SwinTransformerBlock-58        [-1, 16, 16, 8, 96]               0
        LayerNorm-59        [-1, 16, 16, 8, 96]             192
           Linear-60             [-1, 343, 288]          27,936
          Softmax-61          [-1, 6, 343, 343]               0
          Dropout-62          [-1, 6, 343, 343]               0
           Linear-63              [-1, 343, 96]           9,312
          Dropout-64              [-1, 343, 96]               0
  WindowAttention-65              [-1, 343, 96]               0
         Identity-66        [-1, 16, 16, 8, 96]               0
        LayerNorm-67        [-1, 16, 16, 8, 96]             192
           Linear-68       [-1, 16, 16, 8, 384]          37,248
             GELU-69       [-1, 16, 16, 8, 384]               0
          Dropout-70       [-1, 16, 16, 8, 384]               0
           Linear-71        [-1, 16, 16, 8, 96]          36,960
          Dropout-72        [-1, 16, 16, 8, 96]               0
         MLPBlock-73        [-1, 16, 16, 8, 96]               0
         Identity-74        [-1, 16, 16, 8, 96]               0
SwinTransformerBlock-75        [-1, 16, 16, 8, 96]               0
        LayerNorm-76        [-1, 16, 16, 8, 96]             192
           Linear-77             [-1, 343, 288]          27,936
          Softmax-78          [-1, 6, 343, 343]               0
          Dropout-79          [-1, 6, 343, 343]               0
           Linear-80              [-1, 343, 96]           9,312
          Dropout-81              [-1, 343, 96]               0
  WindowAttention-82              [-1, 343, 96]               0
         Identity-83        [-1, 16, 16, 8, 96]               0
        LayerNorm-84        [-1, 16, 16, 8, 96]             192
           Linear-85       [-1, 16, 16, 8, 384]          37,248
             GELU-86       [-1, 16, 16, 8, 384]               0
          Dropout-87       [-1, 16, 16, 8, 384]               0
           Linear-88        [-1, 16, 16, 8, 96]          36,960
          Dropout-89        [-1, 16, 16, 8, 96]               0
         MLPBlock-90        [-1, 16, 16, 8, 96]               0
         Identity-91        [-1, 16, 16, 8, 96]               0
SwinTransformerBlock-92        [-1, 16, 16, 8, 96]               0
        LayerNorm-93        [-1, 16, 16, 8, 96]             192
           Linear-94             [-1, 343, 288]          27,936
          Softmax-95          [-1, 6, 343, 343]               0
          Dropout-96          [-1, 6, 343, 343]               0
           Linear-97              [-1, 343, 96]           9,312
          Dropout-98              [-1, 343, 96]               0
  WindowAttention-99              [-1, 343, 96]               0
        Identity-100        [-1, 16, 16, 8, 96]               0
       LayerNorm-101        [-1, 16, 16, 8, 96]             192
          Linear-102       [-1, 16, 16, 8, 384]          37,248
            GELU-103       [-1, 16, 16, 8, 384]               0
         Dropout-104       [-1, 16, 16, 8, 384]               0
          Linear-105        [-1, 16, 16, 8, 96]          36,960
         Dropout-106        [-1, 16, 16, 8, 96]               0
        MLPBlock-107        [-1, 16, 16, 8, 96]               0
        Identity-108        [-1, 16, 16, 8, 96]               0
SwinTransformerBlock-109        [-1, 16, 16, 8, 96]               0
       LayerNorm-110         [-1, 8, 8, 4, 768]           1,536
          Linear-111         [-1, 8, 8, 4, 192]         147,456
    PatchMerging-112         [-1, 8, 8, 4, 192]               0
      BasicLayer-113         [-1, 192, 8, 8, 4]               0
       LayerNorm-114         [-1, 8, 8, 4, 192]             384
          Linear-115             [-1, 196, 576]         111,168
         Softmax-116         [-1, 12, 196, 196]               0
         Dropout-117         [-1, 12, 196, 196]               0
          Linear-118             [-1, 196, 192]          37,056
         Dropout-119             [-1, 196, 192]               0
 WindowAttention-120             [-1, 196, 192]               0
        Identity-121         [-1, 8, 8, 4, 192]               0
       LayerNorm-122         [-1, 8, 8, 4, 192]             384
          Linear-123         [-1, 8, 8, 4, 768]         148,224
            GELU-124         [-1, 8, 8, 4, 768]               0
         Dropout-125         [-1, 8, 8, 4, 768]               0
          Linear-126         [-1, 8, 8, 4, 192]         147,648
         Dropout-127         [-1, 8, 8, 4, 192]               0
        MLPBlock-128         [-1, 8, 8, 4, 192]               0
        Identity-129         [-1, 8, 8, 4, 192]               0
SwinTransformerBlock-130         [-1, 8, 8, 4, 192]               0
       LayerNorm-131         [-1, 8, 8, 4, 192]             384
          Linear-132             [-1, 196, 576]         111,168
         Softmax-133         [-1, 12, 196, 196]               0
         Dropout-134         [-1, 12, 196, 196]               0
          Linear-135             [-1, 196, 192]          37,056
         Dropout-136             [-1, 196, 192]               0
 WindowAttention-137             [-1, 196, 192]               0
        Identity-138         [-1, 8, 8, 4, 192]               0
       LayerNorm-139         [-1, 8, 8, 4, 192]             384
          Linear-140         [-1, 8, 8, 4, 768]         148,224
            GELU-141         [-1, 8, 8, 4, 768]               0
         Dropout-142         [-1, 8, 8, 4, 768]               0
          Linear-143         [-1, 8, 8, 4, 192]         147,648
         Dropout-144         [-1, 8, 8, 4, 192]               0
        MLPBlock-145         [-1, 8, 8, 4, 192]               0
        Identity-146         [-1, 8, 8, 4, 192]               0
SwinTransformerBlock-147         [-1, 8, 8, 4, 192]               0
       LayerNorm-148        [-1, 4, 4, 2, 1536]           3,072
          Linear-149         [-1, 4, 4, 2, 384]         589,824
    PatchMerging-150         [-1, 4, 4, 2, 384]               0
      BasicLayer-151         [-1, 384, 4, 4, 2]               0
       LayerNorm-152         [-1, 4, 4, 2, 384]             768
          Linear-153             [-1, 32, 1152]         443,520
         Softmax-154           [-1, 24, 32, 32]               0
         Dropout-155           [-1, 24, 32, 32]               0
          Linear-156              [-1, 32, 384]         147,840
         Dropout-157              [-1, 32, 384]               0
 WindowAttention-158              [-1, 32, 384]               0
        Identity-159         [-1, 4, 4, 2, 384]               0
       LayerNorm-160         [-1, 4, 4, 2, 384]             768
          Linear-161        [-1, 4, 4, 2, 1536]         591,360
            GELU-162        [-1, 4, 4, 2, 1536]               0
         Dropout-163        [-1, 4, 4, 2, 1536]               0
          Linear-164         [-1, 4, 4, 2, 384]         590,208
         Dropout-165         [-1, 4, 4, 2, 384]               0
        MLPBlock-166         [-1, 4, 4, 2, 384]               0
        Identity-167         [-1, 4, 4, 2, 384]               0
SwinTransformerBlock-168         [-1, 4, 4, 2, 384]               0
       LayerNorm-169         [-1, 4, 4, 2, 384]             768
          Linear-170             [-1, 32, 1152]         443,520
         Softmax-171           [-1, 24, 32, 32]               0
         Dropout-172           [-1, 24, 32, 32]               0
          Linear-173              [-1, 32, 384]         147,840
         Dropout-174              [-1, 32, 384]               0
 WindowAttention-175              [-1, 32, 384]               0
        Identity-176         [-1, 4, 4, 2, 384]               0
       LayerNorm-177         [-1, 4, 4, 2, 384]             768
          Linear-178        [-1, 4, 4, 2, 1536]         591,360
            GELU-179        [-1, 4, 4, 2, 1536]               0
         Dropout-180        [-1, 4, 4, 2, 1536]               0
          Linear-181         [-1, 4, 4, 2, 384]         590,208
         Dropout-182         [-1, 4, 4, 2, 384]               0
        MLPBlock-183         [-1, 4, 4, 2, 384]               0
        Identity-184         [-1, 4, 4, 2, 384]               0
SwinTransformerBlock-185         [-1, 4, 4, 2, 384]               0
       LayerNorm-186        [-1, 2, 2, 1, 3072]           6,144
          Linear-187         [-1, 2, 2, 1, 768]       2,359,296
    PatchMerging-188         [-1, 2, 2, 1, 768]               0
      BasicLayer-189         [-1, 768, 2, 2, 1]               0
 SwinTransformer-190  [[-1, 48, 32, 32, 16], [-1, 96, 16, 16, 8], [-1, 192, 8, 8, 4], [-1, 384, 4, 4, 2], [-1, 768, 2, 2, 1]]               0
          Conv3d-191       [-1, 48, 64, 64, 32]           1,296
  InstanceNorm3d-192       [-1, 48, 64, 64, 32]               0
       LeakyReLU-193       [-1, 48, 64, 64, 32]               0
          Conv3d-194       [-1, 48, 64, 64, 32]          62,208
  InstanceNorm3d-195       [-1, 48, 64, 64, 32]               0
          Conv3d-196       [-1, 48, 64, 64, 32]              48
  InstanceNorm3d-197       [-1, 48, 64, 64, 32]               0
       LeakyReLU-198       [-1, 48, 64, 64, 32]               0
    UnetResBlock-199       [-1, 48, 64, 64, 32]               0
 UnetrBasicBlock-200       [-1, 48, 64, 64, 32]               0
          Conv3d-201       [-1, 48, 32, 32, 16]          62,208
  InstanceNorm3d-202       [-1, 48, 32, 32, 16]               0
       LeakyReLU-203       [-1, 48, 32, 32, 16]               0
          Conv3d-204       [-1, 48, 32, 32, 16]          62,208
  InstanceNorm3d-205       [-1, 48, 32, 32, 16]               0
       LeakyReLU-206       [-1, 48, 32, 32, 16]               0
    UnetResBlock-207       [-1, 48, 32, 32, 16]               0
 UnetrBasicBlock-208       [-1, 48, 32, 32, 16]               0
          Conv3d-209        [-1, 96, 16, 16, 8]         248,832
  InstanceNorm3d-210        [-1, 96, 16, 16, 8]               0
       LeakyReLU-211        [-1, 96, 16, 16, 8]               0
          Conv3d-212        [-1, 96, 16, 16, 8]         248,832
  InstanceNorm3d-213        [-1, 96, 16, 16, 8]               0
       LeakyReLU-214        [-1, 96, 16, 16, 8]               0
    UnetResBlock-215        [-1, 96, 16, 16, 8]               0
 UnetrBasicBlock-216        [-1, 96, 16, 16, 8]               0
          Conv3d-217         [-1, 192, 8, 8, 4]         995,328
  InstanceNorm3d-218         [-1, 192, 8, 8, 4]               0
       LeakyReLU-219         [-1, 192, 8, 8, 4]               0
          Conv3d-220         [-1, 192, 8, 8, 4]         995,328
  InstanceNorm3d-221         [-1, 192, 8, 8, 4]               0
       LeakyReLU-222         [-1, 192, 8, 8, 4]               0
    UnetResBlock-223         [-1, 192, 8, 8, 4]               0
 UnetrBasicBlock-224         [-1, 192, 8, 8, 4]               0
          Conv3d-225         [-1, 768, 2, 2, 1]      15,925,248
  InstanceNorm3d-226         [-1, 768, 2, 2, 1]               0
       LeakyReLU-227         [-1, 768, 2, 2, 1]               0
          Conv3d-228         [-1, 768, 2, 2, 1]      15,925,248
  InstanceNorm3d-229         [-1, 768, 2, 2, 1]               0
       LeakyReLU-230         [-1, 768, 2, 2, 1]               0
    UnetResBlock-231         [-1, 768, 2, 2, 1]               0
 UnetrBasicBlock-232         [-1, 768, 2, 2, 1]               0
 ConvTranspose3d-233         [-1, 384, 4, 4, 2]       2,359,296
          Conv3d-234         [-1, 384, 4, 4, 2]       7,962,624
  InstanceNorm3d-235         [-1, 384, 4, 4, 2]               0
       LeakyReLU-236         [-1, 384, 4, 4, 2]               0
          Conv3d-237         [-1, 384, 4, 4, 2]       3,981,312
  InstanceNorm3d-238         [-1, 384, 4, 4, 2]               0
          Conv3d-239         [-1, 384, 4, 4, 2]         294,912
  InstanceNorm3d-240         [-1, 384, 4, 4, 2]               0
       LeakyReLU-241         [-1, 384, 4, 4, 2]               0
    UnetResBlock-242         [-1, 384, 4, 4, 2]               0
    UnetrUpBlock-243         [-1, 384, 4, 4, 2]               0
 ConvTranspose3d-244         [-1, 192, 8, 8, 4]         589,824
          Conv3d-245         [-1, 192, 8, 8, 4]       1,990,656
  InstanceNorm3d-246         [-1, 192, 8, 8, 4]               0
       LeakyReLU-247         [-1, 192, 8, 8, 4]               0
          Conv3d-248         [-1, 192, 8, 8, 4]         995,328
  InstanceNorm3d-249         [-1, 192, 8, 8, 4]               0
          Conv3d-250         [-1, 192, 8, 8, 4]          73,728
  InstanceNorm3d-251         [-1, 192, 8, 8, 4]               0
       LeakyReLU-252         [-1, 192, 8, 8, 4]               0
    UnetResBlock-253         [-1, 192, 8, 8, 4]               0
    UnetrUpBlock-254         [-1, 192, 8, 8, 4]               0
 ConvTranspose3d-255        [-1, 96, 16, 16, 8]         147,456
          Conv3d-256        [-1, 96, 16, 16, 8]         497,664
  InstanceNorm3d-257        [-1, 96, 16, 16, 8]               0
       LeakyReLU-258        [-1, 96, 16, 16, 8]               0
          Conv3d-259        [-1, 96, 16, 16, 8]         248,832
  InstanceNorm3d-260        [-1, 96, 16, 16, 8]               0
          Conv3d-261        [-1, 96, 16, 16, 8]          18,432
  InstanceNorm3d-262        [-1, 96, 16, 16, 8]               0
       LeakyReLU-263        [-1, 96, 16, 16, 8]               0
    UnetResBlock-264        [-1, 96, 16, 16, 8]               0
    UnetrUpBlock-265        [-1, 96, 16, 16, 8]               0
 ConvTranspose3d-266       [-1, 48, 32, 32, 16]          36,864
          Conv3d-267       [-1, 48, 32, 32, 16]         124,416
  InstanceNorm3d-268       [-1, 48, 32, 32, 16]               0
       LeakyReLU-269       [-1, 48, 32, 32, 16]               0
          Conv3d-270       [-1, 48, 32, 32, 16]          62,208
  InstanceNorm3d-271       [-1, 48, 32, 32, 16]               0
          Conv3d-272       [-1, 48, 32, 32, 16]           4,608
  InstanceNorm3d-273       [-1, 48, 32, 32, 16]               0
       LeakyReLU-274       [-1, 48, 32, 32, 16]               0
    UnetResBlock-275       [-1, 48, 32, 32, 16]               0
    UnetrUpBlock-276       [-1, 48, 32, 32, 16]               0
 ConvTranspose3d-277       [-1, 48, 64, 64, 32]          18,432
          Conv3d-278       [-1, 48, 64, 64, 32]         124,416
  InstanceNorm3d-279       [-1, 48, 64, 64, 32]               0
       LeakyReLU-280       [-1, 48, 64, 64, 32]               0
          Conv3d-281       [-1, 48, 64, 64, 32]          62,208
  InstanceNorm3d-282       [-1, 48, 64, 64, 32]               0
          Conv3d-283       [-1, 48, 64, 64, 32]           4,608
  InstanceNorm3d-284       [-1, 48, 64, 64, 32]               0
       LeakyReLU-285       [-1, 48, 64, 64, 32]               0
    UnetResBlock-286       [-1, 48, 64, 64, 32]               0
    UnetrUpBlock-287       [-1, 48, 64, 64, 32]               0
          Conv3d-288        [-1, 4, 64, 64, 32]             196
    UnetOutBlock-289        [-1, 4, 64, 64, 32]               0
================================================================
Total params: 62,212,756
Trainable params: 62,212,756
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.50
Forward/backward pass size (MB): 1658.32
Params size (MB): 237.32
Estimated Total Size (MB): 1896.14
----------------------------------------------------------------

生成的pdf文件:
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/sdhdsf132452/article/details/129686222