前言
自用,刚开始接触可能顺序会比较乱。
直接上代码
import jax.numpy as jnp
from jax import jit
def find_positions(matrix, targets):
# 使用广播和比较来创建一个布尔数组,其中True代表匹配的元素
match_positions = matrix == targets[:, None, None]
# 获取匹配位置的索引
positions = jnp.argwhere(match_positions)
# 示例二维数组和一维目标数组:在matrix数组中找到targets数组,并且返回位置
matrix = jnp.array([[1, 2, 3],
[4, 5, 4],
[7, 8, 9],
[8, 4, 2]])
targets = jnp.array([2, 4, 8])
# 查找位置
positions = find_positions(matrix, targets)
print("Positions:\n", positions)
输出:
Positions:
[[0 0 1]
[0 3 2]
[1 1 0]
[1 1 2]
[1 3 1]
[2 2 1]
[2 3 0]]
解释:
这段代码首先定义了一个函数find_positions
,它接受一个二维数组matrix
和一个一维数组targets
作为输入,然后返回targets
中每个元素在matrix
中的位置。这个函数使用了JAX的广播机制来比较targets
中的每个元素和matrix
中的每个元素,产生一个布尔数组,最后通过jnp.argwhere
找到匹配元素的位置。
- 广播机制: 当我们使用
matrix == targets[:, None, None]
这行代码时,我们利用了JAX(以及NumPy的)广播机制。targets[:, None, None]
这个操作将targets
数组从一维扩展到三维,使得每个元素都可以与matrix
的每个元素进行比较。这样,我们得到一个布尔数组match_positions
,其形状为(len(targets), matrix.shape[0], matrix.shape[1])
,其中True
值表示matrix
中的元素与targets
中的某个元素匹配。
match_positions:
[[[False True False]
[False False False]
[False False False]
[False False True]]
[[False False False]
[ True False True]
[False False False]
[False True False]]
[[False False False]
[False False False]
[False True False]
[ True False False]]]
- 查找匹配位置:
jnp.argwhere(match_positions)
这行代码用于查找match_positions
中值为True
的元素的索引,即找到所有目标值在二维数组中的位置。argwhere
返回的是一个数组,其中每个元素是匹配位置的索引,形式为(target_index, row, column)
。
Positions:
[[0 0 1] # 2匹配的位置
[0 3 2] # 2匹配的位置
[1 1 0] # 4匹配的位置
[1 1 2] # 4匹配的位置
[1 3 1] # 4匹配的位置
[2 2 1] # 8匹配的位置
[2 3 0]] # 8匹配的位置
- 在
matrix == targets[:, None, None]
这个表达式中,使用的是双等号==
,这意味着我们在做比较操作,而不是赋值操作。具体来说,这个表达式是在比较二维数组matrix
中的每个元素是否等于targets
数组中的每个值。由于使用了广播机制,targets[:, None, None]
将targets
数组扩展成一个三维数组,使得可以与matrix
的每个元素进行逐元素比较。这个操作返回一个布尔数组,其中的每个值表示相应的比较结果:True
表示相等,False
表示不相等
如果没有匹配到则返回空的array。
Positions:
[]