tf.concat用法总结

版权声明:本博文欢迎分享与转载,转载请注明出处和作者。 https://blog.csdn.net/dream6104/article/details/88853616

tf.concat是连接两个矩阵的操作,tf.concat(values,dim,name='concat')

按照dim给定的维度进行拼接,即,相应的维度增加,例子如下:
 

 矩阵维度简单情形(shape为[2,3])

  t1 = [[1, 2, 3], [4, 5, 6]]
  t2 = [[7, 8, 9], [10, 11, 12]]

  拼接后结果:
  tf.concat([t1, t2], 0)  # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
  tf.concat([t1, t2], 1)  # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
 
  对拼接的结果shape
  tf.shape(tf.concat([t1, t2], 0))  #新的维度shape [4, 3]
  tf.shape(tf.concat([t1, t2], 1))  #新的维度shape [2, 6]

  
  这里解释了当axis=0和axis=1的情况,怎么理解这个axis呢?其实这和numpy中的np.concatenate()用法是一样的。

axis=0     代表在第0个维度拼接

axis=1     代表在第1个维度拼接 



  矩阵都是2*2*2维度的情形
  t1 = [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
  t2 = [[[7, 4], [8, 4]], [[2, 10], [15, 11]]]
  tf.concat([t1, t2], axis=-1)

  -1表示最后一个维度,最后一个维度增加
  
  输出结果为
  <tf.Tensor 'concat_2:0' shape=(2, 2, 4) dtype=int32>

猜你喜欢

转载自blog.csdn.net/dream6104/article/details/88853616