自己写一个Transformer结构

系列文章目录

提示:本文共定义3个类(class)

1.TransformerEncoderLayer 基础的多头Encoder
2.TransformerEncoder 多个Encoder,最终调用的是以上类
3.Transformer 整体结构,包含多Encode和多Decode的使用


前言

提示:这里可以添加本文要记录的大概内容:
例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的基础内容。


提示:以下是本篇文章正文内容,下面案例可供参考

一、先定义一个多头Transformer

代码如下(示例):

1.nn.MultiheadAttention为自带函数,只需传入输入序列和多头的数量即可,mask不是必须
2.前向过程一般采用 forward_post,里面的Norm归一化是在注意力之后

class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)                 # + (768,2,256)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

那么它是如何让调用的呢?

src= self.forward_post(src, src_mask, src_key_padding_mask, pos)
# pos为位置编码,等下会写。src为输入序列,

二、多个Encoder:

代码如下(示例):

class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        ##-----------------主要调用的层--------------------##
        self.layers = _get_clones(encoder_layer, num_layers)    
        ##-------------------------------------------------## 
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src

        #----- 主要的前向函数。mask为None , padding_mask为(bs,h*w),-------
        for layer in self.layers:
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)
        #------------------------------------------------------------------ 
        if self.norm is not None:
            output = self.norm(output)

        return output

其中克隆的多个encoder_layer, 由以下类Transformer来调用最上面的多头Trans

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()

        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

    def forward(self, src, mask,  pos_embed):
        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)      # ([768, 2, 256])
        hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                          pos=pos_embed, query_pos=query_embed)                   # (6,100,2,256)
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)

三、多个Encoder使用方法

memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)      # ([768, 2, 256])
# self.encoder的定义为上方Transformer类的最后一行

那么最终这个大的Transformer结构是如何定义和使用的呢?

from .transformer import build_transformer
self.transformer = transformer
hs = self.transformer(self.input_proj(src), mask, pos[-1])[0]

四、最后看一下两个重要函数

均来自model/transformer.py

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def build_transformer(args):
    return Transformer(
        d_model=args.hidden_dim,
        dropout=args.dropout,
        nhead=args.nheads,
        dim_feedforward=args.dim_feedforward,
        num_encoder_layers=args.enc_layers,
        num_decoder_layers=args.dec_layers,
        normalize_before=args.pre_norm,
        return_intermediate_dec=True,
    )

五、DETR中的mask与pos

mask = F.interpolate(tensor_list.mask[None].float(), size=x.shape[-2:]).bool()[0]   # ([2, 38, 25])
## 以输入为(231194800)为例,利用输入tensor_list的mask[None],尺度也是1194800。
这里猜测tensor_list.mask来源于实例分割的标签

pos=self[1](x).to(x.tensors.dtype)     # x为特征图(220483825)
self[1]= position_embedding = build_position_encoding(args)
这里的位置编码采用的是正余弦,在models/position_encode里是一个单独的类
position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 
# N_steps = args.hidden_dim // 2
 
class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """
    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, tensor_list):
        x = tensor_list.tensors
        mask = tensor_list.mask
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

总结

提示:以上方法可以用来自主构建一个Transformer

猜你喜欢

转载自blog.csdn.net/qq_45752541/article/details/120378920