import tensorflow as tf
score_threshold = 3
x = tf.constant([[1],[2],[3],[4],[5]])
index = tf.where(tf.greater(x, score_threshold))
with tf.Session() as sess:
y = sess.run(index)
print(y)
'''
结果:
[[3 0]
[4 0]]
'''
再如
import tensorflow as tf
score_threshold = 3
x = tf.constant([1,2,3,4,5])
index = tf.where(tf.greater(x, score_threshold))
with tf.Session() as sess:
y = sess.run(index)
print(y)
'''
结果:
[[3]
[4]]
'''