前言
自用,刚开始接触可能顺序会比较乱。
直接上代码
import jax.numpy as jnp
from jax import jit
@jit
def _extractValues(matrix, positions):
values = matrix[positions[:, 0], positions[:, 1]]
return values
matrix = jnp.array([[5,2,4,2,4,1,3,9,4],
[3,4,0,2,8,8,0,9,5],
[6,4,0,7,3,0,0,2,7],
[2,7,1,6,9,1,6,2,4]])
positions = jnp.array([[0, 0],[1, 0],[2, 0],[2, 1],[3, 0],[3, 1],[3, 2],[0, 3],[0, 4],[0, 5],[1, 3],[1, 4],[1, 5],[2, 4],[2, 5],[2, 6],[3, 5],[3, 6]])
extracted_values = extract_values(matrix, positions)
print("Extracted Values: ",extracted_values )
输出:
Extracted Values: [ 4 2 4 3 1 4 0 0 2 -1 0 4 2 1 -2 -2 -1 2]