Tensorflow深度学习之三十一:tf.nn.top_k()

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DaVinciL/article/details/82559904

一、简介

def top_k(input, k=1, sorted=True, name=None)

Finds values and indices of the k largest entries for the last dimension.

If the input is a vector (rank=1), finds the k largest entries in the vector and outputs their values and indices as vectors.Thus values[j] is the j-th largest entry in input, and its index is indices[j].

For matrices (resp. higher rank input), computes the top k entries in each row (resp. vector along the last dimension).Thus, values.shape = indices.shape = input.shape[:-1] + [k]

If two elements are equal, the lower-index element appears first.

翻译:
查找最后一个维度的前k 个最大条目的值和索引。

如果输入是向量(rank = 1),则在向量中找到前k个最大条目,并将它们的值和索引作为向量输出。因此values[j]input 中的第j 个最大条目,它的索引是indices [j]

对于矩阵(或更高级别的输入),计算每行中的顶部k条目(沿最后一个维度的每一个向量)。因此,values.shape = indices.shape = input.shape [: - 1] + [K]

如果两个元素相等,则首先显示lower-index(索引值较低)元素。

注:该函数返回的数据包含两部分,第一部分是返回的value值,第二部分返回的是对应的索引值。可以通过索引[0]或者[1]进行访问。

二、参数

参数
input 1-D or higher Tensor with last dimension at least k. 一个一维或者更高维度的Tensor,他的最后一维的数目至少为k。
k 0-D int32 Tensor. Number of top elements to look for along the last dimension (along each row for matrices). 一个整形Tensor,表示沿着最后一个维度去寻找的元素的数目。
sorted If true the resulting k elements will be sorted by the values in descending order. 如果该值被设置为True,则返回的k个元素会被按照值的从大到小的顺序进行排序。默认为True。
name Optional name for the operation. 可选参数,名称

三、代码

import tensorflow as tf
import numpy as np

# 建立一个长度为10的向量,内部数据随机生成。
a = tf.convert_to_tensor(np.random.random([10]))

# 取出前5个最大的数据,默认从大到小进行排序。
b = tf.nn.top_k(a, 5)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(b[1]))

运行结果:

[0.09673178 0.2011694  0.77118243 0.20476724 0.3439558  0.69864978
 0.2118251  0.32904677 0.87435634 0.47136589]
TopKV2(values=array([0.87435634, 0.77118243, 0.69864978, 0.47136589, 0.3439558 ]), indices=array([8, 2, 5, 9, 4]))
[8 2 5 9 4]

当传入更高维度的数据时:

import tensorflow as tf
import numpy as np

# 定义一个三维的矩阵,内部数据随机产生
a = tf.convert_to_tensor(np.random.random([20, 20, 10]))

# 按照最后一个维度取出前5个最大的数据,默认从大到小进行排序。
b = tf.nn.top_k(a, 5)

with tf.Session() as sess:
    print(sess.run(a))
    print(sess.run(b))
    print(sess.run(b)[1].shape)

运行结果: 最值得注意的是最后一个返回值的shape,只有最后一个维度有所区别。

[[[0.8498202  0.05195572 0.8849565  ... 0.66397947 0.54824224 0.74318886]
  [0.49996231 0.91040108 0.21483549 ... 0.04122947 0.64088468 0.32510497]
  [0.90725498 0.68344152 0.43061874 ... 0.39102586 0.12769082 0.66023738]
  ...
  [0.60666856 0.56439855 0.28063549 ... 0.93124743 0.89449678 0.66979802]
  [0.76200935 0.06834749 0.85145249 ... 0.67836563 0.01516219 0.01993689]
  [0.47049275 0.50707521 0.36991098 ... 0.88998056 0.12763079 0.09845498]]

 [[0.57171017 0.15238957 0.08806684 ... 0.02480321 0.48453851 0.85199458]
  [0.35878106 0.30580091 0.22070303 ... 0.42346321 0.22950292 0.18906091]
  [0.90136589 0.41240145 0.52366428 ... 0.69907391 0.26080453 0.19672214]
  ...
  [0.39987234 0.93231962 0.02967131 ... 0.38570163 0.52938515 0.89505879]
  [0.66779964 0.62346695 0.84506223 ... 0.57041431 0.12558373 0.75406602]
  [0.04802938 0.96657687 0.07476398 ... 0.93957134 0.88229134 0.48934519]]

 [[0.03006041 0.0136604  0.75244466 ... 0.65651256 0.39410724 0.83654045]
  [0.71498666 0.56440115 0.95761964 ... 0.02704624 0.51868975 0.44324936]
  [0.41980744 0.63474661 0.58030962 ... 0.20945427 0.29488566 0.07749595]
  ...
  [0.11727653 0.9169551  0.02627972 ... 0.8763961  0.36451567 0.96754857]
  [0.28255761 0.22505311 0.74507012 ... 0.23504345 0.20330998 0.04071097]
  [0.73204599 0.50676066 0.0524236  ... 0.74684682 0.93345544 0.83705093]]

 ...

 [[0.64496108 0.66815738 0.17245006 ... 0.43895167 0.89021163 0.65442853]
  [0.8690804  0.44297673 0.48261915 ... 0.71620392 0.28584558 0.60172575]
  [0.31634969 0.39460366 0.25693086 ... 0.93440372 0.50671148 0.2486601 ]
  ...
  [0.71044313 0.32806087 0.70054147 ... 0.80219637 0.96946221 0.76465067]
  [0.35188569 0.83711553 0.01343541 ... 0.28523762 0.45159021 0.81395335]
  [0.52934446 0.23226338 0.28012356 ... 0.13028752 0.9962975  0.44482207]]

 [[0.89439131 0.60870675 0.21073087 ... 0.62333398 0.52917202 0.69767772]
  [0.94700397 0.14408882 0.96524112 ... 0.75613067 0.76415524 0.22070657]
  [0.58182603 0.63138273 0.24297734 ... 0.01150216 0.91135157 0.56416608]
  ...
  [0.73974793 0.93020208 0.82434553 ... 0.73215145 0.42041154 0.34463405]
  [0.59814222 0.49599991 0.4764923  ... 0.27145421 0.87418982 0.70327742]
  [0.61134091 0.96387942 0.31842696 ... 0.38037157 0.51440121 0.94851797]]

 [[0.22655945 0.05248473 0.47943931 ... 0.45506608 0.32513959 0.04213444]
  [0.33406586 0.34820628 0.59872586 ... 0.01636161 0.34377442 0.4370155 ]
  [0.98888032 0.62710205 0.92201311 ... 0.27882558 0.46042077 0.4403413 ]
  ...
  [0.49680129 0.41594056 0.93365285 ... 0.87372742 0.70665113 0.15976358]
  [0.48933501 0.31931995 0.92455068 ... 0.76884526 0.3875951  0.12877622]
  [0.16327613 0.35248604 0.90702435 ... 0.33775252 0.60606198 0.05021601]]]
TopKV2(values=array([[[0.8849565 , 0.8498202 , 0.74318886, 0.66397947, 0.54824224],
        [0.91040108, 0.88671867, 0.64088468, 0.6227422 , 0.55252928],
        [0.90725498, 0.80000614, 0.68344152, 0.66023738, 0.43061874],
        ...,
        [0.93124743, 0.89449678, 0.84382658, 0.69857909, 0.66979802],
        [0.9937199 , 0.85145249, 0.79742674, 0.76200935, 0.67836563],
        [0.88998056, 0.82158075, 0.63521181, 0.63428801, 0.50707521]],

       [[0.85199458, 0.62084884, 0.58125283, 0.57171017, 0.48453851],
        [0.42346321, 0.4149812 , 0.41135642, 0.35878106, 0.30580091],
        [0.90136589, 0.82570664, 0.77936049, 0.69907391, 0.65195422],
        ...,
        [0.93231962, 0.89505879, 0.70673449, 0.65217635, 0.61258099],
        [0.84506223, 0.77861116, 0.75406602, 0.71340457, 0.66779964],
        [0.96657687, 0.93957134, 0.9157956 , 0.88229134, 0.71152801]],

       [[0.83654045, 0.82614004, 0.75244466, 0.65651256, 0.5979959 ],
        [0.97347234, 0.95761964, 0.74878137, 0.71498666, 0.56440115],
        [0.69586341, 0.66677399, 0.63474661, 0.58030962, 0.50256754],
        ...,
        [0.96754857, 0.9654071 , 0.9169551 , 0.8763961 , 0.75280918],
        [0.89399932, 0.74507012, 0.72503987, 0.70364816, 0.30463687],
        [0.93345544, 0.84302607, 0.83705093, 0.74684682, 0.73204599]],

       ...,

       [[0.89021163, 0.87103578, 0.66815738, 0.65442853, 0.64496108],
        [0.98479821, 0.8690804 , 0.76023822, 0.71620392, 0.60994919],
        [0.99097209, 0.93440372, 0.50671148, 0.39460366, 0.31634969],
        ...,
        [0.96946221, 0.80219637, 0.76465067, 0.71044313, 0.70137228],
        [0.92060929, 0.83711553, 0.81395335, 0.45159021, 0.35812735],
        [0.9962975 , 0.93294538, 0.59146199, 0.52934446, 0.44482207]],

       [[0.89439131, 0.80978689, 0.78357729, 0.69767772, 0.62333398],
        [0.97394143, 0.96524112, 0.94700397, 0.76415524, 0.75613067],
        [0.99029969, 0.91135157, 0.63266259, 0.63138273, 0.58182603],
        ...,
        [0.96092679, 0.93020208, 0.82434553, 0.79132076, 0.73974793],
        [0.87418982, 0.7208272 , 0.70327742, 0.59814222, 0.49599991],
        [0.96387942, 0.96196898, 0.94851797, 0.61134091, 0.51440121]],

       [[0.56481086, 0.47943931, 0.45506608, 0.32513959, 0.31256481],
        [0.74507118, 0.72854902, 0.59872586, 0.4370155 , 0.35918261],
        [0.98888032, 0.92201311, 0.90066467, 0.81044143, 0.70324162],
        ...,
        [0.97444993, 0.93365285, 0.87372742, 0.83351393, 0.80496823],
        [0.92455068, 0.76884526, 0.62140043, 0.55610083, 0.48933501],
        [0.90702435, 0.9067023 , 0.68597384, 0.60606198, 0.35248604]]]), indices=array([[[2, 0, 9, 7, 8],
        [1, 4, 8, 6, 3],
        [0, 4, 1, 9, 2],
        ...,
        [7, 8, 6, 3, 9],
        [3, 2, 5, 0, 7],
        [7, 5, 3, 4, 1]],

       [[9, 3, 5, 0, 8],
        [7, 5, 6, 0, 1],
        [0, 4, 6, 7, 5],
        ...,
        [1, 9, 6, 3, 4],
        [2, 5, 9, 4, 0],
        [1, 7, 5, 8, 4]],

       [[9, 3, 2, 7, 4],
        [5, 2, 3, 0, 1],
        [6, 5, 1, 2, 4],
        ...,
        [9, 6, 1, 7, 4],
        [6, 2, 3, 4, 5],
        [8, 3, 9, 7, 0]],

       ...,

       [[8, 5, 1, 9, 0],
        [6, 0, 4, 7, 5],
        [3, 7, 8, 1, 0],
        ...,
        [8, 7, 9, 0, 6],
        [4, 1, 9, 8, 3],
        [8, 5, 3, 0, 9]],

       [[0, 4, 5, 9, 7],
        [5, 2, 0, 8, 7],
        [5, 8, 6, 1, 0],
        ...,
        [6, 1, 2, 5, 0],
        [8, 5, 9, 0, 1],
        [1, 5, 9, 0, 8]],

       [[6, 2, 7, 8, 4],
        [3, 6, 2, 9, 4],
        [0, 2, 5, 3, 4],
        ...,
        [3, 2, 7, 6, 5],
        [2, 7, 6, 4, 0],
        [2, 5, 3, 8, 1]]]))
(20, 20, 5)

猜你喜欢

转载自blog.csdn.net/DaVinciL/article/details/82559904