(10-3-04)基于多模态模型的文生图系统:多模态成对抗网络(GAN)模型(4)实现多模态GAN模型

10.4.4  实现多模态GAN模型

文件CLIP-GAN.py实现了条件生成对抗网络(GAN)模型,并在GAN模型中集成了 CLIP(对比性语言-图像预训练)功能,最终实现了多模态GAN模型。集成 CLIP 实现了文本与图像之间的紧密联系,使得生成的图像不仅在视觉上与描述相符,而且在语义上也更加一致。因此,集成 CLIP 可以提高生成图像的质量、多样性和语义准确性,为生成对抗网络引入了更多的语义理解能力。

注意:在本项目中,集成 CLIP 的主要原因是实现多模态,即让生成对抗网络(GAN)模型能够同时处理文本和图像信息,从而生成与给定文本描述相匹配的图像。通过将文本和图像嵌入到一个共同的语义空间中,CLIP 模型使得文本和图像之间的语义关联得以建模。这样一来,GAN 模型可以利用 CLIP 提供的文本描述信息来指导生成器生成更符合描述要求的图像。

文件CLIP-GAN.py的具体实现流程如下所示。

(1)定义类CLIP_IMG_ENCODER实现CLIP 图像编码器模块,用于将图像编码为 CLIP 模型可理解的格式,并提取图像的局部特征和编码的图像嵌入。模块首先对输入图像进行预处理和标准化,然后通过 CLIP 的视觉 transformer 模型提取特征。最终输出局部特征和编码的图像嵌入。

class CLIP_IMG_ENCODER(nn.Module):
    """
       CLIP_IMG_ENCODER 模块用于使用 CLIP 的视觉 transformer 对图像进行编码。
    """
    def __init__(self, CLIP):
        """
        初始化 CLIP_IMG_ENCODER 模块。
        Args:
            CLIP (CLIP): 预训练的 CLIP 模型。
        """
        super(CLIP_IMG_ENCODER, self).__init__()
        model = CLIP.visual
        self.define_module(model)
        # 冻结 CLIP 模型的参数
        for param in self.parameters():
            param.requires_grad = False

    def define_module(self, model):
        """
        定义 CLIP 视觉 transformer 模型的各个层和模块。
        Args:
            model (nn.Module): CLIP 视觉 transformer 模型。
        """
        # 从 CLIP 模型中提取所需的模块
        self.conv1 = model.conv1  # 卷积层
        self.class_embedding = model.class_embedding  # 类别嵌入层
        self.positional_embedding = model.positional_embedding  # 位置嵌入层
        self.ln_pre = model.ln_pre  # 预归一化线性归一化层
        self.transformer = model.transformer  # Transformer 块
        self.ln_post = model.ln_post  # 后归一化线性归一化层
        self.proj = model.proj  # 投影矩阵

    @property
    def dtype(self):
        """
         获取卷积层权重的数据类型。
        """
        return self.conv1.weight.dtype

    def transf_to_CLIP_input(self, inputs):
        """
        将输入图像转换为 CLIP 预期的格式。
        Args:
            inputs (torch.Tensor): 输入图像。

        Returns:
            torch.Tensor: 转换后的图像。
        """
        device = inputs.device
        # 检查输入图像张量的大小
        if len(inputs.size()) != 4:
            raise ValueError('期望 (B, C, X, Y) 张量。')
        else:
            # 标准化输入图像
            mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
            var = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
            inputs = F.interpolate(inputs * 0.5 + 0.5, size=(224, 224))
            inputs = ((inputs + 1) * 0.5 - mean) / var
            return inputs

    def forward(self, img: torch.Tensor):
        """
        CLIP_IMG_ENCODER 模块的前向传播。
        Args:
            img (torch.Tensor): 输入图像。
        Returns:
            torch.Tensor: 从图像中提取的局部特征。
            torch.Tensor: 编码的图像嵌入。
        """
        # 将输入图像转换为 CLIP 预期的格式,并将其数据类型设置为适当的类型
        x = self.transf_to_CLIP_input(img)
        x = x.type(self.dtype)
        # 通过卷积层传递图像
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        grid = x.size(-1)

        # 重塑并对换张量以进行 transformer 输入
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        # 添加类别和位置嵌入
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        # NLD (Batch Size - Length - Dimension) -> LND (Length - Batch Size - Dimension)
        x = x.permute(1, 0, 2)
        # 使用 transformer 块提取局部特征
        selected = [1, 4, 8]
        local_features = []
        for i in range(12):
            x = self.transformer.resblocks[i](x)
            if i in selected:
                local_features.append(
                    x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(
                        img.dtype))
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])
        if self.proj is not None:
            x = x @ self.proj  # 使用投影矩阵和张量执行矩阵乘法
        return torch.stack(local_features, dim=1), x.type(img.dtype)

(2)定义类CLIP_TXT_ENCODER 实现CLIP 文本编码器模块,用于将文本输入编码为 CLIP 模型可理解的格式,并提取句子嵌入和 transformer 输出。模块首先对输入文本进行标记嵌入,并添加位置嵌入,然后通过 CLIP 的 transformer 模型进行前向传播。最终输出编码的句子嵌入和 transformer 输出。

class CLIP_TXT_ENCODER(nn.Module):
    """
        CLIP_TXT_ENCODER 模块用于使用 CLIP 的 transformer 对文本输入进行编码。
    """
    def __init__(self, CLIP):
        """
        初始化 CLIP_TXT_ENCODER 模块。
        Args:
            CLIP (CLIP): 预训练的 CLIP 模型。
        """
        super(CLIP_TXT_ENCODER, self).__init__()
        self.define_module(CLIP)
        # 冻结 CLIP 模型的参数
        for param in self.parameters():
            param.requires_grad = False

    def define_module(self, CLIP):
        """
        定义 CLIP transformer 模型的各个模块。
        Args:
            CLIP (CLIP): 预训练的 CLIP 模型。
        """
        self.transformer = CLIP.transformer  # Transformer 块
        self.vocab_size = CLIP.vocab_size  # Transformer 词汇表的大小
        self.token_embedding = CLIP.token_embedding  # token 嵌入块
        self.positional_embedding = CLIP.positional_embedding  # 位置嵌入块
        self.ln_final = CLIP.ln_final  # 线性归一化层
        self.text_projection = CLIP.text_projection  # 文本的投影矩阵

    @property
    def dtype(self):
        """
        获取 transformer 中第一层权重的数据类型。
        """
        return self.transformer.resblocks[0].mlp.c_fc.weight.dtype

    def forward(self, text):
        """
        CLIP_TXT_ENCODER 模块的前向传播。
        Args:
            text (torch.Tensor): 输入文本标记。

        Returns:
            torch.Tensor: 编码的句子嵌入。
            torch.Tensor: 输入文本的 transformer 输出。
        """
        # 嵌入输入文本标记
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
        # 添加位置嵌入
        x = x + self.positional_embedding.type(self.dtype)
        # 对 transformer 输入进行维度转换
        x = x.permute(1, 0, 2)  # NLD -> LND
        # 通过 transformer 进行前向传播
        x = self.transformer(x)
        # 将维度重新排列为原始形状
        x = x.permute(1, 0, 2)  # LND -> NLD
        # 应用层归一化
        x = self.ln_final(x).type(self.dtype)  # shape = [batch_size, n_ctx, transformer.width]
        # 从文本末尾(eot_token:每个序列中的最高数字)提取句子嵌入
        sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        # 返回句子嵌入和 transformer 输出
        return sent_emb, x

(3)定义类CLIP_Mapper 实现CLIP 映射器模块,用于将图像和提示映射为 CLIP 模型可理解的特征。模块首先将输入图像和提示转换为适当的数据类型,并通过 CLIP 的 transformer 模型进行前向传播。最终输出映射的特征。

class CLIP_Mapper(nn.Module):
    """
    CLIP_Mapper 模块用于使用 CLIP 的 transformer 对图像和提示进行映射。
    """
    def __init__(self, CLIP):
        """
        初始化 CLIP_Mapper 模块。
        Args:
            CLIP (CLIP): 预训练的 CLIP 模型。
        """
        super(CLIP_Mapper, self).__init__()
        model = CLIP.visual
        self.define_module(model)
        # 冻结 CLIP 视觉模型的参数
        for param in model.parameters():
            param.requires_grad = False

    def define_module(self, model):
        """
        定义 CLIP 视觉模型的各个模块。
        Args:
            model: 预训练的 CLIP 视觉模型。
        """
        self.conv1 = model.conv1
        self.class_embedding = model.class_embedding
        self.positional_embedding = model.positional_embedding
        self.ln_pre = model.ln_pre
        self.transformer = model.transformer

    @property
    def dtype(self):
        """
        获取第一个卷积层权重的数据类型。
        """
        return self.conv1.weight.dtype

    def forward(self, img: torch.Tensor, prompts: torch.Tensor):
        """
        CLIP_Mapper 模块的前向传播。
        Args:
            img (torch.Tensor): 输入图像张量。
            prompts (torch.Tensor): 用于映射的提示标记。

        Returns:
            torch.Tensor: 从 CLIP 模型中映射的特征。
        """
        # 将输入图像和提示转换为适当的数据类型
        x = img.type(self.dtype)
        prompts = prompts.type(self.dtype)
        grid = x.size(-1)
        # 重塑输入图像张量
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        # 将类别嵌入附加到输入张量
        x = torch.cat(
            [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
             x],
            dim=1
        )  # shape = [*, grid ** 2 + 1, width]

        # 将位置嵌入附加到输入张量
        x = x + self.positional_embedding.to(x.dtype)
        # 执行层归一化
        x = self.ln_pre(x)
        # NLD -> LND
        x = x.permute(1, 0, 2)
        # 本地特征
        selected = [1, 2, 3, 4, 5, 6, 7, 8]
        begin, end = 0, 12
        prompt_idx = 0
        for i in range(begin, end):
            # 将提示添加到输入张量
            if i in selected:
                prompt = prompts[:, prompt_idx, :].unsqueeze(0)
                prompt_idx = prompt_idx + 1
                x = torch.cat((x, prompt), dim=0)
                x = self.transformer.resblocks[i](x)
                x = x[:-1, :, :]
            else:
                x = self.transformer.resblocks[i](x)
        # 重塑并返回映射的特征
        return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)

(4)定义类CLIP_Adapter实现CLIP调整期模块,用于将生成器的特征调整为与 CLIP 模型的输入要求相匹配。它接收来自生成器的输出特征和条件向量作为输入,通过多个映射块的特征块对特征进行调整,并将其融合。然后,将融合的特征映射到 CLIP 模型的输入空间,并返回处理后的特征。

class CLIP_Adapter(nn.Module):
    """
    CLIP_Adapter 模块用于调整来自生成器的特征以匹配 CLIP 模型的输入要求。
    """
    def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP):
        """
        初始化 CLIP_Adapter 模块。
        Args:
            in_ch (int): 输入通道数。
            mid_ch (int): 中间层的通道数。
            out_ch (int): 输出通道数。
            G_ch (int): 生成器输出的通道数。
            CLIP_ch (int): CLIP 模型输入的通道数。
            cond_dim (int): 条件向量的维度。
            k (int): 卷积层的核大小。
            s (int): 卷积层的步长。
            p (int): 卷积层的填充。
            map_num (int): 映射块的数量。
            CLIP: 预训练的 CLIP 模型。
        """
        super(CLIP_Adapter, self).__init__()
        self.CLIP_ch = CLIP_ch
        self.FBlocks = nn.ModuleList([])
        # 定义映射块(M_Block)并将它们添加到给定数量的特征块(FBlock)中。
        self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p))
        for i in range(map_num - 1):
            self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p))
        # 用于融合适应特征的卷积层
        self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2)
        # CLIP 映射器模块,将适应特征映射到 CLIP 的输入空间
        self.CLIP_ViT = CLIP_Mapper(CLIP)
        # 用于进一步处理映射特征的卷积层
        self.conv = nn.Conv2d(768, G_ch, 5, 1, 2)
        # 用于条件的全连接层
        self.fc_prompt = nn.Linear(cond_dim, CLIP_ch * 8)

    def forward(self, out, c):
        """
        CLIP_Adapter 模块的前向传播。接受来自生成器的输出特征和条件向量作为输入,
        使用具有多个映射块的特征块对特征进行调整,融合它们,将其映射到 CLIP 的输入空间,并返回处理后的特征。
        Args:
            out (torch.Tensor): 生成器的输出特征。
            c (torch.Tensor): 条件向量。
        Returns:
            torch.Tensor: 为生成器调整和映射的特征。
        """
        # 从条件向量生成提示
        prompts = self.fc_prompt(c).view(c.size(0), -1, self.CLIP_ch)
        # 通过包含多个映射块的特征块传递特征
        for FBlock in self.FBlocks:
            out = FBlock(out, c)
        # 融合适应特征
        fuse_feat = self.conv_fuse(out)
        # 将融合特征映射到 CLIP 的输入空间
        map_feat = self.CLIP_ViT(fuse_feat, prompts)
        # 进一步处理映射特征并返回
        return self.conv(fuse_feat + 0.1 * map_feat)

(5)定义类NetG实现生成器网络模块,用于根据文本和噪声合成图像。将输入噪声向量通过全连接层转换为特征图,并使用 CLIP Mapper 适配特征。然后,通过 GBlocks 逐渐上采样特征表示,并融合文本和视觉特征。最终将特征表示转换为 RGB 图像。

class NetG(nn.Module):
    """
    用于根据文本和噪声合成图像的生成器网络。
    """
    def __init__(self, ngf, nz, cond_dim, imsize, ch_size, mixed_precision, CLIP):
        """
        初始化生成器网络。
        参数:
            ngf (int): 生成器滤波器的数量。
            nz (int): 输入噪声向量的维度。
            cond_dim (int): 条件向量的维度。
            imsize (int): 生成图像的大小。
            ch_size (int): 生成图像的输出通道数。
            mixed_precision (bool): 是否使用混合精度训练。
            CLIP: 用于特征适配的 CLIP 模型。
        """
        super(NetG, self).__init__()
        # 定义属性
        self.ngf = ngf
        self.mixed_precision = mixed_precision
        # 构建 CLIP Mapper
        self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32
        self.CLIP_ch = 768
        # 将噪声向量转换为特征图的全连接层(维度为 code_sz * code_sz * code_ch)
        self.fc_code = nn.Linear(nz, self.code_sz * self.code_sz * self.code_ch)
        self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf * 8, self.CLIP_ch, cond_dim + nz, 3, 1, 1, 4, CLIP)
        # 构建 GBlocks
        self.GBlocks = nn.ModuleList([])
        in_out_pairs = list(get_G_in_out_chs(ngf, imsize))
        imsize = 4
        for idx, (in_ch, out_ch) in enumerate(in_out_pairs):
            if idx < (len(in_out_pairs) - 1):
                imsize = imsize * 2
            else:
                imsize = 224
            self.GBlocks.append(G_Block(cond_dim + nz, in_ch, out_ch, imsize))

        # 使用带有 leakyReLU 激活函数的序列层进行 RGB 图像转换
        self.to_rgb = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_ch, ch_size, 3, 1, 1),
        )

    def forward(self, noise, c, eval=False):  # x=noise, c=ent_emb
        """
        生成器网络的前向传播。
        Args:
            noise (torch.Tensor): 输入噪声向量。
            c (torch.Tensor): 条件信息,通常是表示输出属性的嵌入。
            eval (bool, optional): 指示网络是否处于评估模式的标志。默认为 False。
        Returns:
            torch.Tensor: 生成的 RGB 图像。
        """
        # 启用自动混合精度训练的上下文管理器
        with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp:
            # 将噪声和条件信息连接起来
            cond = torch.cat((noise, c), dim=1)
            # 通过全连接层传递噪声以生成特征图,并使用 CLIP Mapper 适配特征
            out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond)
            # 应用 GBlocks 逐渐上采样特征表示,融合文本和视觉特征
            for GBlock in self.GBlocks:
                out = GBlock(out, cond)

            # 将最终特征表示转换为 RGB 图像
            out = self.to_rgb(out)
        return out

(6)定义类NetD实现鉴别器网络模块,用于评估图像的真实性。NetD包含了一系列 D_Block 模块用于处理特征图,并使用主要的 D_Block 进行最终处理。

class NetD(nn.Module):
    """
    用于评估图像真实性的鉴别器网络。
    Attributes:
        DBlocks (nn.ModuleList): 用于处理特征图的 D_Block 模块列表。
        main (D_Block): 主要的 D_Block 模块用于最终处理。
    """
    def __init__(self, ndf, imsize, ch_size, mixed_precision):
        """
        初始化鉴别器网络
        Args:
        ndf (int): 初始特征中的通道数。
        imsize (int): 输入图像的大小(假设为正方形)。
        ch_size (int): 输出特征图中的通道数。
        mixed_precision (bool): 是否使用混合精度训练的标志。
        """
        super(NetD, self).__init__()
        self.mixed_precision = mixed_precision
        # 定义 DBlock
        self.DBlocks = nn.ModuleList([
            D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
            D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
        ])
        # 定义主要的 DBlock 用于最终处理
        self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False)

    def forward(self, h):
        """
        鉴别器网络的前向传播。
        Args:
            h (torch.Tensor): 输入特征图。
        Returns:
            torch.Tensor: 鉴别器输出。
        """
        with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
            # 初始特征图
            out = h[:, 0]
            # 通过每个 DBlock 传递输入特征
            for idx in range(len(self.DBlocks)):
                out = self.DBlocks[idx](out, h[:, idx + 1])
            # 通过主要的 DBlock 进行最终处理
            out = self.main(out)
        return out

(7)定义类NetC实现分类器/比较器网络模块,用于对生成器输出和条件文本的联合特征进行分类。类NetC包含了一系列卷积层用于特征提取,并使用这些特征进行分类。

class NetC(nn.Module):
    """
    分类器/比较器网络,用于对生成器输出和条件文本的联合特征进行分类。
    Attributes:
        cond_dim (int): 条件信息的维度。
        mixed_precision (bool): 是否使用混合精度训练的标志。
        joint_conv (nn.Sequential): 定义分类器层的序列模块。
    """
    def __init__(self, ndf, cond_dim, mixed_precision):
        super(NetC, self).__init__()
        self.cond_dim = cond_dim
        self.mixed_precision = mixed_precision
        # 定义分类器层,包括连续的二维卷积层和 LeakyReLU 作为激活函数
        self.joint_conv = nn.Sequential(
            nn.Conv2d(512 + 512, 128, 4, 1, 0, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0, bias=False),
        )

    def forward(self, out, cond):
        """
        分类器网络的前向传播。
        Args:
            out (torch.Tensor): 生成器输出的特征图。
            cond (torch.Tensor): 条件信息向量
        """
        with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
            # 重新整形并重复条件信息向量以匹配特征图大小
            cond = cond.view(-1, self.cond_dim, 1, 1)
            cond = cond.repeat(1, 1, 7, 7)
            # 连接特征图和条件信息
            h_c_code = torch.cat((out, cond), 1)

            # 通过分类器层
            out = self.joint_conv(h_c_code)
        return out

(8)定义类M_Block实现了多尺度块(M_Block)模块,用于实现生成器中的特征映射处理,包含两个卷积层和条件模块,实现了残差连接和快捷连接,可以根据输入的条件信息调整特征映射,用于生成器网络中的特征处理和条件信息融合。

class M_Block(nn.Module):
    """
    多尺度块,由卷积层和条件模块组成。
    Attributes:
        conv1 (nn.Conv2d): 第一个卷积层。
        fuse1 (DFBlock): 第一个卷积层的条件模块。
        conv2 (nn.Conv2d): 第二个卷积层。
        fuse2 (DFBlock): 第二个卷积层的条件模块。
        learnable_sc (bool): 指示快捷连接是否可学习的标志。
        c_sc (nn.Conv2d): 用于快捷连接的卷积层。
    """
    def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p):
        """
        初始化多尺度块。
        Args:
            in_ch (int): 输入通道数。
            mid_ch (int): 中间层的通道数。
            out_ch (int): 输出通道数。
            cond_dim (int): 条件信息的维度。
            k (int): 卷积层的核大小。
            s (int): 卷积层的步幅。
            p (int): 卷积层的填充。
        """
        super(M_Block, self).__init__()
        # 定义卷积层和条件模块
        self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p)
        self.fuse1 = DFBLK(cond_dim, mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p)
        self.fuse2 = DFBLK(cond_dim, out_ch)
        # 可学习的快捷连接
        self.learnable_sc = in_ch != out_ch
        if self.learnable_sc:
            self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)

    def shortcut(self, x):
        """
        定义快捷连接。
        Args:
            x (torch.Tensor): 输入张量。
        Returns:
            torch.Tensor: 快捷连接的输出。
        """
        if self.learnable_sc:
            x = self.c_sc(x)
        return x

    def residual(self, h, text):
        """
        定义带有条件的残差路径。
        Args:
            h (torch.Tensor): 输入张量。
            text (torch.Tensor): 条件信息。
        Returns:
            torch.Tensor: 残差路径的输出。
        """
        h = self.conv1(h)
        h = self.fuse1(h, text)
        h = self.conv2(h)
        h = self.fuse2(h, text)
        return h

    def forward(self, h, c):
        """
        多尺度块的前向传播。
        Args:
            h (torch.Tensor): 输入张量。
            c (torch.Tensor): 条件信息。
        Returns:
            torch.Tensor: 输出张量。
        """
        return self.shortcut(h) + self.residual(h, c)

(9)定义生成器中的一个生成块(G_Block),包含了两个卷积层和条件模块,用于处理特征映射和融合条件信息,实现了残差连接和可学习的快捷连接,用于实现生成器网络中的特征处理和条件信息融合功能,并生成输出图像。

class G_Block(nn.Module):
    """
    生成器块,包含卷积层和条件模块。
    属性:
        imsize (int): 输出图像的尺寸。
        learnable_sc (bool): 标志是否可学习的快捷连接。
        c1 (nn.Conv2d): 第一个卷积层。
        c2 (nn.Conv2d): 第二个卷积层。
        fuse1 (DFBLK): 第一个卷积层的条件模块。
        fuse2 (DFBLK): 第二个卷积层的条件模块。
        c_sc (nn.Conv2d): 用于快捷连接的卷积层。
    """
    def __init__(self, cond_dim, in_ch, out_ch, imsize):
        """
        初始化生成器块。
        Args:
            cond_dim (int): 条件信息的维度。
            in_ch (int): 输入通道数。
            out_ch (int): 输出通道数。
            imsize (int): 输出图像的尺寸。
        """
        super(G_Block, self).__init__()
        # 初始化属性
        self.imsize = imsize
        self.learnable_sc = in_ch != out_ch
        # 定义卷积层和条件模块
        self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
        self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
        self.fuse1 = DFBLK(cond_dim, in_ch)
        self.fuse2 = DFBLK(cond_dim, out_ch)
        # 可学习的快捷连接
        if self.learnable_sc:
            self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)

    def shortcut(self, x):
        """
        定义快捷连接。
        Args:
            x (torch.Tensor): 输入张量。
        Returns:
            torch.Tensor: 快捷连接的输出。
        """
        if self.learnable_sc:
            x = self.c_sc(x)
        return x

    def residual(self, h, y):
        """
        定义带条件的残差路径。
        Args:
            h (torch.Tensor): 输入张量。
            y (torch.Tensor): 条件信息。
        Returns:
            torch.Tensor: 残差路径的输出。
        """
        h = self.fuse1(h, y)
        h = self.c1(h)
        h = self.fuse2(h, y)
        h = self.c2(h)
        return h

    def forward(self, h, y):
        """
        生成器块的前向传播。
        Args:
            h (torch.Tensor): 输入张量。
            y (torch.Tensor): 条件信息。
        Returns:
            torch.Tensor: 输出张量。
        """
        h = F.interpolate(h, size=(self.imsize, self.imsize))
        return self.shortcut(h) + self.residual(h, y)

(10)下面代码定义了一个判别器中的一个块,包含了卷积层和残差连接,可以选择是否包含CLIP特征,用于处理输入特征和特征融合,输出判别器的特征表示。

class D_Block(nn.Module):
    """
    判别器块。
    """
    def __init__(self, fin, fout, k, s, p, res, CLIP_feat):
        """
        初始化判别器块。
        Args:
        - fin (int): 输入通道数。
        - fout (int): 输出通道数。
        - k (int): 卷积层的核大小。
        - s (int): 卷积层的步长。
        - p (int): 卷积层的填充。
        - res (bool): 是否使用残差连接。
        - CLIP_feat (bool): 是否使用CLIP特征。
        """
        super(D_Block, self).__init__()
        self.res, self.CLIP_feat = res, CLIP_feat
        self.learned_shortcut = (fin != fout)

        # 用于残差路径的卷积层
        self.conv_r = nn.Sequential(
            nn.Conv2d(fin, fout, k, s, p, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(fout, fout, k, s, p, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        # 用于快捷连接的卷积层
        self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0)
        # 用于学习的残差和CLIP特征参数
        if self.res == True:
            self.gamma = nn.Parameter(torch.zeros(1))
        if self.CLIP_feat == True:
            self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x, CLIP_feat=None):
        """
        判别器块的前向传播。
        Args:
        - x (torch.Tensor): 输入张量。
        - CLIP_feat (torch.Tensor): 可选的CLIP特征张量。
        Returns:
        - torch.Tensor: 输出张量。
        """
        # 计算残差特征
        res = self.conv_r(x)

        # 计算快捷连接
        if self.learned_shortcut:
            x = self.conv_s(x)
        # 如果启用了学习的残差和CLIP特征,则加入
        if (self.res == True) and (self.CLIP_feat == True):
            return x + self.gamma * res + self.beta * CLIP_feat
        elif (self.res == True) and (self.CLIP_feat != True):
            return x + self.gamma * res
        elif (self.res != True) and (self.CLIP_feat == True):
            return x + self.beta * CLIP_feat
        else:
            return x

(11)下面代码定义了生成器网络中的扩散块,其中包含了条件特征块,用于处理输入特征和条件信息的仿射变换,并应用激活函数,输出特征表示。

class DFBLK(nn.Module):
    """
    生成器网络的扩散块,带有条件特征块
    """
    def __init__(self, cond_dim, in_ch):
        """
        初始化DFBlock的条件特征块。
        Args:
        - cond_dim (int): 条件输入的维度。
        - in_ch (int): 输入通道数。
        """
        super(DFBLK, self).__init__()
        # 定义条件仿射变换
        self.affine0 = Affine(cond_dim, in_ch)
        self.affine1 = Affine(cond_dim, in_ch)

    def forward(self, x, y=None):
        """
        条件特征块的前向传播。
        Args:
        - x (torch.Tensor): 输入张量。
        - y (torch.Tensor, optional): 条件输入张量。默认为None。

        Returns:
        - torch.Tensor: 输出张量。
        """
        # 应用第一个仿射变换和激活函数
        h = self.affine0(x, y)
        h = nn.LeakyReLU(0.2, inplace=True)(h)
        # 应用第二个仿射变换和激活函数
        h = self.affine1(h, y)
        h = nn.LeakyReLU(0.2, inplace=True)(h)
        return h

(12)下面代码定义了QuickGELU激活函数,这是GELU的高效和更快速的版本,用于非线性变换,有助于网络学习复杂的模式。

class QuickGELU(nn.Module):
    """
    GELU的高效和更快速的版本,
    用于非线性变换和学习复杂模式
    """
    def forward(self, x: torch.Tensor):
        """
        QuickGELU激活函数的前向传播。
        Args:
        - x (torch.Tensor): 输入张量。
        Returns:
        - torch.Tensor: 输出张量。
        """
        # 应用QuickGELU激活函数
        return x * torch.sigmoid(1.702 * x)

(13)下面代码定义了一个仿射变换模块,用于在输入特征上应用条件缩放和平移,以便根据输入条件对生成的输出进行额外控制。模块包含两个全连接网络来计算 gamma 和 beta 参数,然后将这些参数应用于输入特征,以实现仿射变换。

class Affine(nn.Module):
    """
    仿射变换模块,对输入特征应用条件缩放和平移,
    以根据输入条件对生成的输出进行额外控制。
    """
    def __init__(self, cond_dim, num_features):
        """
        初始化仿射变换模块。
        Args:
            cond_dim (int): 条件信息的维度。
            num_features (int): 输入特征的数量。
        """
        super(Affine, self).__init__()
        # 定义两个全连接网络来计算 gamma 和 beta 参数,
        # 每个网络由两个线性层和中间的ReLU激活组成
        self.fc_gamma = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(cond_dim, num_features)),
            ('relu1', nn.ReLU(inplace=True)),
            ('linear2', nn.Linear(num_features, num_features)),
        ]))
        self.fc_beta = nn.Sequential(OrderedDict([
            ('linear1', nn.Linear(cond_dim, num_features)),
            ('relu1', nn.ReLU(inplace=True)),
            ('linear2', nn.Linear(num_features, num_features)),
        ]))
        # 初始化网络的权重和偏置
        self._initialize()

    def _initialize(self):
        """
        初始化用于计算 gamma 和 beta 的线性层的权重和偏置
        """
        nn.init.zeros_(self.fc_gamma.linear2.weight.data)
        nn.init.ones_(self.fc_gamma.linear2.bias.data)
        nn.init.zeros_(self.fc_beta.linear2.weight.data)
        nn.init.zeros_(self.fc_beta.linear2.bias.data)

    def forward(self, x, y=None):
        """
        仿射变换模块的前向传播。
        Args:
            x (torch.Tensor): 输入张量。
            y (torch.Tensor, optional): 条件信息张量,默认为 None。
        Returns:
            torch.Tensor: 应用仿射变换后的变换张量。
        """
        # 计算 gamma 和 beta 参数
        weight = self.fc_gamma(y)
        bias = self.fc_beta(y)

        # 确保权重和偏置张量的形状正确
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        # 将权重和偏置张量扩展到与输入张量相匹配的形状
        size = x.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

        # 应用仿射变换
        return weight * x + bias

(14)下面代码定义了函数get_G_in_out_chs和get_D_in_out_chs,这两个函数用于根据给定的通道数和图像大小计算生成器和判别器块的输入输出通道对。生成器块和判别器块的通道数随着层数变化而变化,最高通道数为输入通道数的八倍,然后逐渐减小。

def get_G_in_out_chs(nf, imsize):
    """
    根据给定的通道数和图像大小,计算生成器块的输入输出通道对。
    Args:
        nf (int): 输入通道数。
        imsize (int): 输入图像的大小。

    Returns:
        list: 包含生成器块的输入输出通道对的元组列表。
    """
    # 根据图像大小确定层数
    layer_num = int(np.log2(imsize)) - 1
    # 计算每一层的通道数
    channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)]
    # 反转通道数列表,从最高通道数开始
    channel_nums = channel_nums[::-1]
    # 生成生成器块的输入输出通道对
    in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
    return in_out_pairs

def get_D_in_out_chs(nf, imsize):
    """
    根据给定的通道数和图像大小,计算判别器块的输入输出通道对。
    Args:
        nf (int): 输入通道数。
        imsize (int): 输入图像的大小。
    Returns:
        list: 包含判别器块的输入输出通道对的元组列表。
    """
    # 根据图像大小确定层数
    layer_num = int(np.log2(imsize)) - 1
    # 计算每一层的通道数
    channel_nums = [nf * min(2 ** idx, 8) for idx in range(layer_num)]
    # 生成判别器块的输入输出通道对
    in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
    return in_out_pairs

总之,本项目的多模态GAN模型包括了生成器块、判别器块、条件特征块和分类器网络等各种组件的定义,该模型可以根据文本描述生成图像,使用 CLIP 模型进行改进的条件设置,并结合了多尺度块、残差连接和条件归一化等功能,以实现有效的图像合成。此外,它提供了 QuickGELU 等高效激活函数和条件仿射变换,可以根据输入条件控制输出,旨在生成与给定文本提示相对应的高质量图像。

猜你喜欢

转载自blog.csdn.net/asd343442/article/details/143515419