pythonJax小记(一):python: 使用Jax查找数组中特定值(持续更新,评论区可以补充)

python: 使用Jax查找数组中特定值


前言

自用,刚开始接触可能顺序会比较乱。


直接上代码

import jax.numpy as jnp
from jax import jit

def find_positions(matrix, targets):
    # 使用广播和比较来创建一个布尔数组,其中True代表匹配的元素
    match_positions = matrix == targets[:, None, None]
    # 获取匹配位置的索引
    positions = jnp.argwhere(match_positions)


# 示例二维数组和一维目标数组:在matrix数组中找到targets数组,并且返回位置
matrix = jnp.array([[1, 2, 3], 
					[4, 5, 4], 
					[7, 8, 9], 
					[8, 4, 2]])
					
targets = jnp.array([2, 4, 8])

# 查找位置
positions = find_positions(matrix, targets)

print("Positions:\n", positions)

输出:

Positions:
[[0 0 1]
[0 3 2]
[1 1 0]
[1 1 2]
[1 3 1]
[2 2 1]
[2 3 0]]

解释:

这段代码首先定义了一个函数find_positions,它接受一个二维数组matrix和一个一维数组targets作为输入,然后返回targets中每个元素在matrix中的位置。这个函数使用了JAX的广播机制来比较targets中的每个元素和matrix中的每个元素,产生一个布尔数组,最后通过jnp.argwhere找到匹配元素的位置。

  1. 广播机制: 当我们使用matrix == targets[:, None, None]这行代码时,我们利用了JAX(以及NumPy的)广播机制。targets[:, None, None]这个操作将targets数组从一维扩展到三维,使得每个元素都可以与matrix的每个元素进行比较。这样,我们得到一个布尔数组match_positions,其形状为(len(targets), matrix.shape[0], matrix.shape[1]),其中True值表示matrix中的元素与targets中的某个元素匹配。
match_positions:
[[[False  True False]
  [False False False]
  [False False False]
  [False False  True]]
 
[[False False False]
  [ True False  True]
  [False False False]
  [False  True False]]
 
[[False False False]
  [False False False]
  [False  True False]
  [ True False False]]]
  1. 查找匹配位置: jnp.argwhere(match_positions)这行代码用于查找match_positions中值为True的元素的索引,即找到所有目标值在二维数组中的位置。argwhere返回的是一个数组,其中每个元素是匹配位置的索引,形式为(target_index, row, column)
Positions:
[[0 0 1] # 2匹配的位置
[0 3 2] # 2匹配的位置
[1 1 0] # 4匹配的位置
[1 1 2] # 4匹配的位置
[1 3 1] # 4匹配的位置
[2 2 1] # 8匹配的位置
[2 3 0]] # 8匹配的位置
  1. matrix == targets[:, None, None]这个表达式中,使用的是双等号==,这意味着我们在做比较操作,而不是赋值操作。具体来说,这个表达式是在比较二维数组matrix中的每个元素是否等于targets数组中的每个值。由于使用了广播机制,targets[:, None, None]targets数组扩展成一个三维数组,使得可以与matrix的每个元素进行逐元素比较。这个操作返回一个布尔数组,其中的每个值表示相应的比较结果:True表示相等,False表示不相等

如果没有匹配到则返回空的array。

Positions:
[]

猜你喜欢

转载自blog.csdn.net/xzs1210652636/article/details/136265534