Pytorch-如何查看网络的模块和参数

Pytorch-如何查看网络的模块和参数

在实验的过程中,我们经常需要知道当前网络由哪些模块组成,以及这些模块当前的参数是什么。

查看模块

以如下网络为例:

m = nn.Sequential(nn.Linear(2, 2),
                  nn.ReLU(),
                  nn.Sequential(nn.BatchNorm2d(2), nn.ReLU()),
                  nn.Sequential(nn.Sigmoid(), nn.ReLU()))

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iWF48Mfr-1580999753365)(Pytorch-如何查看网络的模块和参数/1.jpg)]

children()

只展示网络的子节点。

print(list(m.children()))

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Sequential(
  (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ReLU()
),
 Sequential(
  (0): Sigmoid()
  (1): ReLU()
)]

named_children()

children()的输出加上了编号。

print(list(m.named_children()))

[('0', Linear(in_features=2, out_features=2, bias=True)),
 ('1', ReLU()),
 ('2',
  Sequential(
  (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ReLU()
)),
 ('3', Sequential(
  (0): Sigmoid()
  (1): ReLU()
))]

modules()

以前序dfs遍历并输出这颗树。

print(list(m.modules()))

[Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): ReLU()
  (2): Sequential(
    (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
  )
  (3): Sequential(
    (0): Sigmoid()
    (1): ReLU()
  )
),
 Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Sequential(
  (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ReLU()
),
 BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 ReLU(),
 Sequential(
  (0): Sigmoid()
  (1): ReLU()
),
 Sigmoid(),
 ReLU()]

named_modules()

modules()的输出加上了编号。

print(list(m.named_modules()))

[('',
  Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): ReLU()
  (2): Sequential(
    (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU()
  )
  (3): Sequential(
    (0): Sigmoid()
    (1): ReLU()
  )
)),
 ('0', Linear(in_features=2, out_features=2, bias=True)),
 ('1', ReLU()),
 ('2',
  Sequential(
  (0): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): ReLU()
)),
 ('2.0',
  BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 ('2.1', ReLU()),
 ('3', Sequential(
  (0): Sigmoid()
  (1): ReLU()
)),
 ('3.0', Sigmoid()),
 ('3.1', ReLU())]

查看参数

以BN层为例。m = nn.BatchNorm2d(2)

state_dict()

查看一个网络所有的参数。

print(m.state_dict().keys())

odict_keys(['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked'])

parameters()和named_parameters()

查看网络中需要更新的参数。parameters()只显示参数,named_parameters还显示参数的名称。

for k, v in m.named_parameters():
    print(k)
    print(v)

weight
Parameter containing:
tensor([1., 1.], requires_grad=True)
bias
Parameter containing:
tensor([0., 0.], requires_grad=True)

for v in m.parameters():
    print(v)

Parameter containing:
tensor([1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)

buffers()和named_buffers()

查看网络中不需要更新的参数,如BN层中的running_mean, running_var和num_batches_tracked。

扫描二维码关注公众号,回复: 9350565 查看本文章

一个模块中不需要更新的参数有两种:

  • 普通的类成员变量,形如self.xxx
  • buffer变量,需要调用self.register_buffer()方法将一个变量注册成buffer变量。

buffer变量存在的意义就在于: m.cuda()的时候,会自动把所有的parameters和buffers也移动到GPU上,而普通的类成员变量仍然存在于CPU中。

for k, v in m.named_buffers():
    print(k)
    print(v)

running_mean
tensor([0., 0.])
running_var
tensor([1., 1.])
num_batches_tracked
tensor(0)

for v in m.buffers():
    print(v)

tensor([0., 0.])
tensor([1., 1.])
tensor(0)
发布了173 篇原创文章 · 获赞 28 · 访问量 6万+

猜你喜欢

转载自blog.csdn.net/ECNU_LZJ/article/details/104203675