1. 广播机制定义
如果一个PyTorch操作支持广播,则其Tensor参数可以自动扩展为相等大小(不需要复制数据)。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。
2. 广播机制规则
2.1 如果遵守以下规则,则两个tensor是“可广播的”:
- 每个tensor至少有一个维度;
- 遍历tensor所有维度时,从末尾开始遍历(从右往左开始遍历)(从后往前开始遍历),两个tensor存在下列情况:
- tensor维度相等。
- tensor维度不等且其中一个维度为1。
- tensor维度不等且其中一个维度不存在。
2.2 如果两个tensor是“可广播的”,则计算过程遵循下列规则:
- 如果两个tensor的维度不同,则在维度较小的tensor的前面增加维度,使它们维度相等。
- 对于每个维度,计算结果的维度值取两个tensor中较大的那个值。
- 两个tensor扩展维度的过程是将数值进行复制。
3.代码举例
3.1 相同维度,一定可以 broadcasting。
# 相同维度,一定可以 broadcasting
x=torch.ones(5,7,3)
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 7, 3]), torch.Size([5, 7, 3]), torch.Size([5, 7, 3]))
3.2 x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting。
# x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting
x=torch.ones((0,))
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
3.3 x 和 y 可以广播。
# x 和 y 可以广播
x=torch.ones(5,3,4,1)
y=torch.ones( 3,1,1)
z = x+y
x.shape,y.shape,z.shape
# 从尾部维度开始遍历
# 1st尾部维度: x和y相同,都为1。
# 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。
# 3rd尾部维度: x和y相同,都为3。
# 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
输出结果如下:
(torch.Size([5, 3, 4, 1]), torch.Size([3, 1, 1]), torch.Size([5, 3, 4, 1]))
3.4 x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。
# x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。
x=torch.ones(5,2,4,1)
y=torch.ones( 3,1,1)
z = x+y
x.shape,y.shape,z.shape
3.5 x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等,同时使他们维度大小相同。
# x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。
x=torch.ones(5,2,4,1)
y=torch.ones(1,1)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 2, 4, 1]), torch.Size([1, 1]), torch.Size([5, 2, 4, 1]))
4. in - place 语义
in-place operation称为原地操作符,在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“”来代表原地操作符,例:.add _()、.scatter(),in-place操作不允许tensor使用广播机制那样来改变张量形状维度大小,如下例子所示。
# x 和 y 不可以广播
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
z = x.add_(y)
x.shape,y.shape,z.shape