多参数时Torchsummary的使用方法

概述

  • Torchsummary是深度学习中常用的一款用来描述网络结构和各层参数的工具。在构建网络模型的时候,我们可以通过它来检查网络模型中的各项参数是否正确,进一步,是否能够给出正确维度的输出信息。
  • 对于网络模型的输入信息来说,除了传统的单一输入,还有多输入的网络模型。在多个输入时,使用torchsummary就会出现报错信息: TypeError: can’t multiply sequence by non-int of type ‘tuple’
  • 本案例使用的版本Torchsummary=1.5.1
  • 参考:
    1. https://github.com/sksq96/pytorch-summary/issues/90
    2. https://blog.csdn.net/qq_43733107/article/details/126508616

问题

根据报错信息,可以定位到报错的代码在torchsummary/torchsummary.py的Line 100:

total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))

其中的input_size就是输入信息的维度,如常见的图像信息:(3, 64, 64)。而当多输入的时候,无法直接通过np.prod实现参数的相乘。

解决方法

方法一

total_input_size = abs(np.prod(sum((input_size),())) * batch_size * 4. / (1024 ** 2.))

方法二

total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))

将原代码注释后,任选上述两种方法之一都能实现功能。从代码中也能看到,一个是先加后乘,另一个是遍历乘法运算之后累加。

猜你喜欢

转载自blog.csdn.net/kakangel/article/details/130795893