torch 常用函数讲解【Ⅰ】

is_tensor

torch.is_tensor(obj)

参数:obj 为对象。

作用:判断 obj 是否为 Tensor 对象。与 isinstance(obj, Tensor) 的作用相似,但是更推荐使用 isintsance

tensor

torch.tensor(data, dtype=None, device=None, requires_grad=False)

主要参数:

  • date:可以是 listtuplendarrayscaler,或者是其他类型。
  • dtype:函数返回的张量的元素类型。可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。
  • device:保存张量的设备。可选值为 torch.device 类型。当 device=None 时,如果 data 是张量,那么返回张量的 devicedata 所在设备;如果 data 不是张量,那么返回张量的 device 为 CPU。
  • requires_grad:是否计算并保存梯度。

作用:返回元素为 data,元素类型为 dtype,在 device 设备上,进行自动梯度计算和保存(requires_grad=True)的张量。

data 为张量时,函数 tensor 会拷贝 data 的数据重新创建一个张量,即不存在内存共享。此时,不推荐使用函数 tensor,而是使用与之等价的 data.clone().detach().requires_grad_(False)

举例:

a = torch.tensor([1, 2, 3])
b = torch.tensor(a) # UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
a, b
## tensor([1, 2, 3]) tensor([1, 2, 3])
a[0] = 0 # 修改 a 的元素不影响 b 的元素,则说明二者不存在内存共享。
a, b
## tensor([0, 2, 3]) tensor([1, 2, 3])

numel

torch.numel(input)

参数:inputTensor

作用:返回 input 的元素个数。

a = torch.tensor([1, 2, 3, 4])
b = torch.randn(1, 2, 3, 4)
torch.numel(a)
## 4
torch.numel(b)
## 24

is_nonzero

torch.is_nonzero(input)

参数:inputTensor

作用:如果 input 只有一个元素,且经过类型转换后非零,那么返回 True。当单元素为 00.False 时会返回 True。特别地,如果 input(包括稀疏张量)不止一个元素,则抛出 RuntimeError。

举例:

torch.is_nonzero(torch.tensor([0]))
## False
torch.is_nonzero(torch.tensor([[[0.]]]))
## False
torch.is_nonzero(torch.tensor([False]))
## False
torch.is_nonzero(torch.tensor([1.5]))
## True
torch.is_nonzero(torch.tensor([1, 3, 5]))
## RuntimeError

sparse_coo_tensor

torch.sparse_coo_tensor(indices, values, size=None, dtype=None, device=None, requires_grad=False)

参数:

  • indices:可以是 listtuplendarrayscaler,或者是其他类型。要求存在两个维度,第一维的数量表示稀疏张量对应的完整张量的维度个数,第二维的数量表示非零元素数量。

  • values:可以是 listtuplendarrayscaler,或者是其他类型。要求是一维。

  • size:稀疏张量需要显式指明对应的完整张量的大小。如果没有显式指明,则认为稀疏张量对应的完整张量应该在包含全部指定的非零元素的前提下,包含尽可能少的零元素。可以是 listtuple、或 torch.Size

  • dtype:稀疏张量的元素类型。可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。如果为 None,则设置为参数 values 的类型。

  • device:保存张量的设备。可选值为 torch.device 类型。

  • requires_grad:是否计算并保存梯度。

作用:创建稀疏矩阵。

解释:首先了解稀疏矩阵的大致存储形式,三元组 (row_id, column_id, element_value)。对于函数 sparse_coo_tensor 来说,参数 indices 指明了非零元素的位置 (row_id, column_id),参数 values 指明了对应位置的非零值。具体地,(indices[0][i], indices[1][i], values[i]) 对应矩阵中的一个非零元素,(0, 2, 9) 表示矩阵 (0, 2) 位置的元素为 9。如下图所示。(由此扩展到高维张量也非常容易理解)

COO Style

# 生成三维稀疏张量
i = [[0, 1, 2], [2, 1, 0], [0, 0, 0]]
v = [-1, 0, 1]
torch.sparse_coo_tensor(indices=i, values=v, size=(5, 5, 5))
## tensor(indices=tensor([[0, 1, 2],
##                        [2, 1, 0],
##                        [0, 0, 0]]),
##        values=tensor([-1,  0,  1]),
##        size=(5, 5, 5), nnz=3, layout=torch.sparse_coo)

通过 a.to_dense() 可以查看稀疏张量 a 对应的完整张量。

参考:

[1] torch.sparse_coo_tensor()函数 - CSDN

[2] Pytorch torch.sparse_coo_tensor()_- CSDN

from_numpy

torch.from_numpy(ndarray)

参数:ndarray 为 Numpy 的数据类型。

作用:根据 ndarray 创建张量,且共享内存

a = numpy.array([1, 2, 3])
t = torch.from_numpy(a)
t
## tensor([1,  2,  3])
t[0] = -1
a
## tensor([-1,  2,  3])

as_tensor

torch.as_tensor(data, dtype=None, device=None)

参数:

  • data:可以是 listtuplendarrayscaler,或者是其他类型。
  • dtype:可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。
  • device:保存张量的设备。可选值为 torch.device 类型。

作用:将 data 转换为张量,尽可能共享内存、保存历史梯度data 的梯度信息)。

解释:只有 datatensor 或者 ndarray 类型,且同时满足下面两个条件才会共享内存。

① 对参数 dtype 的要求:dtype=None 或与 datadtype 一致。

② 对参数 device 的要求:当 datatensor 时,device=None 或与 datadevice 一致;由于无法用 GPU 计算 ndarray,所以当 datandarray 时,参数 device 必须是默认(None)或 CPU。

特别地,当 datandarraydtypedevice 满足要求 ① 和 ② 时,本质上调用的是 torch.from_numpy()

举例:

a = np.array([1, 2])
b = torch.as_tensor(a)
b[0]=0 # 测试是否共享内存
a, b
# share memory

a = np.array([1, 2])
b = torch.as_tensor(a, dtype=torch.float64)
b[0]=0
a, b
# not share memory

a = torch.tensor([1, 2])
b = torch.as_tensor(a)
b[0]=0
a, b
# share memory

a = torch.tensor([1, 2])
b = torch.as_tensor(a, dtype=torch.float64)
b[0]=0
a, b
# not share memory

a = torch.tensor([1, 2], device=torch.device('cuda')) # 需要配置好 CUDA 环境
b = torch.as_tensor(a)
b[0]=0
a, b
# share memory

a = [1, 2] # list 不可能与返回的 tensor 共享内存
b = torch.as_tensor(a)
b[0]=0
a, b
# not share memory

torch.as_strided

torch.as_strided(input, size, stride, storage_offset=None)

参数:

  • inputinputTensor
  • size:函数返回的视图的大小。
  • stride:返回的视图由步长元组 stride 生成。要求与 size 维度一致,且均至少为 2。stride 的第一维表示生成视图第一维的步长,第二维表示生成视图第二维的步长,以此类推。
  • storage_offset:起始位置。

作用:返回参数为 sizestridestorage_offsetinput 视图。

解释:所谓“视图”,可以通俗地理解为由源张量部分元素构成,且这些元素与源张量共享内存。

在介绍 as_stride 函数之前先讲解一下张量的底层存储。张量的底层存储是按照行优先的原则存储的,比如二维张量 [[1, 2], [3, 4]] 对应的底层存储为 [1, 2, 3, 4],三维张量 [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] 对应的底层存储为 [1, 2, ..., 11, 12]

as_strided 函数从源张量的存储序列的 storage_offset 位置开始,按照步长 stride 选取元素,直到这些元素构成了大小为 size 的张量视图。详细见下面的例子及解析。

举例:

a = torch.tensor([1, 2, 3, 4])
torch.as_strided(a, size=(2, 2), stride=(0, 0), storage_offset=0)
## tensor([[1, 1],
##         [1, 1]])

上面代码表示生成 (2, 2) 的二维张量,第一维和第二维均按照步长 0 选取。确定 a 的底层存储序列 [1, 2, 3, 4]storage_offset=0 说明从存储序列索引为 0 的位置开始看,所以视图第一行的第一个元素可以确定为底层序列的第一个元素 1stride=(0, 0) 的第一维为 0 说明行步长为 0,即视图第二行的第一个元素与视图第一行的第一个元素的距离为 0,即视图第二行的第一个元素也是存储序列的第一个元素 1,至此行数达到 size 对视图第一维的要求。stride=(0, 0) 表明列步长也为 0,因此可以判断出视图第一行的第二个元素为与第一行的第一个元素距离为 0,即视图第一行的第二个元素为存储序列的第一个元素 1;视图第二行的第二个元素同样与第二行的第一个元素距离为 0,所以设置为 1

a = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
b = torch.as_strided(a, size=(3, 3), stride=(2, 1), storage_offset=1) 
b
## tensor([[2, 3, 4],
##         [4, 5, 6],
##         [6, 7, 8]])
a[1][0] = 0 # 将 a 中的 6 换为 0,b 中两个为 6 的位置都换为了 0
b
## tensor([[2, 3, 4],
##         [4, 5, 0],
##         [0, 7, 8]])

上面代码就比较明显了。offset 确定了视图第一行第一列的元素,之后根据行步长为 2 确定每一行的第一个元素,最后根据每一行的第一个元素和列步长确定整个视图。

a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
torch.as_strided(a, size=(2, 2, 2), stride=(2, 1, 3), storage_offset=3)
## tensor([[[ 4,  7],
##          [ 5,  8]],
## 
##         [[ 6,  9],
##          [ 7, 10]]])

上面代码生成三维视图,道理不变。将三维视图理解为多个二维矩阵,需要先确定第一个矩阵的第一个元素,根据 stride 的第一个元素确定每个矩阵的第一个元素,问题转换为对每个矩阵生成视图,这与上面讲到的过程一致。

参考:

[1] pytorch中tensor的底层存储方式,维度变换permute/view/reshape,维度大小和数目 - CSDN

[2] torch.as_strided()详解 - CSDN

zeros、ones、empty、full

torch.zeros(size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

torch.ones(size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

torch.empty(size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False, memory_format=torch.contiguous_format)

torch.full(size, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

主要参数:

  • size:返回的张量大小;当 out 不是 None 时亦可理解为被赋值的张量的大小。
  • fill_value:仅出现在 full 函数中,要填充到返回张量中的标量。
  • out:被赋值的张量。
  • dtype:返回的张量元素类型。可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。
  • layouttorch.layout 表示张量的内存分布。torch.sparse_coo 表示稀疏张量,torch.strided 表示稠密张量,是最常见的内存分布方式。
  • device:保存张量的设备。可选值为 torch.device 类型。
  • requires_grad:是否计算并保存梯度。

作用:

  • zeros:返回一个全零张量或者将已存在张量的元素全部赋值为零。赋值操作不常用。
  • ones:与 zeros 唯一的不同在于返回全一张量。
  • empty:与上面两个函数相比多两个不常用参数。返回内存中的数,即不做任何赋值操作。
  • full:相较于 zeros 多了 fill_value 参数。

举例:

torch.zeros(5)
## tensor([ 0.,  0.,  0.,  0.,  0.])
torch.zeros(size=(2, 2), layout=torch.strided) # 稠密张量
## tensor([[0., 0.],
##		   [0., 0.]])
torch.zeros(size=(2, 2), layout=torch.sparse_coo) # 稀疏张量
## tensor(indices=tensor([], size=(2, 0)),
##        values=tensor([], size=(0,)),
##        size=(2, 2), nnz=0, layout=torch.sparse_coo)
a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.zeros(size=a.shape, out=a)
b[0] = 1 # 测试内存共享
a, b
## tensor([[1., 1.],
##         [0., 0.]]) 
## tensor([[1., 1.],
##         [0., 0.]])

torch.empty(size=(2, 3), dtype=torch.float64)
## tensor([[4.9407e-324, 9.8813e-324, 1.4822e-323],
##         [1.9763e-323, 2.4703e-323, 2.9644e-323]], dtype=torch.float64)

torch.full(size=(2, 3), fill_value=3.1415926)
## tensor([[3.1416, 3.1416, 3.1416],
##         [3.1416, 3.1416, 3.1416]])

解释:关于 torch.strided,每个稠密(完整)张量都可以调用其 .stride() 方法查看张量对应的步长:

torch.tensor([1, 2, 3]).stride()
## (1,)
torch.tensor([[1, 2, 3], [4, 5, 6]]).stride()
## (3, 1)
torch.tensor([[[1], [2], [3]], [[4], [5], [6]]]).stride()
## (3, 1, 1)

这么看来,我们创建的张量也只不过是底层数据的”视图“。

另外,这些函数的 out 不常用,而且不方便,因为需要保证这些函数中的其他参数与 out 张量的属性匹配。由于不常用,不过多展示。

参考:

[1] PyTorch学习笔记(2)——randn_like()、layout、memory_format_torch.randn_like - CSDN

zeros_like、ones_like、empty_like、full_like

torch.zeros_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)

torch.ones_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)

torch.empty_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)

torch.full_like(input, fill_value, dtype=None, layout=torch.strided, device=None, requires_grad=False, memory_format=torch.preserve_format)

主要参数:

  • inputinputTensor
  • fill_value:仅出现在 full_like 函数中,要填充到返回张量中的标量。
  • dtype:返回的张量元素类型。可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。
  • layouttorch.layout 表示张量的内存分布。torch.sparse_coo 表示稀疏张量,torch.strided 表示稠密张量,是最常见的内存分布方式。
  • device:保存张量的设备。可选值为 torch.device 类型。
  • requires_grad:是否计算并保存梯度。

作用:

  • zeros_like:生成与 input 大小相同,且满足 dtype 等参数要求的全零新张量。
  • zeros_like:生成全一新张量。
  • empty_like:生成新张量,不修改内存值。
  • full_like:生成新张量,由 full_value 填充。

举例:

input = torch.empty(2, 3)
torch.ones_like(input)
## tensor([[ 1.,  1.,  1.],
##         [ 1.,  1.,  1.]])

heaviside

torch.heaviside(input, values, out=None)

参数:

  • inputinputTensor
  • valuesvaluesTensor。元素值为 00. 或者 False 的位置将被赋值为 values

作用:对于数值型的张量 inputheaviside 函数返回一个新的同型数值张量,input 大于 0 的对应位置的值设置为 1,小于 0 的对应位置设置为 0,等于 0 的设置为 values;对于布尔型的张量 inputheaviside 函数返回一个新的同型布尔张量,inputFalse 的对应位置设置为 values,为 True 的位置不变,仍然为 True。注意,要求张量 values 必须与 input 类型相同。

举例:

a = torch.tensor([-1, 0, 3])
v = torch.tensor([100]) # 发生广播,[100, 100, 100]
torch.heaviside(a, v)
## tensor([  0, 100,   1])
a = torch.tensor([-1, 0, 3]) # 发生广播,[[-1, 0, 3], [-1, 0, 3], [-1, 0, 3]]
v = torch.tensor([[1], [10], [100]]) ## 发生广播,[[1, 1, 1], [10, 10, 10], [100, 100, 100]]
torch.heaviside(a, v)
## tensor([[  0,   1,   1],
##         [  0,  10,   1],
##         [  0, 100,   1]])
a = torch.tensor([-1, 0, 3]) # 发生广播,[[-1, 0, 3], [-1, 0, 3], [-1, 0, 3]]
v = torch.tensor([[5, 10, 15], [1, 2, 3], [100, 10, 1]])
torch.heaviside(a, v)
## tensor([[ 0, 10,  1],
##         [ 0,  2,  1],
##         [ 0, 10,  1]])
a = torch.tensor([[-2, 0, 3], [0, -4, 0], [0, 0, 0]])
v = torch.tensor([1, 10, 100]) ## 发生广播,[[1, 10, 100], [1, 10, 100], [1, 10, 100]]
torch.heaviside(a, v)
## tensor([[  0,  10,   1],
##         [  1,   0, 100],
##         [  1,  10, 100]])

arange

torch.arange(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

主要参数:

  • start:起始值,默认为 0。
  • end:最终值。上限或者下限(不含)。
  • step:步长。区别于 stride
  • dtype:可选值为 torch.dtype 类型,比如 torch.inttorch.floattorch.float64 等。
  • layouttorch.layout 表示张量的内存分布。torch.sparse_coo 表示稀疏张量,torch.strided 表示稠密张量,是最常见的内存分布方式。
  • device:保存张量的设备。可选值为 torch.device 类型。
  • requires_grad:是否计算并保存梯度。

作用:返回以 start 为开始,步长为 step 最大值小于(最小值大于)end 的一维张量。

torch.arange(5) # 默认 start=0,step=1
## tensor([0, 1, 2, 3, 4])
torch.arange(1, 7, 2) # 此时 end 表示上限
## tensor([1, 3, 5])
(torch.arange(5, -5, -2) # 此时 end 表示下限
## tensor([ 5,  3,  1, -1, -3])

注意,torch.range() 函数已经丢弃不用了。

linspace、logspace

torch.linspace(start, end, steps, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

torch.logspace(start, end, steps, base=10.0, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

主要参数:

  • start:起始值。
  • end:结尾值。
  • steps:元素总数。
  • base:仅出现在 logspace 函数中,底数。
  • other parameters:与上面其他函数中的同名参数同义。

作用:返回以 start 为开始,end 为结尾, 元素个数为 steps 的一维张量。

举例:

torch.linspace(3, 10, steps=5)
## tensor([  3.0000,   4.7500,   6.5000,   8.2500,  10.0000])
torch.linspace(start=-10, end=10, steps=1)
## tensor([-10.])

torch.logspace(start=-10, end=10, steps=5)
## tensor([ 1.0000e-10,  1.0000e-05,  1.0000e+00,  1.0000e+05,  1.0000e+10])
torch.logspace(start=0.1, end=1.0, steps=5)
## tensor([  1.2589,   2.1135,   3.5481,   5.9566,  10.0000])
torch.logspace(start=2, end=2, steps=1, base=2)
## tensor([4.0])

可以认为 logspace 函数是在得到 linspace 的基础上,将 linspace 的结果作为指数的结果作为输出。

eye

torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)

参数:

  • n:行数。
  • m:列数,默认等于行数。
  • other parameters:与上面其他函数中的同名参数同义。

作用:返回对角线上的元素为 1,其他位置的元素为 0 的二维张量。

举例:

torch.eye(3)
## tensor([[ 1.,  0.,  0.],
##         [ 0.,  1.,  0.],
##         [ 0.,  0.,  1.]])
torch.eye(2, 3)
## tensor([[ 1.,  0.,  0.],
##         [ 0.,  1.,  0.]])

cat

torch.cat(tensors, dim=0, out=None)

参数:

  • tensors:张量序列。要求全部张量有相同的 dtype,且除了 dim 维,其他维度必须对应相同。
  • dim:被操作的是哪一维。

作用:在 dim 维对张量序列进行拼接。可以被视为 torch.split()torch.chunk() 的反向操作。与 torch.concat()torch.concatenate() 等价。

举例:

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
torch.cat((-1 * a, a, 2 * a), dim=0)
## tensor([[-1, -2, -3],
##         [-4, -5, -6],
##         [ 1,  2,  3],
##         [ 4,  5,  6],
##         [ 2,  4,  6],
##         [ 8, 10, 12]])
torch.cat((-1 * a, a, 2 * a), dim=1)
## tensor([[-1, -2, -3,  1,  2,  3,  2,  4,  6],
##         [-4, -5, -6,  4,  5,  6,  8, 10, 12]])

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[-1], [-2]])
c = torch.tensor([[0, 0], [0, 0]])
torch.cat((a, b, c), dim=1)
## tensor([[ 1,  2,  3, -1,  0,  0],
##         [ 4,  5,  6, -2,  0,  0]])

补充对 dim 的巧妙理解。

我们讲的第一维(dim=0)一般对应张量的 .shape 属性(或 .size() 方法)返回的列表的第一个元素(idx=0),同理,第二维(dim=1)对应列表的第二个元素(idx=1)。严谨地讲,列表中的数值表示每个维度对应的长度。cat 函数在(从 0 开始的) dim 维对张量进行拼接,意味着这些张量在这个维度上很可能长度不一致。当然,这不是肯定的,更准确来说,拼接后的张量与每个源张量相比,发生变化的维度就是 dim。这不仅仅使用于 cat 函数,这适用于绝大部分带 dim 参数的函数。

最外层括号对应 dim=0,往内的括号以此类推。被拼接的维度对应长度发生改变,就是该维对应的括号内的直接元素个数变多。

chunk

torch.chunk(input, chunks, dim=0)

参数:

  • inputinputTensor
  • chunks:返回的视图数量。
  • dim:切分的维度。

作用:对 inputdim 维进行切分,返回不超过 chunks 个视图构成的元组。当 inputdim 维对应的长度(即 input.size(dim))能够被 chunks 整除,那么返回的每个视图在 dim 维具有相同的长度;如果不能够整除,那么保证除了最后一个视图外,其它视图在 dim 维的长度相同即可。另外,可能出现返回的视图个数不足 chunks 的情况。

举例:

torch.arange(11).chunk(6)
## (tensor([0, 1]),
##  tensor([2, 3]),
##  tensor([4, 5]),
##  tensor([6, 7]),
##  tensor([8, 9]),
##  tensor([10]))
torch.arange(12).chunk(6)
## (tensor([0, 1]),
##  tensor([2, 3]),
##  tensor([4, 5]),
##  tensor([6, 7]),
##  tensor([8, 9]),
##  tensor([10, 11]))
torch.arange(13).chunk(6)
## (tensor([0, 1, 2]),
##  tensor([3, 4, 5]),
##  tensor([6, 7, 8]),
## 	tensor([ 9, 10, 11]),
##  tensor([12]))

另外,与 torch.tensor_split() 等除了在功能上不同外,torch.chunk() 返回的块数可能少于要求块数;而 tensor_split() 等可以通过返回空张量来保证块数的要求。

tensor_split、dsplit、vsplit、hsplit、split

torch.tensor_split(input, indices_or_sections, dim=0)

torch.dsplit(input, indices_or_sections)

torch.vsplit(input, indices_or_sections)

torch.hsplit(input, indices_or_sections)

torch.split(tensor, split_size_or_sections, dim=0)

参数:

  • inputinputTensor。对于 tensor_splithsplit 而言,input 至少为一维;对于 vsplit 而言,至少为二维;对于 dsplit 而言,至少为三维。

  • indices_or_sectionsindices_or_sections 可以是 int 类型的列表、元组、张量或者标量。

    为标量时,表示在 dim 维对 input 进行切分,切分方法如下:如果 dim 的长度可以被 indices_or_sections=n 整除,那么每个视图的 dim 维长度为 input.size(dim) / n;如果无法整除,那么前 int(input.size(dim) % n) 个视图的 dim 维长度为 int(input.size(dim) / n) + 1,而剩下视图的 dim 维长度为 int(input.size(dim) / n),这可以保证每个视图 dim 维长度之和与 inputdim 维长度。通俗来讲,就是保证每块尽可能一样大,如果没法保证一样大,那么前面的块必须比后面的块大。

    为列表、元组或张量时,每个元素表示位置索引。比如对于向量 input 按照 indices_or_sections=[2, 3, 6] 进行切分,将 indices_or_sections 中的每个元素视为一个分隔索引,这样整个向量就被分为 len(indices_or_sections) + 1 块,即 [:2][2:3][3:6][6:],因此 tensor_split 将返回四个张量 input[:2]input[2:3]input[3:6]input[6:]

  • split_size_or_sections:仅出现在 split 函数中。为标量时,与 indices_or_sections 一致;为列表时,每个元素表示每个视图 dim 维的长度,这与 indices_or_sections 的位置索引含义不同。因此,为列表时,列表内的元素之和必须等于 inputdim 维的长度,即满足 sum(split_size_or_sections) = input.size(dim)

  • dim:仅出现在 tensor_split 函数中。

作用:

  • tensor_split:根据 indices_or_sectionsdiminput 进行切分,返回每部分的视图构成的元组。本质上是进行切片操作,该函数是基于 Numpy 中的 np.array_split()
  • dsplit:等价于 tensor_split(input, indices_or_sections, dim=2),这意味着 dsplit 的输入张量 input 至少为三维。该函数是基于 Numpy 中的 np.dsplit()
  • vsplit:等价于torch.tensor_split(input, indices_or_sections, dim=0)。该函数是基于 Numpy 中的 np.vsplit()
  • hsplit:如果 input 只有一维,则等价于 torch.tensor_split(input, indices_or_sections, dim=0);如果 input 为更高维,那么等价于 torch.tensor_split(input, indices_or_sections, dim=1)。该函数是基于 Numpy 中的 np.hsplit()
  • split:根据 split_size_or_sectionsdiminput 进行切分,返回每部分的视图构成的元组。区别于 np.split()np.split() 对于输入为列表时的切分方式与 tensor_splitdsplit 一致。

另外,与 torch.chunk() 除了在功能上不同外,torch.chunk() 返回的块数可能少于要求块数;而 tensor_split()dsplitvsplithsplit 都可以通过返回空张量来保证块数的要求。

举例:

# torch.tensor_split
x = torch.arange(8)
torch.tensor_split(x, 3)
## (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]))

x = torch.arange(7)
torch.tensor_split(x, 3)
## (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
torch.tensor_split(x, (1, 6))
## (tensor([0]), tensor([1, 2, 3, 4, 5]), tensor([6]))

x = torch.arange(14).reshape(2, 7)
x
## tensor([[ 0,  1,  2,  3,  4,  5,  6],
##         [ 7,  8,  9, 10, 11, 12, 13]])
torch.tensor_split(x, 3, dim=1)
## (tensor([[0, 1, 2],
##          [7, 8, 9]]),
##  tensor([[ 3,  4],
##          [10, 11]]),
##  tensor([[ 5,  6],
##          [12, 13]]))
torch.tensor_split(x, (1, 6), dim=1)
## (tensor([[0],
##          [7]]),
##  tensor([[ 1,  2,  3,  4,  5],
##          [ 8,  9, 10, 11, 12]]),
##  tensor([[ 6],
##          [13]]))
# torch.dsplit
t = torch.arange(16.0).reshape(2, 2, 4)
t
## tensor([[[ 0.,  1.,  2.,  3.],
##          [ 4.,  5.,  6.,  7.]],
##         [[ 8.,  9., 10., 11.],
##          [12., 13., 14., 15.]]])
torch.dsplit(t, 2)
## (tensor([[[ 0.,  1.],
##           [ 4.,  5.]],
##          [[ 8.,  9.],
##           [12., 13.]]]),
##  tensor([[[ 2.,  3.],
##           [ 6.,  7.]],
##          [[10., 11.],
##           [14., 15.]]]))
torch.dsplit(t, [3, 6]) # t 的 dim=2 维长度为 4,而参数 indices_or_sections 最大值为 6,所以出现下面的输出
## (tensor([[[ 0.,  1.,  2.],
##           [ 4.,  5.,  6.]],
##          [[ 8.,  9., 10.],
##           [12., 13., 14.]]]),
##  tensor([[[ 3.],
##           [ 7.]],
##          [[11.],
##           [15.]]]),
## tensor([], size=(2, 2, 0)))
# torch.split
a = torch.arange(10).reshape(5,2)
a
## tensor([[0, 1],
##         [2, 3],
##         [4, 5],
##         [6, 7],
##         [8, 9]])
torch.split(a, 2)
## (tensor([[0, 1],
##          [2, 3]]),
##  tensor([[4, 5],
##          [6, 7]]),
##  tensor([[8, 9]]))
torch.split(a, [1,4])
## (tensor([[0, 1]]),
##  tensor([[2, 3],
##          [4, 5],
##          [6, 7],
##          [8, 9]]))

stack

torch.stack(tensors, dim=0, out=None)

参数:

  • tensors:张量序列。要求全部张量具有相同的大小。
  • dim:被操作的是哪一维。特别地,stackdim 最大可以取到 len(tensors[0].size()),而不是 len(tensors[0].size()) - 1,这与 stack 的功能相关。

作用:沿着一个新的维度拼接张量。

解释:

对于 stack 的操作可以总结出如下结论:当 D=0 时,stack 的输出张量的维度为 (N, m, n);当 D=1 时,stack 的输出张量的维度为 (m, N, n);当 D=2 时,stack 的输出张量的维度为 (m, n, N)。对于高维张量也是同理的。

可见,stack 在保证同高维索引不变的前提下,完全忽视低维嵌套,只是将不同张量在第 D 维进行堆叠。举个例子,在自然语言处理中,一个 batch 对应张量的一般形式为 (batch_size, seq_len, word_embedding),本质上这是将每个语句在 dim=0 上进行堆叠得到的,即 torch.stack(tensors, dim=0),其中 tensors 由每个语句 (seq_len, word_embedding) 构成。通过下面具体例子和过程描述来加深理解。

举例:

a = torch.arange(9).reshape(3, 3)
b = -1 * a
a, b
## tensor([[0, 1, 2],
##         [3, 4, 5],
##         [6, 7, 8]])
## tensor([[ 0, -1, -2],
##         [-3, -4, -5],
##         [-6, -7, -8]])
torch.stack((a, b), dim=0)
## tensor([[[ 0,  1,  2],
##          [ 3,  4,  5],
##          [ 6,  7,  8]],
## 
##         [[ 0, -1, -2],
##          [-3, -4, -5],
##          [-6, -7, -8]]])

ab 增加一维得到 AB,形式为 (1, 3, 3)。在 AB 的第 0 维度里,每个元素依次相连(堆叠),每对连接元素用 [] 包装。由于 AB 的第 0 维均只有一个元素,分别为 ab,所以对二者进行堆叠得到 [a, b],输出即为 [a, b]

torch.stack((a, b), dim=1)
## tensor([[[ 0,  1,  2],
##          [ 0, -1, -2]],
## 
##         [[ 3,  4,  5],
##          [-3, -4, -5]],
##
##         [[ 6,  7,  8],
##          [-6, -7, -8]]])

A 的第一维有三个元素:[0, 1, 2][3, 4, 5][6, 7, 8]B 的第一维有三个元素:[0, -1, -2][-3, -4, -5][-6, -7, -8]。将 [0, 1, 2][0, -1, -2] 堆叠,将 [3, 4, 5][-3, -4, -5] 堆叠,将 [6, 7, 8][-6, -7, -8] 堆叠,分别得到 [[0, 1, 2], [0, -1, -2]][[3, 4, 5], [-3, -4, -5]][[6, 7, 8], [-6, -7, -8]]。我们知道,对于第 0 维而言,无论是 [0, 1, 2][3, 4, 5][6, 7, 8] 还是 [0, -1, -2][-3, -4, -5][-6, -7, -8] 都来自同一个第 0 维张量,因此,将 [[0, 1, 2], [0, -1, -2]][[3, 4, 5], [-3, -4, -5]][[6, 7, 8], [-6, -7, -8]] 堆叠在一起得到最终的输出。

torch.stack((a, b), dim=2)
## tensor([[[ 0,  0],
##          [ 1, -1],
##          [ 2, -2]],
## 
##         [[ 3, -3],
##          [ 4, -4],
##          [ 5, -5]],
##
##         [[ 6, -6],
##          [ 7, -7],
##          [ 8, -8]]])

A2 维度里的元素:012,……,8B2 维度里的元素:0-1-2,……,-8。对应位置堆叠,得到 [0, 0][1, -1] 、……、[8, -8]。由于 [0, 1, 2][0, -1, -2] 的第 1 维索引是 0[3, 4, 5][-3, -4, -5] 的第 1 维索引为 1[6, 7, 8][-6, -7, -8] 的第 1 维索引为 2,所以把第一维索引为 0 的、索引为 1 的和索引为 2 的单独堆叠,即 [[0, 0], [1, -1], [2, -2]][[3, -3], [4, -4], [5, -5]][[6, -6], [7, -7], [8, -8]]。最后将以上三组堆叠。

torch.cat() 类似都是对张量进行拼接,但是拼接的方式有所不同。

参考:

[1] torch.stack()函数的使用理解_- CSDN

hstack、vstack、dstack、row_stack、column_stack

torch.hstack(tensors, out=None)

torch.vstack(tensors, out=None)

torch.dstack(tensors, out=None)

torch.row_stack(tensors, out=None)

torch.column_stack(tensors, out=None)

参数:tensors:张量序列。保证按照函数的要求变形后的每个张量大小一致即可。

作用:

  • hstack:沿着列维度堆叠。对于一维张量而言,相当于在 dim=0 维上拼接;对于其他高维张量而言,相当于在 dim=1 维上拼接。

  • vstack:沿着行维度堆叠。相当于在 dim=0 维上拼接。特别地,对于一维张量,需要向通过 torch.atleast_2d() 将一维张量转换为二维张量,其效果就是在张量最外层套上一层括号,再进行拼接。

  • dstack:沿着深度维度堆叠。对于一维和二维张量,先通过 torch.atleast_3d() 将其转换为三维张量,原理是外层嵌套括号,之后在 dim=2 维上拼接。

  • row_stack:与 vstack 完全一致。

  • column_stack:与 hstack 唯一的不同在于,对于零维和一维张量 t 首先会转换大小为 (t.numel(), 1),再调用 hstack 函数。

举例:

# torch.hstack
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
torch.hstack((a,b))
## tensor([1, 2, 3, 4, 5, 6])

a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
torch.hstack((a,b))
## tensor([[1, 4],
##         [2, 5],
##         [3, 6]])
# torch.vstack
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
torch.vstack((a,b))
## tensor([[1, 2, 3],
##         [4, 5, 6]])

a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
torch.vstack((a,b))
## tensor([[1],
##         [2],
##         [3],
##         [4],
##         [5],
##         [6]])
# torch.dstack
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
torch.dstack((a,b))
## tensor([[[1, 4],
##          [2, 5],
##          [3, 6]]])

a = torch.tensor([[1],[2],[3]])
b = torch.tensor([[4],[5],[6]])
torch.dstack((a,b))
## tensor([[[1, 4]],
##         [[2, 5]],
##         [[3, 6]]])
# torch.column_stack
a = torch.arange(5)
b = torch.arange(10).reshape(5, 2)
torch.column_stack((a, b, b))
## tensor([[0, 0, 1, 0, 1],
##         [1, 2, 3, 2, 3],
##         [2, 4, 5, 4, 5],
##         [3, 6, 7, 6, 7],
##         [4, 8, 9, 8, 9]])

index_select、narrow

torch.index_select(input, dim, index, out=None)

torch.narrow(input, dim, start, length)

参数:

  • inputinputTensor
  • dim:被操作的是哪一维。
  • index:一维 int 张量或一维 long 张量。
  • start:零维 int 型张量或者 int 标量。
  • lengthint 标量。

作用:

index_select:返回一个全新的张量(即不共享内存)。输出张量和输入张量的维数相同,即 len(input.size())len(output.size());且除 dim 维二者长度可能不一致,其他维上长度均相同,这是因为该函数只对 dim 维按照 index 给出的索引选取。

narrow:返回的张量与 input 共享内存。在第 dim 维从 start 选取到 start + length

举例:

# torch.index_select
x = torch.randn(3, 4)
x
## tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
##         [-0.4664,  0.2647, -0.1228, -1.1068],
##         [-1.1734, -0.6571,  0.7230, -0.6004]])
indices = torch.tensor([0, 2])
torch.index_select(x, 0, indices)
## tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
##         [-1.1734, -0.6571,  0.7230, -0.6004]])
torch.index_select(x, 1, indices)
## tensor([[ 0.1427, -0.5414],
##         [-0.4664, -0.1228],
##         [-1.1734,  0.7230]])
# torch.narrow
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.narrow(x, 0, 0, 2)
## tensor([[ 1,  2,  3],
##         [ 4,  5,  6]])
torch.narrow(x, 1, 1, 2)
## tensor([[ 2,  3],
##         [ 5,  6],
##         [ 8,  9]])

masked_select

torch.masked_select(input, mask, out=None)

参数:

  • inputinputTensor
  • mask:布尔类型张量。

作用:返回由 maskTrue 对应位置的 input 中元素构成的一维张量。允许 mask 广播。

x = torch.randn(3, 4)
x
## tensor([[-0.5066,  0.0834, -0.0053, -1.2769],
##         [-0.0784, -0.9557,  0.2613, -1.1404],
##         [ 0.9312,  1.4116, -0.0264, -0.3287]])
mask = torch.ge(x, 0.5) # 大于等于 0.5 的位置为 True
mask
## tensor([[False, False, False, False],
##         [False, False, False, False],
##         [ True,  True, False, False]])
torch.masked_select(x, mask)
## tensor([0.9312, 1.4116])

torch.masked_select(x, torch.tensor([True, False, False, False])) # 广播
## tensor([-0.5066, -0.0784,  0.9312])

torch.masked_select(x, torch.tensor([[True, False, False, False]])) # 广播
## tensor([-0.5066, -0.0784,  0.9312])
                                     
torch.masked_select(x, torch.tensor([[False], [True], [False]])) #广播
## tensor([ 0.9312,  1.4116, -0.0264, -0.3287])

transpose、swapdims、t

torch.transpose(input, dim0, dim1)

torch.swapdims(input, dim0, dim1)

torch.t(input)

参数:

  • inputinputTensor
  • dim0dim1:两个不同的维度。

作用:

  • transpose:交换两个维度。如果是稠密张量,则贡献内存;如果是稀疏张量,则不共享内存。

  • swapdims:与 transpose 一致。

  • t:对于零维和一维张量返回本身;对于二维矩阵张量转置;无法处理三维及以上的张量。共享内存。

解释:交换维度的原理可以认为是,对于每个元素而言,交换 dim0dim1 对应的索引值。比如:张量 a = [[[1], [2]], [[3], [4]]]a.shape = (2, 2, 1),元素 1 对应的索引为 a[0][0][0] 中的 000,元素 2 对应的索引为 a[0][1][0] 中的 010,同理可得,元素 3 索引为 100,元素 4 索引为 110。如果交换的两个维度是 dim0=0dim1=1,那么交换后元素 1 的索引为 000,元素 2 的索引为 100,元素 3 的索引为 010,元素 4 的索引为 110,即交换元素的第一个和第二个索引值,按照交换后的索引对张量进行重构。

从交换索引位置的角度来理解就简单多了。

举例:

t = torch.arange(6).reshape((3, 2, 1))
t
## tensor([[[0],
##          [1]],
##
##         [[2],
##          [3]],
##
##         [[4],
##          [5]]])
torch.transpose(t, 0, 1)
## tensor([[[0],
##          [2],
##          [4]],
## 
##         [[1],
##          [3],
##          [5]]])
# torch.t
torch.t(torch.tensor(1))
## tensor(1)
torch.t(torch.tensor([1, 2, 3]))
## tensor([ 1, 2, 3])
x = torch.randn(2, 3)
x
## tensor([[ 0.4875,  0.9158, -0.5872],
##         [ 0.3938, -0.6929,  0.6932]])
torch.t(x)
## tensor([[ 0.4875,  0.3938],
##         [ 0.9158, -0.6929],
##         [-0.5872,  0.6932]])

reshape、movedim、moveaxis、permute

torch.reshape(input, shape)

torch.movedim(input, source, destination)

torch.moveaxis(input, source, destination)

torch.permute(input, dims)

参数:

  • inputinputTensor

  • shapeint 类型的元组。需要满足 shape 中元素的乘积与 input.shape 中元素的乘积相同。shape 元组中允许出现 -1-1 处对应的实际值自动推断给出,即 input.shape 除以 shape 中的非 -1 元素。很显然,只能出现一个 -1

  • sourcedestination :二者为同维 int 元组,或均为 int 标量。按顺序交换 sourcedestination 对应位置的索引。

  • dimsint 类型的元组。索引的全排列,即数字 0 ~ len(input.shape)-1 的一种全排列。

作用:

  • reshape:返回 input 大小变为 shape 的张量。reshape 变换前后的张量按照从低维到高维按顺序展平的序列是一致的。具体看下面与其他几个函数对比的例子。
  • movedim:当传入非元组参数时,与 torch.transpose 一致。当传入元组参数时,可以认为在按顺序执行多次 torch.transpose
  • moveaxis:与 movedim 一致。
  • permute:按照 dims 给定的维度序列对维度进行重排,本质上也是 transpose 函数的运行过程。

举例:

# torch.reshape
t = torch.arange(6).reshape((-1, 2, 1)) # -1 对应 3 # reshape 前的张量展开为 [1,2,3,4,5,6],reshape 后的张量展开相同
t
## tensor([[[0],
##          [1]],
## 
##         [[2],
##          [3]],
##
##         [[4],
##          [5]]])
# torch.movedim
a = torch.movedim(t, 0, 1) # 可以看出对 a 展开的序列为 [0,2,4,1,3,5],这与 reshape 不同
a, a.size()
## tensor([[[0],
##          [2],
##          [4]],
##
##         [[1],
##          [3],
##          [5]]])
## torch.Size([2, 3, 1])

b = torch.movedim(t, (1,2), (0,1)) # 等价于 torch.movedim( torch.movedim(t, 1, 0), 2, 1),即先后两次 movedim,第一次参数为 1,0,第二次参数为 2,1
b, b.size()
## tensor([[[0, 2, 4]],
## 
##         [[1, 3, 5]]])
## torch.Size([2, 1, 3])
# torch.permute
x = torch.permute(t, (2, 0, 1)) # t.shape = [3,2,1],permute 之后为 [1,3,2]
x, x.size()
## tensor([[[0, 1],
##          [2, 3],
##          [4, 5]]])
## torch.Size([1, 3, 2])
# 对比这几个函数,发现展开后序列不同的根本原因在于 strided 不同
t = torch.arange(6).reshape((3, 2, 1)) 
t.stride() # reshape
## (2, 1, 1)
torch.movedim(t, 0, 1).stride() # movedim
## (1, 2, 1)
torch.transpose(t, 0, 1).stride() # transpose
## (1, 2, 1)
torch.permute(t, (1, 0, 2)).stride() # permute
## (1, 2, 1)

select

torch.select(input, dim, index)

参数:

  • inputinputTensor

  • dim:被操作的是哪一维。

  • index:选择 dim 维的哪个索引。

作用:与直接对张量切片起到相同的作用。tensor.select(0, index) 等价于 tensor[index]tensor.select(2, index) 等价于 tensor[:,:,index]

举例:

a = torch.arange(9).reshape((3,3))
torch.select(a, dim=1, index=0)
## tensor([0, 3, 6])
a[:, 0]
## tensor([0, 3, 6])

squeeze

torch.squeeze(input, dim=None)

参数:

  • inputinputTensor

  • dim:被操作的是哪一维。

作用:如果没给出 dim,去掉 input 中长度为 1 的维度;如果给出 dimdim 维长度为 1,那么去掉该维;其他情况无效。共享内存。

举例:

a = torch.tensor([1])
torch.squeeze(a)
## tensor(1)
b = torch.arange(4).reshape((1, 2, 2, 1))
torch.squeeze(b, 0).shape
## torch.Size([2, 2, 1])
torch.squeeze(b, 1).shape
## torch.Size([1, 2, 2, 1])

take

torch.take(input, index)

参数:

  • inputinputTensor

  • indexlong 类型张量。

作用:返回具有给定索引的输入元素的全新张量。输入张量被视为一维张量。结果与索引的形状相同。

举例:

a = torch.arange(6).reshape((2, 3))
a
## tensor([[0, 1, 2],
##         [3, 4, 5]])
torch.take(a, index=torch.tensor([0, 1, 2]))
## tensor([0, 1, 2])
torch.take(a, index=torch.tensor([[0, 1], [4, 5]]))
## tensor([[0, 1],
##         [4, 5]])

gather

torch.gather(input, dim, index, sparse_grad=False, out=None)

参数:

  • inputinputTensor
  • dim:被操作的是哪一维。
  • index:索引张量。index 必须与 input 的维度数相同,并且对于任意的 d != dim index.size(d) <= input.size(d)

作用:沿着 dim 维将值聚集起来。通过一个三维张量下的通用公式表示:当 dim=0 时,out[i][j][k] = input[index[i][j][k]][j][k];当 dim=1 时,out[i][j][k] = input[i][index[i][j][k]][k];当 dim=2 时,out[i][j][k] = input[i][j][index[i][j][k]]

举例:

t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
## tensor([[ 1,  1],
##         [ 4,  3]])

tile

torch.tile(input, dims)

参数:

  • inputinputTensor
  • dims:每个维度中的重复次数。

作用:通过重复输入元素构造张量。与 Numpy 中的 tile 函数类似。input 的维数和 dims 的长度相比,如果 input 的维数小,那么在 input 的高维增维;如果 dims 长度小,那么在 dims 前面补 1。比如 input.size() = (2, 3, 4, 5)dims = (2, 1),那么 dims 应该被视为 (1, 1, 2, 1)

举例:

x = torch.tensor([1, 2, 3])
x.tile((2,))
## tensor([1, 2, 3, 1, 2, 3])

y = torch.arange(9).reshape((3, 3, 1))
torch.tile(y, (2,)) # 相当于 (1,1,2),即第0维度出现1次,第1维度出现1次,第2维度出现2次
## tensor([[[0, 0],
##          [1, 1],
##          [2, 2]],
## 
##         [[3, 3],
##          [4, 4],
##          [5, 5]],
## 
##         [[6, 6],
##          [7, 7],
##          [8, 8]]])

unbind

torch.unbind(input, dim=0)

参数:

  • inputinputTensor
  • dim:被操作的是哪一维。

作用:返回切片组成的元组。对于 dim=2 维长度为 L 的张量,unbind 返回 (input[:,:,0], input[:,:1], ..., input[;,:,L-1])

举例:

a = torch.arange(6).reshape((1, 2, 3))
torch.unbind(a, 0)
## (tensor([[0, 1, 2],
##          [3, 4, 5]]),)
torch.unbind(a, 1)
## (tensor([[0, 1, 2]]), tensor([[3, 4, 5]]))
torch.unbind(a, 2)
## (tensor([[0, 3]]), tensor([[1, 4]]), tensor([[2, 5]]))

where

torch.where(condition, x, y)

参数:

  • condition:布尔类型张量。
  • x:标量或张量。为真时对应的值。
  • y:标量或张量。为假时对应的值。

作用:condition 中为 True 或者可转换为 True 的位置上的值被赋值为与 x 位置对应上的值,y 为对应位置为 False 的位置被赋予的值,即满足公式 out_i = x_i if condition_i else y_iconditionxy 必须都是可广播的。

举例:

a = torch.randn((2, 3))
a
## tensor([[-0.7267,  0.0872, -0.8346],
##         [ 0.2803, -0.9863,  0.0977]])
x, y = 1, 0
torch.where(a > 0, x, y)
## tensor([[0, 1, 0],
##         [1, 0, 1]])
x, y = torch.tensor([1, 2, 3]), torch.tensor([[-1], [-2]]) # 广播
torch.where(a > 0, x, y)
## tensor([[-1,  2, -1],
##         [ 1, -2,  3]])

猜你喜欢

转载自blog.csdn.net/weixin_46221946/article/details/129492466