The parameter amount of the network model and the calculation of FLOPs Pytorch

Table of contents

1、torchstat 

2、top

3、fvcore 

4、flops_counter

5. Custom statistical functions


The difference between FLOPS and FLOPs:

  • FLOPS: Pay attention to all uppercase, it is the abbreviation of floating point operations per second, which means the number of floating point operations per second, which is understood as the calculation speed. It is a measure of hardware performance.
  • FLOPs: Note that s is lowercase, which is the abbreviation of floating point operations (s stands for plural), which means floating-point operands and is understood as the amount of calculation. It can be used to measure the complexity of the algorithm/model.

Before introducing the torchstat package and thop package, let's summarize:

  • The torchstat package can count the parameters and calculations of convolutional neural networks and fully connected neural networks.
  • The thop package can count the parameters and calculations of convolutional neural networks, fully connected neural networks, and recurrent neural networks. See below for program examples.

1、torchstat 

pip install torchstat -i https://pypi.tuna.tsinghua.edu.cn/simple

In actual operation, we can call the torchstat package to help us count the parameters and FLOPs of the model. If you do not modify some codes in this package, then this package is only suitable for models whose input is 3-channel images.

import torch
import torch.nn as nn
from torchstat import stat
 
 
class Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1, bias=False)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
 
 
model = Simple()
stat(model, (3, 244, 244))   # 统计模型的参数量和FLOPs,(3,244,244)是输入图像的size

 If the one-line program in the torchstat package is changed a little, then this package can be used to count the parameters and calculations of the fully connected neural network. Of course, manually calculating the parameters and calculations of the fully connected neural network is also very fast =_=. After entering the torchstat source code, as shown in the figure below, comment out the red circle, and then use the torchstat package to count the parameters and calculations of the fully connected neural network.

2、top

pip install thop -i https://pypi.tuna.tsinghua.edu.cn/simple
import torch
import torch.nn as nn
from thop import profile
 
class Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
 
    def forward(self, x):
        x = self.fc1(x)
        return x
 
net = Simple()
input = torch.randn(1, 10)  # batchsize=1, 输入向量长度为10
macs, params = profile(net, inputs=(input, ))
print(' FLOPs: ', macs*2)   # 一般来讲,FLOPs是macs的两倍
print('params: ', params)

3、fvcore 

pip install fvcore -i https://pypi.tuna.tsinghua.edu.cn/simple

it's better to use it

import torch
from torchvision.models import resnet50
from fvcore.nn import FlopCountAnalysis, parameter_count_table

# 创建resnet50网络
model = resnet50(num_classes=1000)

# 创建输入网络的tensor
tensor = (torch.rand(1, 3, 224, 224),)

# 分析FLOPs
flops = FlopCountAnalysis(model, tensor)
print("FLOPs: ", flops.total())

# 分析parameters
print(parameter_count_table(model))

 The terminal output results are as follows, FLOPs is 4089184256, and the number of model parameters is about 25.6M (the number of parameters here is somewhat different from my own calculation, mainly in the BN module, here only the two training parameters of beta and gamma are calculated, no statistics moving_mean and moving_var two parameters), for details, please see the issue I mentioned in the official.
Through the information printed by the terminal, we can find that the BN layer is not included in the calculation of FLOPs, and the pooling layer also has ordinary add operations (I found that there is no uniform regulation when calculating FLOPs. The calculation of FLOPs projects on github is basically every are all different, but the calculated results are similar).

Note: When using the fvcore module to calculate the flops of the model, I encountered a problem and recorded the solution. The first is an error on line 589 of jit_analysis.py. After debugging, it is found that the type of op_counts.values() is int32, but the types required for calculation can only be int, float, np.float64 and np.int64, so manual conversion is required. Modify as follows:

4、flops_counter

pip install ptflops -i https://pypi.tuna.tsinghua.edu.cn/simple

Using it is also fine, the result is the same as fvcore

from ptflops import get_model_complexity_info

macs, params = get_model_complexity_info(model, (112, 9, 9), as_strings=True,
                                         print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

5. Custom statistical functions

import torch
import numpy as np

def calc_flops(model, input):
    def conv_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (
            2 if multiply_adds else 1)
        bias_ops = 1 if self.bias is not None else 0

        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_conv.append(flops)

    def linear_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        num_steps = input[0].size(0)
        weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
        bias_ops = self.bias.nelement() if self.bias is not None else 0

        flops = batch_size * (weight_ops + bias_ops)
        flops *= num_steps
        list_linear.append(flops)

    def fsmn_hook(self, input, output):
        batch_size = input[0].size(0) if input[0].dim() == 2 else 1

        weight_ops = self.filter.nelement() * (2 if multiply_adds else 1)
        num_steps = input[0].size(0)
        flops = num_steps * weight_ops
        flops *= batch_size
        list_fsmn.append(flops)

    def gru_cell(input_size, hidden_size, bias=True):
        total_ops = 0
        # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
        # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
        state_ops = (hidden_size + input_size) * hidden_size + hidden_size
        if bias:
            state_ops += hidden_size * 2
        total_ops += state_ops * 2

        # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
        total_ops += (hidden_size + input_size) * hidden_size + hidden_size
        if bias:
            total_ops += hidden_size * 2
        # r hadamard : r * (~)
        total_ops += hidden_size

        # h' = (1 - z) * n + z * h
        # hadamard hadamard add
        total_ops += hidden_size * 3

        return total_ops

    def gru_hook(self, input, output):

        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        if self.batch_first:
            batch_size = input[0].size(0)
            num_steps = input[0].size(1)
        else:
            batch_size = input[0].size(1)
            num_steps = input[0].size(0)
        total_ops = 0
        bias = self.bias
        input_size = self.input_size
        hidden_size = self.hidden_size
        num_layers = self.num_layers
        total_ops = 0
        total_ops += gru_cell(input_size, hidden_size, bias)
        for i in range(num_layers - 1):
            total_ops += gru_cell(hidden_size, hidden_size, bias)
        total_ops *= batch_size
        total_ops *= num_steps

        list_lstm.append(total_ops)

    def lstm_cell(input_size, hidden_size, bias):
        total_ops = 0
        state_ops = (input_size + hidden_size) * hidden_size + hidden_size
        if bias:
            state_ops += hidden_size * 2
        total_ops += state_ops * 4
        total_ops += hidden_size * 3
        total_ops += hidden_size
        return total_ops

    def lstm_hook(self, input, output):

        batch_size = input[0].size(0) if input[0].dim() == 2 else 1
        if self.batch_first:
            batch_size = input[0].size(0)
            num_steps = input[0].size(1)
        else:
            batch_size = input[0].size(1)
            num_steps = input[0].size(0)
        total_ops = 0
        bias = self.bias
        input_size = self.input_size
        hidden_size = self.hidden_size
        num_layers = self.num_layers
        total_ops = 0
        total_ops += lstm_cell(input_size, hidden_size, bias)
        for i in range(num_layers - 1):
            total_ops += lstm_cell(hidden_size, hidden_size, bias)
        total_ops *= batch_size
        total_ops *= num_steps

        list_lstm.append(total_ops)

    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())

    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    def pooling_hook(self, input, output):
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()

        kernel_ops = self.kernel_size * self.kernel_size
        bias_ops = 0
        params = output_channels * (kernel_ops + bias_ops)
        flops = batch_size * params * output_height * output_width

        list_pooling.append(flops)

    def foo(net):
        childrens = list(net.children())
        if not childrens:
            print(net)
            if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
                net.register_forward_hook(conv_hook)
                # print('conv_hook_ready')
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
                # print('linear_hook_ready')
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
                # print('batch_norm_hook_ready')
            if isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU):
                net.register_forward_hook(relu_hook)
                # print('relu_hook_ready')
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
                # print('pooling_hook_ready')
            if isinstance(net, torch.nn.LSTM):
                net.register_forward_hook(lstm_hook)
                # print('lstm_hook_ready')
            if isinstance(net, torch.nn.GRU):
                net.register_forward_hook(gru_hook)

            # if isinstance(net, FSMNZQ):
            #     net.register_forward_hook(fsmn_hook)
                # print('fsmn_hook_ready')
            return
        for c in childrens:
            foo(c)

    multiply_adds = False
    list_conv, list_bn, list_relu, list_linear, list_pooling, list_lstm, list_fsmn = [], [], [], [], [], [], []
    foo(model)

    _ = model(input)

    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(
        list_lstm) + sum(list_fsmn))
    fsmn_flops = (sum(list_fsmn) + sum(list_linear))
    lstm_flops = sum(list_lstm)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('The network has {} params.'.format(params))

    print(total_flops, fsmn_flops, lstm_flops)
    print('  + Number of FLOPs: %.2f M' % (total_flops / 1000 ** 2))
    return total_flops

if __name__ == '__main__':
    from torchvision.models import resnet18

    model = resnet18(num_classes=1000)
    imput_size = torch.rand((1,3,224,224))
    calc_flops(model, imput_size)

Guess you like

Origin blog.csdn.net/qq_45100200/article/details/127728053