tf.split函数说明

版权声明:转载请标明出处,谢谢! https://blog.csdn.net/kdongyi/article/details/82910632

函数形式:

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

参数:

  1. value :输入的tensor
  2. num_or_size_splits :每个分割后的张量的尺寸,如果是个整数n,就将输入的tensor分为n个子tensor。如果是个tensor T,就将输入的tensor分为len(T)个子tensor。 
  3. axis :A 0-D int32 Tensor;表示分割的尺寸;必须在[-rank(value), rank(value))范围内;默认为0。
  4. num=None :可选的,用于指定无法从 size_splits 的形状推断出的输出数。
  5. name :操作的名称(可选)

 用途:将张量分割成子张量。 

  1. 如果num_or_size_splits是整数类型,num_split,则value沿维度 axis 分割成为num_split更小的张量。要求num_split均匀分配value.shape[axis]。

  2. 如果num_or_size_splits不是整数类型,则它被认为是一个张量size_splits,然后将value分割成len(size_splits)块。第i部分的形状与value的大小相同,除了沿维度axis之外的大小size_splits[i]。


代码实例: 

import tensorflow as tf

value = [[1, 2, 3], [4, 5, 6]]

split1, split2 = tf.split(value, [0, 2], 0)
split3, split4, split5 = tf.split(value, [1, 0, 2], 1)

with tf.Session() as sess:
	print("第一个变换结果:")
	print(sess.run(split1))
	print("**************")
	print(sess.run(split2))
	print("**************\r\n")

	print("第二个变换结果:")
	print(sess.run(split3))
	print("**************")
	print(sess.run(split4))
	print("**************")
	print(sess.run(split5))

运行结果:

第一个变换结果:
[]
**************
[[1 2 3]
 [4 5 6]]
**************

第二个变换结果:
[[1]
 [4]]
**************
[]
**************
[[2 3]
 [5 6]]

 

猜你喜欢

转载自blog.csdn.net/kdongyi/article/details/82910632