python: 使用Jax知二维数组,根据每一个二维数组中的一维数组的第一个元素进行分类,重新生成若干二维数组
前言
自用,刚开始接触可能顺序会比较乱。
直接上代码
分组后的二维数组中不包含用于分组依据的第一个元素
import jax.numpy as jnp
def group_by_first_element_without_first_index(arr):
groups = {
}
for row in arr:
key = row[0].item() # 使用.item()来将JAX数组中的元素转换为Python标量
# 去除第一个元素后的一维数组
row_without_first = row[1:]
if key in groups:
groups[key].append(row_without_first) # 添加到对应的列表中,不包含第一个元素
else:
groups[key] = [row_without_first] # 创建新的列表
# 将分组后的列表转换为JAX数组
grouped_arrays = [jnp.array(group) for group in groups.values()]
return grouped_arrays
# 示例二维数组
arr = jnp.array([[1, 2, 3], [1, 4, 5], [2, 5, 6], [3, 7, 8], [2, 8, 9]])
# 进行分组,不包含分组依据的第一个元素
grouped_without_first_index = group_by_first_element_without_first_index(arr)
# 打印结果
for group in grouped_without_first_index:
print(group)
输出:
[[2 3]
[4 5]]
[[5 6]
[8 9]]
[[7 8]]
分组后的二维数组中包含用于分组依据的第一个元素
import jax.numpy as jnp
def group_by_first_element(arr):
# 使用字典来存储分类结果,键是第一个元素,值是对应的数组列表
groups = {
}
for row in arr:
key = row[0].item() # 使用 .item() 方法将JAX数组的第一个元素转换为Python标量
if key in groups:
groups[key].append(row) # 如果键已存在,添加到对应的列表中
else:
groups[key] = [row] # 如果键不存在,创建新的列表
# 将分组后的列表转换为JAX数组
grouped_arrays = [jnp.array(group) for group in groups.values()]
return grouped_arrays
# 示例二维数组
arr = jnp.array([[1, 2, 3], [1, 4, 5], [2, 5, 6], [3, 7, 8], [2, 8, 9]])
# 进行分组
grouped = group_by_first_element(arr)
# 打印结果
for group in grouped:
print(group)
输出:
[[1 2 3]
[1 4 5]]
[[2 5 6]
[2 8 9]]
[[3 7 8]]