利用切片读取张量的行并生成新的张量

方法:用y读取cf_class的行索引。

注意:(1)y中的元素不要超过cf_class行索引的最大值,否侧会报索引错误。

           (2)y必须转化成long格式。

cf_class = torch.tensor([[7,8,9],[4,5,6],[4,3,2],[6,56,2]])
print('cf_class.shape:',cf_class.shape)
print('cf_class[0]:',cf_class[0])
print('cf_class[1]:',cf_class[1])
print('cf_class[2]:',cf_class[2])
print('cf_class[3]:',cf_class[3])
y = torch.tensor([1,3,2,0,2,1,3])
print(y)
y = y.long()
print('y:',y)
print('cf_class[y]:',cf_class[y])
print('cf_class[y].shape:',cf_class[y].shape)

输出:

cf_class.shape: torch.Size([4, 3])
cf_class[0]: tensor([7, 8, 9])
cf_class[1]: tensor([4, 5, 6])
cf_class[2]: tensor([4, 3, 2])
cf_class[3]: tensor([ 6, 56,  2])
tensor([1, 3, 2, 0, 2, 1, 3])
tensor([1, 3, 2, 0, 2, 1, 3])
cf_class[y]: tensor([[ 4,  5,  6],
                     [ 6,56,  2],
                     [ 4,  3,  2],
                     [ 7,  8,  9],
                     [ 4,  3,  2],
                     [ 4,  5,  6],
                     [ 6, 56,  2]])
cf_class[y].shape: torch.Size([7, 3])

猜你喜欢

转载自blog.csdn.net/qq_54708219/article/details/129811274