PyTorch的基础知识点记录

1 Model类的属性:m.parameters()
返回的是一个生成器。生成器返回的每一个元素是tensor,是网络的实际参数
(PyTorch version 1.6)

2 查看具体层的参数 (PyTorch version 1.10)
named_parameters, named_buffers

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

猜你喜欢

转载自blog.csdn.net/qq_29007291/article/details/117193087