【图神经网络&动作识别】【代码阅读】AAAI 2018 时空图神经网络ST-GCN

 原始代码

论文解读 

初始化

模型定义,初始化。Model的__init__函数:

class Model(nn.Module):
    r"""Spatial temporal graph convolutional networks.

    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units

    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, in_channels, num_class, graph_args,
                 edge_importance_weighting, **kwargs):
        super().__init__()

        # load graph
        self.graph = Graph(**graph_args)
        A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = A.size(0)
        temporal_kernel_size = 9
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))
        kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
        self.st_gcn_networks = nn.ModuleList((
            st_gcn(in_channels, 64, kernel_size, 1, residual=False, **kwargs0),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 64, kernel_size, 1, **kwargs),
            st_gcn(64, 128, kernel_size, 2, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 128, kernel_size, 1, **kwargs),
            st_gcn(128, 256, kernel_size, 2, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
            st_gcn(256, 256, kernel_size, 1, **kwargs),
        ))

        # initialize parameters for edge importance weighting
        if edge_importance_weighting:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for i in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        # fcn for prediction
        self.fcn = nn.Conv2d(256, num_class, kernel_size=1)

首先是建立邻接矩阵A:

# load graph
self.graph = Graph(**graph_args)
A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)
self.register_buffer('A', A)

图卷积的kernel的大小决定分组数量的多少,因为kernel的权重是和这一组的节点相乘然后取平均,最后所有组的结果再相加。

A的形状是K*V*V,K是kernel_size(和论文里的partitioning method选取有关,uni-label是1,distance是2,spatial是3),默认是spatial,所以A是3*25*25的大小。他代表在每一组里面节点的相连关系。

具体构建方法在Graph类的get_adjacency函数:

    def get_adjacency(self, strategy):
        valid_hop = range(0, self.max_hop + 1, self.dilation)
        adjacency = np.zeros((self.num_node, self.num_node))
        for hop in valid_hop:
            adjacency[self.hop_dis == hop] = 1
        normalize_adjacency = normalize_digraph(adjacency)

        if strategy == 'uniform':
            A = np.zeros((1, self.num_node, self.num_node))
            A[0] = normalize_adjacency
            self.A = A
        elif strategy == 'distance':
            A = np.zeros((len(valid_hop), self.num_node, self.num_node))
            for i, hop in enumerate(valid_hop):
                A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
                                                                hop]
            self.A = A
        elif strategy == 'spatial':
            A = []
            for hop in valid_hop:
                a_root = np.zeros((self.num_node, self.num_node))
                a_close = np.zeros((self.num_node, self.num_node))
                a_further = np.zeros((self.num_node, self.num_node))
                for i in range(self.num_node):
                    for j in range(self.num_node):
                        if self.hop_dis[j, i] == hop:
                            if self.hop_dis[j, self.center] == self.hop_dis[
                                    i, self.center]:
                                a_root[j, i] = normalize_adjacency[j, i]
                            elif self.hop_dis[j, self.
                                              center] > self.hop_dis[i, self.
                                                                     center]:
                                a_close[j, i] = normalize_adjacency[j, i]
                            else:
                                a_further[j, i] = normalize_adjacency[j, i]
                if hop == 0:
                    A.append(a_root)
                else:
                    A.append(a_root + a_close)
                    A.append(a_further)
            A = np.stack(A)
            self.A = A
        else:
            raise ValueError("Do Not Exist This Strategy")

要注意A是有一个normalize操作的,比如一个节点在这一组里相连的节点有2个

那这两个节点的数值都是1/2,对应文章里的这个部分

然后是初始化gcn网络模块,edge importance 权重参数,和最后的分类器

前向传播

def forward(self, x):

    # data normalization
    N, C, T, V, M = x.size()
    x = x.permute(0, 4, 3, 1, 2).contiguous()
    x = x.view(N * M, V * C, T)
    x = self.data_bn(x)
    x = x.view(N, M, V, C, T)
    x = x.permute(0, 1, 3, 4, 2).contiguous()
    x = x.view(N * M, C, T, V)

    # forwad
    for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
        x, _ = gcn(x, self.A * importance)

    # global pooling
    x = F.avg_pool2d(x, x.size()[2:])
    x = x.view(N, M, -1, 1, 1).mean(dim=1)

    # prediction
    x = self.fcn(x)
    x = x.view(x.size(0), -1)

    return x

我们先来看输入数据及其存储方式,输入数据的维度是(N,C,T,V,M),其中N是batch,C是channel,T是时间,V是节点,M是instance(例如一个时间输入可能包含两个骨架)

数据首先经过一系列reshape和Batch Normalization之后,输入gcn模块,这时候数据形状为(N*M,C,T,V)

st-gcn一共有10层gcn,其中第一层没有dropout,其余都有。

gcn内部

def forward(self, x, A):

    res = self.residual(x)
    x, A = self.gcn(x, A)
    x = self.tcn(x) + res

    return self.relu(x), A

 x先是通过一个residual模块,本质是一个1x1卷积

self.residual = nn.Sequential(
    nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        stride=(stride, 1)),
    nn.BatchNorm2d(out_channels),
)

时空图卷积被分为空间卷积(spatial convolution)和时间卷积(temporal convolution)两个步骤来完成,在进行时间卷积时,可以直接调用1d卷积,因为时间上的相邻关系在矩阵中自然体现的,但是在空间卷积中,矩阵中节点并不是按相邻关系排布的,需要额外考虑邻接矩阵。

首先是空间卷积,x和A被输入空间卷积模块gcn中

self.gcn = ConvTemporalGraphical(in_channels, out_channels,
                                    kernel_size[1])
def forward(self, x, A):
    assert A.size(0) == self.kernel_size

    x = self.conv(x)

    n, kc, t, v = x.size()
    x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
    x = torch.einsum('nkctv,kvw->nctw', (x, A))

    return x.contiguous(), A

其本质也是用卷积实现的 

self.conv = nn.Conv2d(
    in_channels,
    out_channels * kernel_size,
    kernel_size=(t_kernel_size, 1),
    padding=(t_padding, 0),
    stride=(t_stride, 1),
    dilation=(t_dilation, 1),
    bias=bias)

在实现spatial convolution时,为了考虑节点与卷积核计算的连接关系,将乘法和加法分开运算了,按理说应该是 out_channel 个 1x3 的卷积核,输出 out_channel 维度的特征,但是因为数据在矩阵中的存储方式,相邻节点并不一定有相邻关系,所以不能直接用2d卷积核运算。

这个地方,作者是采用 3*out_channel 个1x1 卷积核,这一步相当于是在做乘法运算(假定所有节点相连,所以会有计算冗余),输出 3*out_channel 维度的特征,维度是(N,K*C,T,V)然后reshape成(N,K,C,T,V)。

einsum这个函数是用来做加法操作。对每一个样本(N,T),模型的输出矩阵(K,V,C)需要与邻接矩阵A(K,V,W)做element-wise乘加,对于每一个节点W,他的相邻关系是一个矩阵(K,V),他的卷积核乘法输出是矩阵(K,V,C),这两个矩阵做element-wise相乘,最后再求和得到最终的输出向量(长度为C),就得到了这一个节点通过空间卷积核的输出(N,C,T,W)

然后时间卷积是一个普通的1d卷积核

self.tcn = nn.Sequential(
    nn.BatchNorm2d(out_channels),
    nn.ReLU(inplace=True),
    nn.Conv2d(
        out_channels,
        out_channels,
        (kernel_size[0], 1),
        (stride, 1),
        padding,
    ),
    nn.BatchNorm2d(out_channels),
    nn.Dropout(dropout, inplace=True),
)

空间卷积的输出经过时间卷积之后,加上残差模块的输出,就得到了gcn模块的输出,这样的gcn模块一共有十个。

def forward(self, x, A):

    res = self.residual(x)
    x, A = self.gcn(x, A)
    x = self.tcn(x) + res

    return self.relu(x), A

然后是一个global pooling层,将T*V的feature聚合到一个特征向量,再reshape回(N,M,C,1,1),然后跨instance取平均,得到(N,C,1,1)

# global pooling
x = F.avg_pool2d(x, x.size()[2:])
x = x.view(N, M, -1, 1, 1).mean(dim=1)

# prediction
x = self.fcn(x)
x = x.view(x.size(0), -1)

最后是一个卷积神经网络(1x1卷积核作用于1x1的图,其实相当于全连接)输出类别向量

# fcn for prediction
self.fcn = nn.Conv2d(256, num_class, kernel_size=1)