CogView中的ColumnParallelLinear

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

目录

一、原理

二、代码详解

 1、__init__

(1)参数说明

(2)沿第二个维度划分权重矩阵(获取分给每个进程的权重矩阵列数)

(3)初始化权重矩阵(W)和偏置矩阵(b)

2、forward


一、原理

简单来说就是基于模型分片地按列切分权重的线性变换。

权重:W = [W_1, ..., W_p](p为分区数量,即GPU数量);

偏置:B = [b_1, ..., b_p];

输入:X(每个GPU都拥有相同的X);

输出:Y;

表达式:Y = XW+B=X*[W_1, ..., W_p]+[b_1, ..., b_p] = [XW_1, ..., XW_p]+[b_1, ..., b_p] = [Y_1, ..., Y_p];(Y_1是在第一块GPU上的结果)

这里分为两种情况:一种是每块GPU上有相应的结果(如上面所示);另一种是结果Y给所有GPU共用(矩阵列表拼接的结果)。


二、代码详解

(代码位置:model/mpu/layers)

 1、__init__

(1)参数说明

  • input_size:矩阵W的第一维;
  • output_size:矩阵A的第二维度;
  • bias:是否添加偏置;
  • gather_output:如果为true,则在输出上调用所有gether并使Y可用于所有GPU,否则,每个GPU都将有其输出,即Y_i=X*W_i(第i片GPU);
  • init_method:初始化权重的方法;
  • stride:用于跨距线性层;
  • keep_master_weight_for_test:这是为测试而添加的,应设置为False。它返回用于初始化的主权重;

假设W的shape为(c,d)。

class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Arguments:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        gather_output: If true, call all-gether on output and make Y avaiable
                       to all GPUs, otherwise, every GPU will have its output
                       which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set
                     to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be
                                     set to False. It returns the master weights
                                     used for initialization.
    """
    def __init__(self, input_size, output_size, bias=True, gather_output=True,
                 init_method=init.xavier_normal_, stride=1,
                 keep_master_weight_for_test=False):
        super(ColumnParallelLinear, self).__init__()
        # Keep input parameters
        self.input_size = input_size
        self.output_size = output_size
        self.gather_output = gather_output

(2)沿第二个维度划分权重矩阵(获取分给每个进程的权重矩阵列数)

        # Divide the weight matrix along the last dimension.
        world_size = get_model_parallel_world_size()#获取进程数(每个进程组里有多少个进程)——默认情况下,只有一个进程组
        self.output_size_per_partition = divide(output_size, world_size)# 获取每个权重分区的大小

(3)初始化权重矩阵(W)和偏置矩阵(b)

        # Parameters.
        # Note: torch.nn.functional.linear performs XA^T + b and as a result
        # we allocate the transpose.
        self.weight = Parameter(torch.Tensor(self.output_size_per_partition,
                                             self.input_size))#初始化权重——先由torch.Tensor生成全0的(d/p,c)的tensor,再由Parameter将一个不可训练的tensor转换成可以训练的类型parameter(即在定义网络时这个tensor就是一个可以训练的参数了)
        self.weight.model_parallel = True#启用权重并行
        if bias:#如果考虑偏置(b)
            self.bias = Parameter(torch.Tensor(self.output_size_per_partition))#初始化一个(1,d/p)的b矩阵(全0)
            self.bias.model_parallel = True#启用偏置并行
            # Always initialize bias to zero.
            with torch.no_grad():
                self.bias.zero_()
        else:#不考虑偏置
            self.register_parameter('bias', None)

        # Initialize weight.初始化权重矩阵
        self.master_weight = _initialize_affine_weight(
            self.weight, self.output_size, self.input_size,
            self.output_size_per_partition, 0, init_method,
            stride=stride, return_master_weight=keep_master_weight_for_test)

2、forward

就是Y=XW+B(注意:这里F.linear是X*W^T+B)

    def forward(self, input_):
        # Set up backprop all-reduce.
        input_parallel = copy_to_model_parallel_region(input_)#将输入传递到模型并行区域
        # Matrix multiply.
        output_parallel = F.linear(input_parallel, self.weight, self.bias)#X*W+b
        if self.gather_output:#每片GPU都有相同的输出
            # All-gather across the partitions.
            output = gather_from_model_parallel_region(output_parallel)
        else:#每个GPU都将有其输出,即Y_i=X*W_i(第i片GPU)
            output = output_parallel
        return output

欢迎大家在评论区批评指正,谢谢~

猜你喜欢

转载自blog.csdn.net/weixin_55073640/article/details/126465070
今日推荐