快速解决RuntimeError: one of the variables needed for gradient computation has been modified by an inpla

最近在使用生成对抗网络时,出现:
在这里插入图片描述
解决方法:
1、检查网络模型中,Relu等激活函数中,Replace是否为True,改为False。在这里插入图片描述
2、检查在前向传播forward中,是否有 a += 1这种操作,改成 a =a + 1
在这里插入图片描述
3、如果上述1、2不管用,这种情况在GAN中也较为常见。
下面展示一些 内联代码片


for epoch in range(nepoch):
    for i, (data, label) in enumerate(dataloader, 0):
        ###################################################################
	   (1)先用真实图像data和噪声生成的虚假图片fake_data训练鉴别器D        ###################################################################
        # 先用Real标签训练鉴别器D
        D.zero_grad()
        data, label, batch_size  = data.to(device), label.to(device), data.shape[0]
         real_label = torch.ones(batch_size).to(device)  # 定义真实的图片label为1
        fake_label = torch.zeros(batch_size).to(device)  # 定义假的图片的label为0
        #Real输入进鉴别器D的输出与Real标签进行loss
        output = D(data)
        errD_real = criterion(output, real_label)
        # 输入随机变量,用Fake标签训练鉴别器D
        # 随机产生一个潜在变量,然后通过decoder 产生生成图片
        z = torch.randn(batch_size, nz).to(device)
        # 通过vae的decoder把潜在变量z变成虚假图片
        fake_data = vae.decoder_fc(z).view(z.shape[0], 32, 7, 7)
        fake_data = vae.decoder(fake_data)
        #Fake输入进鉴别器D的输出与Fake标签进行loss
        output = D(fake_data)
        errD_fake = criterion(output, fake_label)
        errD = errD_real + errD_fake
        errD .backward()
        optimizerD.step()
        ###################################################
        # (2) Update G network which is the decoder of VAE
        ###################################################
        recon_data, mean, logstd = vae(data)
        vae.zero_grad()
        vae_loss = loss_function(recon_data, data, mean, logstd)
        vae_loss.backward(retain_graph=True)
        optimizerVAE.step()
        ###############################################
        # (3) Update G network: maximize log(D(G(z)))
        ###############################################
        vae.zero_grad()
        real_label = torch.ones(batch_size).to(device)  # 定义真实的图片label为1
        # output = D(recon_data)
        # 即在生成器更新完之后使用更新之后的生成器输出结果来计算判别器loss,用新计算出来的loss对判别器进行更新。
        # 使用.detach():将张量从计算图中分离出来
        output = D(recon_data.detach())
        errVAE = criterion(output, real_label)
        errVAE.backward()
        D_G_z2 = output.mean().item()
        optimizerVAE.step()

猜你喜欢

转载自blog.csdn.net/weixin_50557558/article/details/139459869