1.clone()主要用于模块复用 数据进行复制,不共享同一内存,梯度可以回溯
c=torch.tensor(1.0,requires_grad=True)
b=c*2
d=b**2 (**)
b_=b.clone()
e_=b_**3
e_.backward(retain_graph=True)
"""
b.zero_() 这里的b是d.backward()的回溯节点(**),在回溯前不能进行in place 操作,
目的保证梯度计算正确,但如果是b_.zero_()就不会报错,因为clone不共享内存
"""
d.backward()
print(c.grad) #tensor(32.)
这里单独查看b_.grad或者b.grad都不存在,因为他们是中间变量,不需要保存,更新也是只更新叶子节点,此外要设置retain_graph=True,因为有一条线路上先进行了梯度回溯,为节省显存计算图会释放。
2.detach()主要用于数据的提取,共享同一内存,强制require_grad=False(即使设置为True也不进行梯度回溯)
c=torch.tensor(1.0,requires_grad=True)
b=c*2
w=b**2
b_=b.detach()
q=torch.tensor(1.0,requires_grad=True)
e_=q**b_
e_.backward()
#b_.zero_() 因为detach共享内存,这里进行in palce操作会报错
w.backward()
print(q.grad) #tensor(2.)