pytorch - data process - 赋值

reference

https://icepoint666.github.io/2019/05/27/pytorch-clone/

https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter02_prerequisite/2.2_tensor?id=_222-%e6%93%8d%e4%bd%9c

直接赋值 tensor 易错点

  • torch中对于直接赋值的这种操作一定要小心,先看看是不是需要直接赋值,还是重新开辟一块内存来存放
  • 我们可以使用类似NumPy的索引操作来访问Tensor的一部分,需要注意的是:索引出来的结果与原数据共享内存,也即修改一个,另一个会跟着修改
  • view()返回的新Tensor与源Tensor虽然可能有不同的size,但是是共享data的,也即更改其中的一个,另外一个也会跟着改变。(顾名思义,view仅仅是改变了对这个张量的观察角度,内部数据并未改变)

1.索引出来的结果与原数据共享内存

import torch

x = torch.rand(5, 3)
print(x, '\n')

y = x[0, :]
y += 1
print(y)
print(x[0, :]) # 源tensor也被改了
tensor([[0.8773, 0.6006, 0.7565],
        [0.2363, 0.7644, 0.4871],
        [0.7156, 0.1406, 0.3292],
        [0.8075, 0.4868, 0.7283],
        [0.4470, 0.4558, 0.3665]]) 

tensor([1.8773, 1.6006, 1.7565])
tensor([1.8773, 1.6006, 1.7565])

2.view()

view()返回的新Tensor与源Tensor虽然可能有不同的size。

y = x.view(15)

z = x.view(-1, 5)  # -1所指的维度可以根据其他维度的值推出来

print(x.size(), y.size(), z.size())


torch.Size([5, 3]) torch.Size([15]) torch.Size([3, 5])

但是,但是是共享data的,也即更改其中的一个,另外一个也会跟着改变

(顾名思义,view仅仅是改变了对这个张量的观察角度,内部数据并未改变)。

x += 1

print(x)
print(y) # 也加了1

tensor([[2.8773, 2.6006, 2.7565],
        [1.2363, 1.7644, 1.4871],
        [1.7156, 1.1406, 1.3292],
        [1.8075, 1.4868, 1.7283],
        [1.4470, 1.4558, 1.3665]])

tensor([2.8773, 2.6006, 2.7565, 1.2363, 1.7644, 1.4871, 1.7156, 1.1406, 1.3292,
        1.8075, 1.4868, 1.7283, 1.4470, 1.4558, 1.3665])

3.如何实现 赋值

用 .clone() 创建一个新的副本,开辟新的内存以存贮数据。

x_cp = x.clone().view(15)
x -= 1

print(x)

print(x_cp)     # copy version, do not change


tensor([[1.8773, 1.6006, 1.7565],
        [0.2363, 0.7644, 0.4871],
        [0.7156, 0.1406, 0.3292],
        [0.8075, 0.4868, 0.7283],
        [0.4470, 0.4558, 0.3665]])

tensor([2.8773, 2.6006, 2.7565, 1.2363, 1.7644, 1.4871, 1.7156, 1.1406, 1.3292,
        1.8075, 1.4868, 1.7283, 1.4470, 1.4558, 1.3665])

猜你喜欢

转载自blog.csdn.net/Zhou_Dao/article/details/115259966