深度可分离卷积结构(depthwise separable convolution)计算复杂度分析

前言

这个例子说明了什么叫做空间可分离卷积,这种方法并不应用在深度学习中,只是用来帮你理解这种结构。

在神经网络中,我们通常会使用深度可分离卷积结构(depthwise separable convolution)。

这种方法在保持通道分离的前提下,接上一个深度卷积结构,即可实现空间卷积。接下来通过一个例子让大家更好地理解。

算法原理

假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,从而产生16×32=512个特征图谱。进而通过叠加每个输入通道对应的特征图谱后融合得到1个特征图谱。最后可得到所需的32个输出通道。

针对这个例子应用深度可分离卷积,用1个3×3大小的卷积核遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,进行相加融合。这个过程使用了16×3×3+16×32×1×1=656个参数,远少于上面的16×32×3×3=4608个参数。

这个例子就是深度可分离卷积的具体操作,其中上面的深度乘数(depth multiplier)设为1,这也是目前这类网络层的通用参数。

这么做是为了对空间信息和深度信息进行去耦。从Xception模型的效果可以看出,这种方法是比较有效的。由于能够有效利用参数,因此深度可分离卷积也可以用于移动设备中。

这里写图片描述

下面的代码是mobilenet的一个参数列表,计算的普通卷积与深度分离卷积的计算复杂程度比较 : 代码链接

# Tensorflow mandates these.
from collections import namedtuple
import functools

import tensorflow as tf

slim = tf.contrib.slim

# Conv and DepthSepConv namedtuple define layers of the MobileNet architecture
# Conv defines 3x3 convolution layers
# DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution.
# stride is the stride of the convolution
# depth is the number of channels or filters in a layer
Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])

# _CONV_DEFS specifies the MobileNet body
_CONV_DEFS = [
    Conv(kernel=[3, 3], stride=2, depth=32),
    DepthSepConv(kernel=[3, 3], stride=1, depth=64),
    DepthSepConv(kernel=[3, 3], stride=2, depth=128),
    DepthSepConv(kernel=[3, 3], stride=1, depth=128),
    DepthSepConv(kernel=[3, 3], stride=2, depth=256),
    DepthSepConv(kernel=[3, 3], stride=1, depth=256),
    DepthSepConv(kernel=[3, 3], stride=2, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=1, depth=512),
    DepthSepConv(kernel=[3, 3], stride=2, depth=1024),
    DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
]

input_size = 160
inputdepth = 3
conv_defs = _CONV_DEFS
sumcost = 0
for i, conv_def in enumerate(conv_defs):
    stride = conv_def.stride
    kernel = conv_def.kernel
    outdepth = conv_def.depth
    output_size = round((input_size - int(kernel[0] / 2) * 2) / stride)
    if isinstance(conv_def, Conv):
        sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    if isinstance(conv_def, DepthSepConv):
        sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    inputdepth = outdepth
    input_size = output_size
print("src conv:    ", sumcost)

input_size = 160
inputdepth = 3
conv_defs = _CONV_DEFS
sumcost1 = 0
for i, conv_def in enumerate(conv_defs):
    stride = conv_def.stride
    kernel = conv_def.kernel
    outdepth = conv_def.depth
    output_size = round((input_size - int(kernel[0] / 2) * 2) / stride)
    if isinstance(conv_def, Conv):
        sumcost1 += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
    if isinstance(conv_def, DepthSepConv):
        #sumcost += output_size * output_size * kernel[0] * kernel[0] * inputdepth * outdepth
        sumcost1 += output_size * output_size *(inputdepth * kernel[0] * kernel[0]  + inputdepth * outdepth * 1 * 1)
    inputdepth = outdepth
    input_size = output_size
print("DepthSepConv:", sumcost1)
print("compare:", sumcost1 / sumcost)
src conv:       1045417824
DepthSepConv:   126373376
compare: 0.12088312739538674

猜你喜欢

转载自blog.csdn.net/qq_29462849/article/details/80794159