vita-clip 模型(从源码层次)


前言:细节不过多赘述,工作量太大


创新点

视觉端

全局令牌
        if self.use_global_prompts:
            for i, blk in enumerate(self.blocks):
                global_prompts = self.global_prompts[i].expand(B*T, -1, -1)

                x = torch.cat((x[:, :1, :], global_prompts, x[:, 1:, :]), dim=1)
                x = blk(x)
                x = torch.cat((x[:, :1, :], x[:, self.num_global_prompts+1:, :]), dim=1)

所谓全局令牌:全局视频级提示令牌( ( G ( l ) = [ g 1 ( l ) , ⋯   , g M v ( l ) ] (G^{(l)}=[g_{1}^{(l)},\cdots,g_{M_{v}}^{(l)}] (G(l)=[g1(l),,gMv(l)])是随机初始化的可学习向量,用于为模型提供学习视频数据分布的能力,使模型更好地适应视频领域的特点。
可以算一种加了 `时间维度`` 来捕捉

局部令牌
 # then if local prompts are being used
        if self.use_local_prompts:
            local_prompts = self.local_prompts.expand(B, -1, -1)
            # If train time frames and
            # test time frames are not equal
            if T != self.num_frames:
                token_multiplier = T//self.num_frames
                local_prompts = local_prompts.repeat(1,token_multiplier,1)
            
            # use additive conditioning
            local_prompts = local_prompts + cls_token_proj

            # repeat across frames
            local_prompts = local_prompts.repeat_interleave(repeats=T, dim=0)
            x = torch.cat((x[:, :1, :], local_prompts, x[:, 1:, :]), dim=1)

局部帧级提示:局部帧级提示令牌( ( L ( l ) = [ l 1 ( l ) , ⋯   , l T ( l ) ] (L^{(l)}=[l_{1}^{(l)},\cdots,l_{T}^{(l)}] (L(l)=[l1(l),,lT(l)]))数量与视频帧数T相等,也是随机初始化的可学习向量,且基于每个帧的分类令牌进行条件设定。这使得模型能够在帧与帧之间传递判别信息,增强对每帧局部信息的学习,其计算方式为 ( l ^ t ( l ) = l t ( l ) + z t , 0 ( l − 1 ) (\hat{l}_{t}^{(l)} = l_{t}^{(l)} + z_{t,0}^{(l - 1)} (l^t(l)=lt(l)+zt,0(l1))。
其实也是对分类标记 加入了 时间维度 来捕捉时间信息

摘要令牌
 if self.use_summary_token:
            summary_token_norm = self.summary_ln(cls_token_proj)
            summary_token_attn = cls_token_proj + self.summary_attn_layer(summary_token_norm, summary_token_norm, summary_token_norm)
            summary_token_attn_reshape = summary_token_attn.view(BT, 1, C)
            x = torch.cat([x, summary_token_attn_reshape], dim=1)

跨帧交流

文本端

class TextPromptLearner(nn.Module):
    def __init__(self, classnames, text_model, num_prompts, prompts_init='', CSC=False, ctx_pos='end'):
        super().__init__()

        _tokenizer = _Tokenizer()
        n_cls = len(classnames)
        n_ctx = num_prompts
        ctx_init = prompts_init
        ctx_dim = text_model.ln_final.weight.shape[0]
        #特定
        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))#作为上下文的长度
            prompt = tokenize(ctx_init)
            with torch.no_grad():
                embedding = text_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init #空格
        #随机
        else:
            # random initialization
            if CSC:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{
      
      prompt_prefix}"')
        print(f"Number of context words (tokens): {
      
      n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]#token序列 词表对应的序列
        prompts = [prompt_prefix + " " + name + "." for name in classnames]#提示词

        tokenized_prompts = torch.cat([tokenize(p) for p in prompts])
        # print(tokenized_prompts.shape)
        with torch.no_grad():
            embedding = text_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = ctx_pos

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix#(n,1,d)
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

你要用清楚 两个量一个是 tokenized_prompts 和prompts

  • tokenized_prompts
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]#token序列 词表对应的序列
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([tokenize(p) for p in prompts])# 是拿一句话出来 (n_cls,77)
# print(tokenized_prompts.shape)

看注解 直接得到 (n_cls,77)

  • prompt
 if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

默认是end 所以 prompt是(n_cls,77,d)

数据集

dataset

class VideoDataset(torch.utils.data.Dataset):

    def __init__(
        self, list_path: str, data_root: str,
        num_spatial_views: int, num_temporal_views: int, random_sample: bool,
        num_frames: int, sampling_rate: int, spatial_size: int,
        mean: torch.Tensor, std: torch.Tensor,
        auto_augment: Optional[str] = None, interpolation: str = 'bicubic',
        mirror: bool = False,
    ):
参数名称 参数类型 默认值 含义
list_path str 数据列表文件的路径,该文件可能包含了数据的相关信息,如图片或视频的文件名、标签等,用于指引程序找到具体的数据。
data_root str 数据的根目录,结合 list_path 中的信息,程序可以定位到具体的数据文件。
num_spatial_views int 空间视角的数量,在处理图像或视频数据时,可能会从不同的空间视角对数据进行采样或处理,该参数指定了视角的数量。
num_temporal_views int 时间视角的数量,对于视频数据,可能会从不同的时间点进行采样或处理,该参数指定了时间视角的数量。
random_sample bool 是否进行随机采样。如果设置为 True,则在数据处理过程中会采用随机采样的方式;如果设置为 False,则可能采用固定的采样策略。
num_frames int 要采样的帧数,对于视频数据,指定了从视频中选取的帧数。
sampling_rate int 采样率,决定了在时间维度上采样的间隔。例如,采样率为 2 表示每隔 2 帧进行一次采样。
spatial_size int 空间尺寸,通常用于指定图像或视频在空间维度上的大小,如图片的边长或视频帧的尺寸。
mean torch.Tensor 数据的均值,用于数据归一化处理。在图像或视频数据预处理中,通常会减去均值以将数据中心化。
std torch.Tensor 数据的标准差,同样用于数据归一化处理。在减去均值后,会除以标准差以将数据的方差调整为 1。
auto_augment Optional[str] None 自动数据增强策略,指定了使用的自动数据增强方法。如果为 None,则不进行自动数据增强。
interpolation str 'bicubic' 图像插值方法,在对图像进行缩放或变形时,用于计算新像素值的方法,默认使用双三次插值。
mirror bool False 是否进行镜像翻转,设置为 True 时会对数据进行镜像操作,是一种简单的数据增强方式。

dataloader

def create_train_loader(args: argparse.Namespace, resume_step: int = 0) -> torch.utils.data.DataLoader:
    dataset = create_train_dataset(args)
    rank, world_size = (0, 1) if not dist.is_initialized() else (dist.get_rank(), dist.get_world_size())

    assert args.batch_size % world_size == 0
    batch_size_per_gpu = args.batch_size // world_size

    # manually create a step-based sampler
    sampler = []
    while len(sampler) * len(dataset) < args.num_steps * args.batch_size:
        g = torch.Generator()
        g.manual_seed(len(sampler))
        indices = torch.randperm(len(dataset), generator=g)
        sampler.append(indices)
    sampler = torch.cat(sampler, dim=0)[:args.num_steps * args.batch_size].view(args.num_steps, args.batch_size)
    sampler = sampler[resume_step:, batch_size_per_gpu * rank: batch_size_per_gpu * (rank + 1)].flatten().tolist()

    loader = torch.utils.data.DataLoader(
        dataset, sampler=sampler, batch_size=batch_size_per_gpu,
        num_workers=args.num_workers, pin_memory=False, drop_last=True,
    )

  • 训练步数:在深度学习训练里,每一步通常意味着对一个批次(batch)的数据进行一次前向传播、计算损失、反向传播以及参数更新的完整过程。args.num_steps 明确了在整个训练过程中要进行多少次这样的操作
  • 进程(work_num):在深度学习训练中,数据加载是一个重要环节。当使用 DataLoader 加载数据时,num_workers 决定了会创建多少个独立的子进程来并行地完成数据加载任务。简单来说,它控制了数据加载的并行程度
    其他参考 分布式训练

模型

  • 优化器 AdamW函数
  • 学习率 余弦退火
  • 损失函数 :交叉熵函数
  • 输入 :图片
    for i, (data, labels) in enumerate(train_loader, resume_step):
        data, labels = data.cuda(), labels.cuda()
        data_ed = datetime.now()

        optimizer.zero_grad()

        assert data.size(0) % args.batch_split == 0
        split_size = data.size(0) // args.batch_split
        hit1, hit5, loss_value = 0, 0, 0
        for j in range(args.batch_split):
            data_slice = data[split_size * j: split_size * (j + 1)]
            labels_slice = labels[split_size * j: split_size * (j + 1)]

            with torch.cuda.amp.autocast(args.fp16):
                logits = model(data_slice)
                loss = criterion(logits, labels_slice)
                
            if labels.dtype == torch.long: # no mixup, can calculate accuracy
                hit1 += (logits.topk(1, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
                hit5 += (logits.topk(5, dim=1)[1] == labels_slice.view(-1, 1)).sum().item()
            loss_value += loss.item() / args.batch_split
            
            loss_scaler.scale(loss / args.batch_split).backward()
  • 输出为 样本对类别的相似度分数

  • 训练技巧 参考混合精度