tf.unstack讲解

tf.unstack

tf.unstack(
    value,
    num=None,
    axis=0,
    name='unstack'
)
'''
Args:
	value: A rank R > 0 Tensor to be unstacked.
	num: An int. The length of the dimension axis. Automatically inferred if None (the default).
	axis: An int. The axis to unstack along. Defaults to the first dimension. Negative values wrap around, so the valid range is [-R, R).
	name: A name for the operation (optional).
Returns:
	The list of Tensor objects unstacked from value.
'''

参数说明

value:需要分解的tensor
axis:沿着哪一个维度进行分解
num: 一个整数An int. The length of the dimension axis. Automatically inferred if None (the default).

将一个tensor(或者一个矩阵)分解,和tf.stcak作用相反
输入是一个tensor,输出一个有N(The length of the dimension axis)个tensor组成的list
given a tensor of shape (A, B, C, D);

  1. 如果 axis ==0 输出列表的 第 i’th tensor 的值是 value[i, :, :, :] , shape 是(B, C, D).
  2. 如果 axis ==1 输出列表的 第 i’th tensor 的值是 value[:, i, :, :] , shape 是(A, C, D).
  3. 如果 axis ==2 输出列表的 第 i’th tensor 的值是 value[:, :, i, :] , shape 是(A, B, D).
  4. 如果 axis ==3 输出列表的 第 i’th tensor 的值是 value[:, :, :, i] , shape 是(A, B, C).

例子1

import tensorflow as tf
a = tf.constant([[1, 2, 3], [4, 5, 6]])
b = tf.unstack(a, axis=0)
c = tf.unstack(a, axis=1)
with tf.Session() as sess:
		print(sess.run(a))
		print(sess.run(b))
		print(sess.run(c))

输出

[[1 2 3]
 [4 5 6]]
 
[array([1, 2, 3], dtype=int32), array([4, 5, 6], dtype=int32)]

[array([1, 4], dtype=int32), array([2, 5], dtype=int32), array([3, 6], dtype=int32)]

例子2

import tensorflow as tf
import numpy as np
a=tf.constant(np.random.randint(0,10,size=[,2,3]))
c_0=tf.unstack(a,axis=0)
c_1=tf.unstack(a,axis=1)
c_2=tf.unstack(a,axis=2)
with tf.Session() as sess:
    print("A:\n",sess.run(a))
    print("C_0:\n",sess.run(c_0))
    print("C_1:\n",sess.run(c_1))
    print("C_2:\n",sess.run(c_2))

输出

A:
 [[[4 5 8]
  [2 1 3]]]
C_0:
 [array([[4, 5, 8],
       [2, 1, 3]])]
C_1:
 [array([[4, 5, 8]]), array([[2, 1, 3]])]
C_2:
 [array([[4, 2]]), array([[5, 1]]), array([[8, 3]])]

猜你喜欢

转载自blog.csdn.net/qq_32806793/article/details/85223906