pythonJax小记(二):python: 使用Jax知二维数组,根据每一个二维数组中的一维数组的第一个元素进行分类,重新生成若干二维数组(持续更新,评论区可以补充)

python: 使用Jax知二维数组,根据每一个二维数组中的一维数组的第一个元素进行分类,重新生成若干二维数组


前言

自用,刚开始接触可能顺序会比较乱。


直接上代码

分组后的二维数组中不包含用于分组依据的第一个元素

import jax.numpy as jnp

def group_by_first_element_without_first_index(arr):
    groups = {
    
    }
    for row in arr:
        key = row[0].item()  # 使用.item()来将JAX数组中的元素转换为Python标量
        # 去除第一个元素后的一维数组
        row_without_first = row[1:]
        if key in groups:
            groups[key].append(row_without_first)  # 添加到对应的列表中,不包含第一个元素
        else:
            groups[key] = [row_without_first]  # 创建新的列表

    # 将分组后的列表转换为JAX数组
    grouped_arrays = [jnp.array(group) for group in groups.values()]
    return grouped_arrays

# 示例二维数组
arr = jnp.array([[1, 2, 3], [1, 4, 5], [2, 5, 6], [3, 7, 8], [2, 8, 9]])

# 进行分组,不包含分组依据的第一个元素
grouped_without_first_index = group_by_first_element_without_first_index(arr)

# 打印结果
for group in grouped_without_first_index:
    print(group)

输出:

[[2 3]
 [4 5]]
[[5 6]
 [8 9]]
[[7 8]]

分组后的二维数组中包含用于分组依据的第一个元素

import jax.numpy as jnp

def group_by_first_element(arr):
    # 使用字典来存储分类结果,键是第一个元素,值是对应的数组列表
    groups = {
    
    }
    for row in arr:
        key = row[0].item()  # 使用 .item() 方法将JAX数组的第一个元素转换为Python标量
        if key in groups:
            groups[key].append(row)  # 如果键已存在,添加到对应的列表中
        else:
            groups[key] = [row]  # 如果键不存在,创建新的列表

    # 将分组后的列表转换为JAX数组
    grouped_arrays = [jnp.array(group) for group in groups.values()]
    return grouped_arrays
# 示例二维数组
arr = jnp.array([[1, 2, 3], [1, 4, 5], [2, 5, 6], [3, 7, 8], [2, 8, 9]])

# 进行分组
grouped = group_by_first_element(arr)

# 打印结果
for group in grouped:
    print(group)

输出:

[[1 2 3]
 [1 4 5]]
[[2 5 6]
 [2 8 9]]
[[3 7 8]]

猜你喜欢

转载自blog.csdn.net/xzs1210652636/article/details/136266600