Gan 网络生成图片

Gan 网络生成图片

本次文章使用 pytorch 框架对 Gan网络生产图片网络架构进行编写与实现, Gan网络架构自己可以自定义, 这里采用自己定义的Gan网络。(完整代码下面链接)

  • https://github.com/xiaoaleiBLUE/computer_vision(希望 starred一下)

一、Gan网络

  • GAN 应用: 数据生成, 图像翻译, 超分辨率(更高清), 图像补全。
  • Adversarial: 对抗对手的意思, 两个模型: Generator(生成模型)
  • Discriminator(判别模型, 分类模型)
  • 创作者(G)目标: 赝品骗过鉴别者 鉴别者(D)目标: 火眼金睛不被骗

二、Gan网络架构

  • Gan架构: Generator(生成模型), Discriminator(判别模型, 分类模型), 以下简称 G网络,D网络。
  • Gan 网络架构图:在这里插入图片描述

三、Gan网络实现基本思路

  • 在这里插入图片描述

四、本文章使用Gan网络架构

4.1 G, D网络架构

  • G网络架构(CxHxW通道排列形式) 100x1x1的tensor --> 512x4x4 --> 128x16x16 --> 64x32x32 --> 3x64x64
  • D网络架构(CxHxW通道排列形式) 3x64x64的三通道图像或者tensor --> 64x32x32 --> 128x16x16 --> 256x8x8 --> 512x4x4 --> 8192 --> 1
    在这里插入图片描述

4.2 ConvTranspose2d计算公式

  • out = (input-1) x stride - 2 x padding + kernel_size + out_padding
  • 比如G网络的 100 x 1 x 1的张量(tensor) --> 512 x 4 x 4
  • out = (1-1) x 1 - 2 x 0 + 4 + 0 = 4, 所以 H,W 为 4

4.3 G, D网络架构代码实现

  • G 网络, 生成者网络
# G 生成网络, 生成者网络
class G_model(nn.Module):

    def __init__(self):
        super(G_model, self).__init__()

        self.main = nn.Sequential(

            # 100*1*1(张量)  ----> 4*4*512
            # out = (input-1)*stride - 2*padding + kernel_size + out_padding
            # out = (1-1)*1 - 2*0 + 4 + 0 = 4
            nn.ConvTranspose2d(in_channels=100, out_channels=512, kernel_size=4, stride=1, padding=0, bias=False),
            # BN
            nn.BatchNorm2d(512),
            # Relu
            nn.ReLU(inplace=True),

            # 4*4*512  ----> 8*8*256
            # out = (4-1)*2 - 2*1 + 4 + 0 = 8
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            # BN
            nn.BatchNorm2d(256),
            # Relu
            nn.ReLU(inplace=True),

            # 8*8*256  ----> 16*16*128
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            # BN
            nn.BatchNorm2d(128),
            # Relu
            nn.ReLU(inplace=True),

            # 16*16*128  ----> 32*32*64
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            # BN
            nn.BatchNorm2d(64),
            # Relu
            nn.ReLU(inplace=True),

            # 32*32*64  ----> 64*64*3
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
            # tanh()
            nn.Tanh()

        )

    def forward(self, x):

        return self.main(x)
  • D 网络, 判别者网络
class D_model(nn.Module):

    def __init__(self):
        super(D_model, self).__init__()

        self.main = nn.Sequential(

            # 64*64*3  ----> 32*32*64
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1, bias=True),
            # Relu
            nn.LeakyReLU(0.2, inplace=True),

            # 32*32*64  ----> 16*16*128
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True),
            # BN
            nn.BatchNorm2d(128),
            # Relu
            nn.LeakyReLU(0.2, inplace=True),

            # 16*16*128  ----> 8*8*256
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=True),
            # BN
            nn.BatchNorm2d(256),
            # Relu
            nn.LeakyReLU(0.2, inplace=True),

            # 8*8*256  ----> 4*4*512=8192
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=True),
            # BN
            nn.BatchNorm2d(512),
            # Relu
            nn.LeakyReLU(0.2, inplace=True),

            # flatten
            nn.Flatten(),

            # 全连接
            nn.Linear(8192, 1),

            # sigmoid
            nn.Sigmoid()

        )

    def forward(self, x):

        return self.main(x)

4.4 数据集相关操作

  • 文件夹读取图片 --> dataset --> dataloader
# 导入数据
img_size = 64

img_preprocess = transforms.Compose([
    # 缩放
    transforms.Resize(img_size),
    # 中心裁剪 64*64 的正方形
    transforms.CenterCrop(img_size),
    # PIL图像转为tensor  H*W*C ---> C*H*W
    transforms.ToTensor(),
    # 归一化到[-1,1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 从文件夹读取图片
dataset = tdst.ImageFolder(root='./data/', transform=img_preprocess)

# 加载成 dataloader
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# 查看批次数量  12324 / 128 = 96.28, 向上取整
print(len(dataloader))
  • 显示dataloader 中的图片
# 显示图片, 其中 x[0]:对应的图片,有shape,  x[1]:对应类别标签,  x[0].shape:torch.Size([128, 3, 64, 64]) 128对应批次大小
for x in dataloader:
    # 设置画布大小
    fig = plt.figure(figsize=(8, 8))
    for i in range(16):

        plt.subplot(4, 4, i+1)

        # 转为 numpy, x[0][i].shape: torch.Size([3, 64, 64])
        img = x[0][i].numpy()

        # 通道顺序变换, 调整通道顺序为 PIL 格式
        img = np.transpose(img, (1, 2, 0))

        # 先转到[0,1],再乘以255
        img = (img + 1) / 2 * 255

        # 取整
        img = img.astype('int')

        # 显示
        plt.imshow(img)
        plt.axis('off')

    plt.show()
    break

4.5 优化器和损失函数

  • 对G网络, D网络分别定义损失函数和优化器
# loss 损失函数, 二分类的损失函数
loss_fn = nn.BCELoss()

# 优化器
D_optimizer = optim.Adam(D_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
G_optimizer = optim.Adam(G_model.parameters(), lr=0.0002, betas=(0.5, 0.999))

五、训练

  • 按照上述Gan基本思路进行训练
# 开始训练
Epoch_num = 100

for epoch in range(Epoch_num):

    # 获取批次图像
    start_time = time.time()

    for i, data in enumerate(dataloader):
        # 训练判别网络D:真实数据标记为1
        # 每次update前清空梯度
        D_model.zero_grad()
        # 获取数据, data[0]是图片, data[1]是类别
        imgs_batch = data[0].to(device)
        # 动态获取图片的batch_size
        b_size = imgs_batch.size(0)
        # 计算输出
        output = D_model(imgs_batch).view(-1)
        # 构建全1向量 label
        ones_label = torch.full((b_size, ), 1, dtype=torch.float, device=device)
        # 计算 loss
        d_loss_real = loss_fn(output, ones_label)
        # 反向传播
        d_loss_real.backward()
        # 梯度更新
        D_optimizer.step()

        # 训练判别网络D:假数据标记为0
        # 清楚梯度
        D_model.zero_grad()
        # 构建随机张量
        noise_tensor = torch.randn(b_size, 100, 1, 1, device=device)
        # 生成假的图片
        generated_imgs = G_model(noise_tensor)
        # 假图片的输出, 此时不需要训练 G, 
        output = D_model(generated_imgs.detach()).view(-1)
        # 构建全 0 向量
        zeros_label = torch.full((b_size, ), 0, dtype=torch.float, device=device)
        # 计算 loss
        d_loss_fake = loss_fn(output, zeros_label)
        # 反向传播
        d_loss_fake.backward()
        # 梯度更新
        D_optimizer.step()

        # 训练生成网络G:假数据标记为1
        # 清楚梯度
        G_model.zero_grad()
        # 随机张量
        noise_tensor = torch.randn(b_size, 100, 1, 1, device=device)
        # 生成假的图片
        generated_imgs = G_model(noise_tensor)
        # 假图片的输出
        output = D_model(generated_imgs).view(-1)
        # 构建全1向量
        ones_label = torch.full((b_size, ), 1, dtype=torch.float, device=device)
        # 计算 loss
        g_loss = loss_fn(output, ones_label)
        # 反向传播
        g_loss.backward()
        # 生成网络梯度更新
        G_optimizer.step()

    # 打印训练时间
    print('第{}个epoch所用时间: {}s'.format(epoch, time.time() - start_time))
    # 每一个 epoch 输出结果
    # 用 no_grad 表示梯度不跟踪
    with torch.no_grad():
        # 生成 16 个随机向量
        fixed_noise = torch.randn(16, 100, 1, 1, device=device)
        # 生成图片
        fake_imgs = G_model(fixed_noise).detach().cpu().numpy()
        # 画布大小
        fig = plt.figure(figsize=(10, 10))

        for i in range(fake_imgs.shape[0]):
            plt.subplot(4, 4, i+1)
            img = np.transpose(fake_imgs[i], (1, 2, 0))
            img = (img + 1) / 2 * 255
            img = img.astype('int')
            plt.imshow(img)
            plt.axis('off')

        plt.show()

六、训练结果显示

  • 在这里插入图片描述
    在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m0_60890175/article/details/130197076