keras中的 concatenate() 详解

最近看模态融合,用到了 keras 中的 concatenate() 函数,之前没有搞明白其中的 axis 这个参数是什么意思。后来经过一番研究,总算是搞明白了。

先看代码

import numpy as np
import keras.backend as K
import tensorflow as tf

a = K.variable(np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))
b = K.variable(np.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]))

c1 = K.concatenate([a, b], axis=0)
c2 = K.concatenate([a, b], axis=1)
c3 = K.concatenate([a, b], axis=2)
#试试默认的参数,其实就是从倒数第一个维度进行融合的。
c4 = K.concatenate([a, b])
c5 = K.concatenate([a, b],axis=-1)

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('***************')
    print(a.shape,b.shape)
    print('***************')
    print('*****C1******',c1.shape)
    print(sess.run(c1))
    print()
    print('*****C2******',c2.shape)
    print(sess.run(c2))
    print()
    print('*****C3******',c3.shape)
    print(sess.run(c3))
    print()
    print('*****C4******',c4.shape)
    print(sess.run(c4))
    print('*****C5******',c5.shape)
    print(sess.run(c5))
    

再看看输出的效果:

在这里插入图片描述



axis=n表示从第n个维度进行拼接,对于一个三维矩阵,axis的取值可以为[-3, -2, -1, 0, 1, 2]。
axis=-2,意思是从倒数第2个维度进行拼接,对于三维矩阵而言,这就等同于axis=1。
axis=-1,意思是从倒数第1个维度进行拼接,对于三维矩阵而言,这就等同于axis=2。

简单点理解:

可能从图像上来理解比较复杂,但是如果从数学的角度来 看很简单,就比如上边的例子

两个 (2,2,2)(2,2,2)的数组进行融合,,

  • 第一个维度融合就是(4,2,2),即 axis=0
  • 第二个维度融合就是(2,4,2),即 axis=1
  • 第三个维度融合就是(2,2,4),即 axis=2

参考文献

[1]https://blog.csdn.net/leviopku/article/details/82380710
[2]https://zhuanlan.zhihu.com/p/58672698

猜你喜欢

转载自blog.csdn.net/zhaozhao236/article/details/109434254
今日推荐