前言
自用,刚开始接触可能顺序会比较乱。
问题描述
- 已知一维数组
- 当数组中的值大于0.5的时候赋值1
- 小于等于0.5的时候赋值0
直接上代码
import jax.numpy as jnp
from jax import jit
# 使用@jit装饰器进行即时编译
@jit
def threshold_array_optimized(arr):
return (arr > 0.5).astype(jnp.int32)
# 创建一个示例一维数组
arr = jnp.array([0.2, 0.6, 0.4, 0.8, 0.1])
# 应用函数并打印结果
result = threshold_array_optimized(arr)
print(result)
输出:
[0 1 0 1 0]
解释
可以直接将条件表达式的结果(这本身就是一个布尔数组)转换为整数类型。在JAX中,True
会被转换为1
,而False
会被转换为0
,然后直接返回结果。