这篇文章是基于CTR-GCN的工作做了一些小小的改动,没看过 CTR-GCN 的同学可以先阅读这一篇
模型定义
模型的代码和CTR-GCN是一样的,不过里面用的核心图卷积模块不一样。
class Model(nn.Module):
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
drop_out=0, adaptive=True):
super(Model, self).__init__()
if graph is None:
raise ValueError()
else:
Graph = import_class(graph)
self.graph = Graph(**graph_args)
A = self.graph.A # 3,25,25
self.num_class = num_class
self.num_point = num_point
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
base_channel = 64
self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive)
self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l3 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l4 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l5 = TCN_GCN_unit(base_channel, base_channel*2, A, stride=2, adaptive=adaptive)
self.l6 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l7 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l8 = TCN_GCN_unit(base_channel*2, base_channel*4, A, stride=2, adaptive=adaptive)
self.l9 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.l10 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.fc = nn.Linear(base_channel*4, num_class)
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
bn_init(self.data_bn, 1)
if drop_out:
self.drop_out = nn.Dropout(drop_out)
else:
self.drop_out = lambda x: x
def forward(self, x):
if len(x.shape) == 3:
N, T, VC = x.shape
x = x.view(N, T, self.num_point, -1).permute(0, 3, 1, 2).contiguous().unsqueeze(-1)
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.l4(x)
x = self.l5(x)
x = self.l6(x)
x = self.l7(x)
x = self.l8(x)
x = self.l9(x)
x = self.l10(x)
# N*M,C,T,V
c_new = x.size(1)
x = x.view(N, M, c_new, -1)
x = x.mean(3).mean(1)
x = self.drop_out(x)
return self.fc(x)
前向传播
核心模块,TCN_GCN_Unit,这一部分和CTR-GCN也是一样的,论文图示:
代码中:
class TCN_GCN_unit(nn.Module):
def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5, dilations=[1,2]):
super(TCN_GCN_unit, self).__init__()
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive)
self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations,
residual=False)
self.relu = nn.ReLU(inplace=True)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
def forward(self, x):
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
return y
空间卷积
这一块也基本是一样的,不一样的地方有两个,第一个是核心模块改为了TDGC,第二个是增加了两个可学习参数beta和gamma
class unit_gcn(nn.Module):
def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True):
super(unit_gcn, self).__init__()
inter_channels = out_channels // coff_embedding
self.inter_c = inter_channels
self.out_c = out_channels
self.in_c = in_channels
self.adaptive = adaptive
self.num_subset = A.shape[0]
self.convs = nn.ModuleList()
for i in range(self.num_subset):
self.convs.append(TDGC(in_channels, out_channels))
if residual:
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
else:
self.down = lambda x: x
else:
self.down = lambda x: 0
if self.adaptive:
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
else:
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
self.alpha = nn.Parameter(torch.zeros(1)) # No I
self.bn = nn.BatchNorm2d(out_channels)
self.soft = nn.Softmax(-2)
self.relu = nn.ReLU(inplace=True)
self.beta = nn.Parameter(torch.tensor(0.5)) # 1.0 1.4 2.0
self.gamma = nn.Parameter(torch.tensor(0.1))
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
bn_init(self.bn, 1e-6)
def forward(self, x):
y = None
if self.adaptive:
A = self.PA
else:
A = self.A.cuda(x.get_device())
for i in range(self.num_subset):
z = self.convs[i](x, A[i], self.alpha, self.beta, self.gamma)
y = z + y if y is not None else z
y = self.bn(y)
y += self.down(x)
y = self.relu(y)
return y
TD-GC
这是核心改动的模块,与CTR-CGN的对比如下:
在channel-wise topology端(b的左边)作者删掉了一个1x1卷积,把pair-wise substraction改成了self-pair-wise substraction,其他部分都没变。
在Temporal-wise topology端(b的右边),也是用的类似操作,不过把平均的维度从temporal改成了spatial,这样得到的topology就是基于时间的了。然后这个之后作者没有加1x1卷积和缩放因子α。
然后是通过一个Temporal-Channel Fusion将两条路的输出fusion起来,使用两个可学习的参数beta和gamma。
class TDGC(nn.Module):
def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1): # r = 4 r = 16
super(TDGC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if in_channels == 3 or in_channels == 9:
self.rel_channels = 8
self.mid_channels = 16
else:
self.rel_channels = in_channels // rel_reduction
self.mid_channels = in_channels // mid_reduction
self.conv1 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
# This convolution (self.conv2) is redundant,
# but when you want to use the weight files we provide for action recognition, you have to uncomment it!
# self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1)
self.tanh = nn.Tanh()
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
def forward(self, x, A=None, alpha=1, beta=1, gamma=0.1):
x1, x3 = self.conv1(x).mean(-2), self.conv3(x)
x1 = self.tanh(x1.unsqueeze(-1) - x1.unsqueeze(-2))
x1 = self.conv4(x1) * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0)
x1 = torch.einsum('ncuv,nctv->nctu', x1, x3)
x4 = self.tanh(x3.mean(-3).unsqueeze(-1) - x3.mean(-3).unsqueeze(-2))
x3 = x3.permute(0, 2, 1, 3)
x5 = torch.einsum('btmn,btcn->bctm', x4, x3)
x1 = x1 * beta + x5 * gamma
return x1
时间卷积
这部分和CTR-GCN也是完全一样的,可以参考之前的文章。
class MultiScale_TemporalConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilations=[1,2,3,4],
residual=True,
residual_kernel_size=1):
super().__init__()
assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'
# Multiple branches of temporal convolution
self.num_branches = len(dilations) + 2
branch_channels = out_channels // self.num_branches
if type(kernel_size) == list:
assert len(kernel_size) == len(dilations)
else:
kernel_size = [kernel_size]*len(dilations)
# Temporal Convolution branches
self.branches = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
in_channels,
branch_channels,
kernel_size=1,
padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
TemporalConv(
branch_channels,
branch_channels,
kernel_size=ks,
stride=stride,
dilation=dilation),
)
for ks, dilation in zip(kernel_size, dilations)
])
# Additional Max & 1x1 branch
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)),
nn.BatchNorm2d(branch_channels)
))
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)),
nn.BatchNorm2d(branch_channels)
))
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
# initialize
self.apply(weights_init)
def forward(self, x):
# Input dim: (N,C,T,V)
res = self.residual(x)
branch_outs = []
for tempconv in self.branches:
out = tempconv(x)
branch_outs.append(out)
out = torch.cat(branch_outs, dim=1)
out += res
return out
class unit_tcn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
super(unit_tcn, self).__init__()
pad = int((kernel_size - 1) / 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
stride=(stride, 1))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
conv_init(self.conv)
bn_init(self.bn, 1)
def forward(self, x):
x = self.bn(self.conv(x))
return x
总结
TD-GCN对比CTR-GCN,代码上的唯一改动就是在核心模块上,将CTR-GC改成了TD-GC,而这一改动仅仅是通过类似的思路添加了一个Temporal-wise topology而得来的。作者在四个输入模态上进行测试,输入Joint的话,相比CTR-GCN有大约0.4~2%的提升。
然后作者也做了一些简化,参数量有一点下降
在手势识别数据集SHREC17'和DHG-14/28上达到了SOTA
在UCLA数据集上
在NTU-RGB+D上
可以看到虽然有刷榜的嫌疑,但是增加时序解耦模块后,模型性能确实是有提升的。