【tensorflow】tf.sparse_split用法——使用tf.sparse_split拆分sparse_tensor

我们在实际tensorflow应用中,如果遇到保存稀疏矩阵的时候,会选择Sparse_tensor,这样可以节省大量的空间。
但是如果想要拆分稀疏矩阵的时候,直观的思路是:先将spare_tensor转为dense_tensor,然后拆分,然后再转成spare_tensor,这个过程中耗时不说,专程dense实际上就违背了我们节省空间的初衷。

  • 正确的解决方式是:

def sparse_split(keyword_required=KeywordRequired(),
                 sp_input=None,
                 num_split=None,
                 axis=None,
                 name=None,
                 split_dim=None):
  """Split a `SparseTensor` into `num_split` tensors along `axis`.

  If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split`
  each slice starting from 0:`shape[axis] % num_split` gets extra one
  dimension. For example, if `axis = 1` and `num_split = 2` and the
  input is:

      input_tensor = shape = [2, 7]
      [    a   d e  ]
      [b c          ]

  Graphically the output tensors are:

      output_tensor[0] =
      [    a ]
      [b c   ]

      output_tensor[1] =
      [ d e  ]
      [      ]

  Args:
    keyword_required: Python 2 standin for * (temporary for argument reorder)
    sp_input: The `SparseTensor` to split.
    num_split: A Python integer. The number of ways to split.
    axis: A 0-D `int32` `Tensor`. The dimension along which to split.
    name: A name for the operation (optional).
    split_dim: Deprecated old name for axis.

  Returns:
    `num_split` `SparseTensor` objects resulting from splitting `value`.

  Raises:
    TypeError: If `sp_input` is not a `SparseTensor`.
    ValueError: If the deprecated `split_dim` and `axis` are both non None.
  """
  • 举例说明用法:
import tensorflow as tf

a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))

b,c = tf.sparse_split(sp_input=a,num_split=2,axis=1)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))

输出是:

SparseTensorValue(indices=array([[0, 0],
       [1, 1]]), values=array([1, 2], dtype=int32), dense_shape=array([2, 2]))
SparseTensorValue(indices=array([[0, 0]]), values=array([1], dtype=int32), dense_shape=array([2, 1]))
SparseTensorValue(indices=array([[1, 0]]), values=array([2], dtype=int32), dense_shape=array([2, 1]))
  • 需要注意的是:
    上面是使用python3的版本,如果使用python2,必须传入keyword_required参数,否则会报错:

      Keyword arguments are required for this function
    

python2的调用方法为:

from tensorflow.python.ops.sparse_ops import KeywordRequired
import tensorflow as tf

a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))

b,c = tf.sparse_split(keyword_required=KeywordRequired(),sp_input=a,num_split=2,axis=1)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(c))

这样就能解决报错的问题。

发布了97 篇原创文章 · 获赞 55 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/voidfaceless/article/details/103168407
今日推荐