tensorflow中的可分离卷积

可分离卷积的官方文档:

http://www.tensorfly.cn/tfdoc/api_docs/python/nn.html#separable_conv2d

tf.nn.depthwise_conv2d(input, filter, strides, padding, name=None)

参数介绍:

  • input: 4-D with shape [batch, in_height, in_width, in_channels].
  • filter: 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  • strides: 1-D of size 4. The stride of the sliding window for each dimension of input.
  • padding: A string, either 'VALID' or 'SAME'. The padding algorithm.
  • name: A name for this operation (optional).
  • A 4-D Tensor of shape [batch, out_height, out_width, in_channels * channel_multiplier].

示例:

import tensorflow as tf
import numpy as np

a = tf.constant(
            [
            [[1,14,3],[4,2,1],[4,5,6]],
            [[1,-10,3],[4,5,6],[5,-100,6]],
            [[1,14,3],[4,5,6],[1,4,2]]
         ],dtype=tf.float32)
#把a的shape扩展为4维的
a = tf.expand_dims(a,2)

depthwish_filter = tf.get_variable(name = '1' , shape=[2,1,3,1], initializer=tf.ones_initializer)
c = tf.nn.depthwise_conv2d(a,depthwish_filter,[1,1,1,1],'VALID')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(c))
 
#
[
    [[[  5.  16.   4.]]

    [[  8.   7.   7.]]]


    [[[  5.  -5.   9.]]

    [[  9. -95.  12.]]]


  [[[  5.  19.   9.]]

    [[  5.   9.   8.]]]
]

注意:

filter: 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].

常见的是1,我也不清楚这个参数的意义,这里修改为2:

import tensorflow as tf

a = tf.constant(
  [
    [[1,14,3],[4,2,1],[4,5,6]],
    [[1,-10,3],[4,5,6],[5,-100,6]],
    [[1,14,3],[4,5,6],[1,4,2]]
    ],dtype=tf.float32
    )
a = tf.expand_dims(a,2)

depthwish_filter = tf.get_variable(name = '1' , shape=[2,1,3,2], initializer=tf.ones_initializer)
c = tf.nn.depthwise_conv2d(a,depthwish_filter,[1,1,1,1],'VALID')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(c))

#

[
    [[[  5.   5.  16.  16.   4.   4.]]

      [[  8.   8.   7.   7.   7.   7.]]]


     [[[  5.   5.  -5.  -5.   9.   9.]]

      [[  9.   9. -95. -95.  12.  12.]]]


     [[[  5.   5.  19.  19.   9.   9.]]

      [[  5.   5.   9.   9.   8.   8.]]]
]

这里有点2倍的意味(瞎猜的),这个参数目前没有见过不为1的情况。

可分离卷积:

tf.nn.separable_conv2d(input,depthwise_filter, pointwise_filter, strides, padding, name=None) 

参数:

  • input: 4-D Tensor with shape [batch, in_height, in_width, in_channels].
  • depthwise_filter: 4-D Tensor with shape [filter_height, filter_width, in_channels, channel_multiplier]. Contains in_channels convolutional filters of depth 1.
  • pointwise_filter: 4-D Tensor with shape [1, 1, channel_multiplier * in_channels, out_channels]. Pointwise filter to mix channels after depthwise_filter has convolved spatially.
  • strides: 1-D of size 4. The strides for the depthwise convolution for each dimension of input.
  • padding: A string, either 'VALID' or 'SAME'. The padding algorithm.

Returns:

A 4-D Tensor of shape [batch, out_height, out_width, out_channels].

示例:

import tensorflow as tf
import numpy as np
print(5//2)
import os
a = tf.constant([
    [[1,14,3],[4,2,1],[4,5,6]],
    [[1,-10,3],[4,5,6],[5,-100,6]],
    [[1,14,3],[4,5,6],[1,4,2]]
    ],dtype=tf.float32
    )
a = tf.expand_dims(a,2)

depthwish_filter = tf.get_variable(name = '1' , shape=[2,1,3,1], initializer=tf.ones_initializer)
# 逐层卷积
c = tf.nn.depthwise_conv2d(a,depthwish_filter,[1,1,1,1],'VALID')


pointwish_filter = tf.get_variable(name='2',shape=[1,1,3,4],initializer=tf.ones_initializer)
#可分离卷积,这里的卷积核有2个
b = tf.nn.separable_conv2d(a,depthwish_filter,pointwish_filter,[1,1,1,1],padding='VALID')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(c))
    print(sess.run(b))

# 
c:
[
    [[[  5.  16.   4.]]

     [[  8.   7.   7.]]]


     [[[  5.  -5.   9.]]

      [[  9. -95.  12.]]]


     [[[  5.  19.   9.]]

      [[  5.   9.   8.]]]
]

b:
[
    [[[ 25.  25.  25.  25.]]

      [[ 22.  22.  22.  22.]]]


     [[[  9.   9.   9.   9.]]

      [[-74. -74. -74. -74.]]]


     [[[ 33.  33.  33.  33.]]

      [[ 22.  22.  22.  22.]]]
]

猜你喜欢

转载自blog.csdn.net/biubiubiu888/article/details/82024765
今日推荐