Sil-Net代码

https://gitee.com/Lost_star/Sill-Net.git

数据

data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)  # '../db/'
tr_loader = data_loader(data_path, args.exp,is_transform=True, split='train', img_size=(args.img_rows, args.img_cols),                        augmentations=data_trans_train)
te_loader = data_loader(data_path, args.exp, is_transform=True, split='test', img_size=(args.img_rows, args.img_cols),                        augmentations=data_trans_test)
trainloader = DataLoader(tr_loader, batch_size=args.batch_size, shuffle=True, pin_memory=True,drop_last=True)  # pin_memory设置为True,那么拷贝到CUDA中
testloader = DataLoader(te_loader, batch_size=args.batch_size, shuffle=True, pin_memory=True)

超参数

beta = 1e-2
gamma = 1e-1
mix_ratio = 0.5
feature_channel = 6

main函数

if __name__ == "__main__":
    out_root = Path(outimg_path)  # PosixPath('results/img_log_gtsrb2gtsrb/2021-08-24-143653')
    if not out_root.is_dir():
        os.makedirs(out_root, exist_ok=True)

    best_acc = 0
    # 训练开始
    for e in range(1, args.epochs + 1):
        train(e) # 训练
        best_acc = lalala(e, best_acc)# 测试&保存模型

        print('========epoch(%d)=========' % e)
        print('best_acc:%02f' % (best_acc))

训练

  • batch_size为16
  • 指定类
tr_class = tr_loader.tr_class  
# [ 1  2  3  4  5  7  8  9 10 11 12 13 14 15 17 18 25 26 31 33 35 38]
te_class = te_loader.te_class  
# [ 0  6 16 19 20 21 22 23 24 27 28 29 30 32 34 36 37 39 40 41 42]

在这里插入图片描述

损失函数

论文中有四个损失函数:

  • match loss:纠正扭曲的特征
def loss_match_func(feat_sem, temp_sem):
    """
    均方差
    :param feat_sem:torch.Size([16, 3, 64, 64])
    :param temp_sem:torch.Size([16, 3, 64, 64])
    :return: 标量
    """
    MS = match_loss(feat_sem, temp_sem)
    return MS
  • recon loss:保证语义特征有足够信息重建到模板
def loss_recon_func(recon_feat_sem, recon_temp_sem, template, recon_temp_sup=None, template_sup=None):
    """
    重构损失
    :param recon_feat_sem: torch.Size([16, 3, 64, 64]) 图片的语义特征
    :param recon_temp_sem: torch.Size([16, 3, 64, 64]) 模板的语义特征
    :param template: torch.Size([16, 3, 64, 64]) 模板
    :param recon_temp_sup:
    :param template_sup:
    :return: 标量
    """
    RE = recon_loss(recon_feat_sem, template) + recon_loss(recon_temp_sem, template)
    if recon_temp_sup is not None:
        recon_sup = recon_loss(recon_temp_sup, template_sup)
        RE += recon_sup
    return RE

  • class loss:分类正确
def loss_class_func(out, target, out_sup=None, target_sup=None):
    """
    交叉熵损失
    :param out: torch.Size([16, 11])
    :param target: torch.Size([16])
    :param out_sup:
    :param target_sup:
    :return:
    """
    CE = F.cross_entropy(out, target)
    CE_sup = 0
    if out_sup is not None:
        CE_sup = F.cross_entropy(out_sup, target_sup)
    return CE + CE_sup
  • illu loss:使照明特征变得语义无关
def loss_illu_func(feat_illu, target):
    """
    同一标签的图像,其照明特征越不同越好
    :param feat_illu: torch.Size([16, 3, 64, 64])
    :param target: 16
    :return:
    """
    pida_illu = PIDA_loss(feat_illu, target)
    return -pida_illu
def PIDA_loss(feature, target):
    """

    :param feature:
    :param target:
    :return:
    """
    tg_unique = torch.unique(target)
    pida_loss = 0
    for tg in tg_unique:
        feature_split = feature[target == tg, :, :, :]  # 同样标签的特征
        mean_feature = torch.mean(feature_split, 0).unsqueeze(0)
        mean_feature_rep = mean_feature.repeat(feature_split.shape[0], 1, 1, 1)
        pida_loss += match_loss(feature_split, mean_feature_rep)
    return pida_loss

模型

在这里插入图片描述

extract

  1. 提取图片的语义特征和光照特征
  2. 有6个没有poling layer的卷积层、
  3. 最后将输出6个通道的拆成两个3通道的,一个作为语义特征,一个作为光照特征
    def extract(self, x, is_warping):  # x : torch.Size([16, 3, 64, 64])
        if is_warping and self.param1 is not None:
            x = self.stn1(x)
        h1 = self.leakyrelu(self.ex_bn1(self.ex1(self.ex_pd1(x))))  # torch.Size([16, 100, 64, 64])
        h2 = self.leakyrelu(self.ex_bn2(self.ex2(self.ex_pd2(h1))))  # torch.Size([16, 150, 64, 64])

        if is_warping and self.param2 is not None:
            h2 = self.stn2(h2)
        h3 = self.leakyrelu(self.ex_bn3(self.ex3(self.ex_pd3(h2))))  # torch.Size([16, 200, 64, 64])
        h4 = self.leakyrelu(self.ex_bn4(self.ex4(self.ex_pd4(h3))))  # torch.Size([16, 150, 64, 64])

        if is_warping and self.param3 is not None:
            h4 = self.stn3(h4)  # torch.Size([16, 150, 64, 64])
        h5 = self.leakyrelu(self.ex_bn5(self.ex5(self.ex_pd5(h4))))  # torch.Size([16, 100, 64, 64])
        h6 = self.sigmoid(self.ex_bn6(self.ex6(self.ex_pd6(h5))))  # torch.Size([16, 6, 64, 64])

        feat_sem, feat_illu = torch.chunk(h6, 2, 1)  # 都是torch.Size([16, 3, 64, 64])
        feat_sem_nowarp = feat_sem

        if is_warping and self.param4 is not None:
            feat_sem = self.stn4(feat_sem)

        return feat_sem, feat_illu, feat_sem_nowarp

其中又用到了stn,用于校正变形的语义特征
在这里插入图片描述

class stn(nn.Module):  
    def __init__(self, input_channels, input_size, params):  # 3, 64, [150, 150, 150, 150]
        super(stn, self).__init__()

        self.input_size = input_size

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
            nn.ReplicationPad2d(2),
            nn.Conv2d(input_channels, params[0], kernel_size=5, stride=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.ReplicationPad2d(2),
            nn.Conv2d(params[0], params[1], kernel_size=5, stride=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv3 = nn.Sequential(
            nn.ReplicationPad2d(2),
            nn.Conv2d(params[1], params[2], kernel_size=3, stride=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        out_numel, out_size = convNoutput([self.conv1, self.conv2, self.conv3], input_size / 2)
        # set fc layer based on predicted size
        self.fc = nn.Sequential(
            View(),
            nn.Linear(out_numel, params[3]),
            nn.ReLU()
        )
        self.classifier = classifier = nn.Sequential(
            View(),
            nn.Linear(params[3], 6)  # affine transform has 6 parameters
        )
        # initialize stn parameters (affine transform)
        self.classifier[1].weight.data.fill_(0)
        self.classifier[1].bias.data = torch.FloatTensor([1, 0, 0, 0, 1, 0])

    def localization_network(self, x):  # torch.Size([16, 3, 64, 64])
        x = self.maxpool(x)  # torch.Size([16, 3, 32, 32])
        x = self.conv1(x)  # torch.Size([16, 150, 16, 16])
        x = self.conv2(x)  # torch.Size([16, 150, 8, 8])
        x = self.conv3(x)  # torch.Size([16, 150, 5, 5])
        x = self.fc(x)  # torch.Size([16, 150])
        x = self.classifier(x)  # torch.Size([16, 6])
        return x

    def forward(self, x):  # torch.Size([16, 3, 64, 64])
        theta = self.localization_network(x)  # torch.Size([16, 6])
        theta = theta.view(-1, 2, 3)  # torch.Size([16, 2, 3]),下一行的仿射变换要求size的最后两个是2,3
        grid = F.affine_grid(theta, x.size())  # torch.Size([16, 64, 64, 2]) ,size固定N*H*W*2
        x = F.grid_sample(x, grid)  # torch.Size([16, 3, 64, 64])
        return x

decode

将语义特征解码成模板

    def decode(self, x):  # torch.Size([16, 3, 64, 64])
        h1 = self.leakyrelu(self.de_bn1(self.de1(self.de_pd1(x))))  # torch.Size([16, 100 , 64, 64])
        h2 = self.leakyrelu(self.de_bn2(self.de2(self.de_pd2(h1))))  # torch.Size([16, 150, 64, 64])
        h3 = self.leakyrelu(self.de_bn3(self.de3(self.de_pd3(h2))))  # torch.Size([16, 200, 64, 64])
        h4 = self.leakyrelu(self.de_bn4(self.de4(self.de_pd4(h3))))  # torch.Size([16, 150, 64, 64])
        out = self.sigmoid(self.de5(self.de_pd5(h4)))  # torch.Size([16, 3, 64, 64])
        return out

classify

  1. 关于为什么有两个分类器
    因为使用了全连接层,而用train中的类有43个,test中只有11个。全连接层参数是固定的,所以需要两个
  2. 为什么不直接用分割出来的语义特征训练分类器?
    论文中提到

在这里插入图片描述

这个分类器是有43个训练类的

    def classify(self, x):  # torch.Size([32, 6, 64, 64])
        h1 = self.pool2(self.leakyrelu(self.cls_bn1(self.cls1(x))))  # torch.Size([32, 100, 32, 32])
        h2 = self.leakyrelu(self.cls_bn2(self.cls2(h1)))  # torch.Size([32, 150, 32, 32])
        h3 = self.pool2(self.leakyrelu(self.cls_bn3(self.cls3(h2))))  # torch.Size([32, 200, 16, 16])
        h4 = self.leakyrelu(self.cls_bn4(self.cls4(h3)))  # torch.Size([32, 250, 16, 16])
        h5 = self.pool2(self.leakyrelu(self.cls_bn5(self.cls5(h4))))  # torch.Size([32, 300, 8, 8])
        h6 = self.leakyrelu(self.cls_bn6(self.cls6(h5)))  # torch.Size([32, 100, 8, 8])
        h7 = h6.view(-1,int(self.input_size / 8 * self.input_size / 8 * self.classify_chn[5]))  # torch.Size([32, 6400])
        out = self.fc1(h7)  # torch.Size([32, 43])
        return out

classify2

这个分类器是有11个训练类的

    def classify2(self, x):  # torch.Size([16, 6, 64, 64])
        h1 = self.pool2(self.leakyrelu(self.cls2_bn1(self.cls21(x))))  # torch.Size([16, 100, 32, 32])
        h2 = self.leakyrelu(self.cls2_bn2(self.cls22(h1)))  # torch.Size([16, 150, 32, 32])
        h3 = self.pool2(self.leakyrelu(self.cls2_bn3(self.cls23(h2))))  # torch.Size([16, 200, 16, 16])
        h4 = self.leakyrelu(self.cls2_bn4(self.cls24(h3)))  # torch.Size([16, 250, 16, 16])
        h5 = self.pool2(self.leakyrelu(self.cls2_bn5(self.cls25(h4))))  # torch.Size([16, 300, 8, 8])
        h6 = self.leakyrelu(self.cls2_bn6(self.cls26(h5)))  # torch.Size([16, 100, 8, 8])
        h7 = h6.view(-1,
                     int(self.input_size / 8 * self.input_size / 8 * self.classify_chn[5]))  # torch.Size([16, 6400])
        out = self.fc2(h7)  # torch.Size([16, 11])
        return out

init_params

    def init_params(self, net):
        print('Loading the model from the file...')
        net_dict = self.state_dict()
        if isinstance(net, dict):
            pre_dict = net
        else:
            pre_dict = net.state_dict()
        # 1. filter out unnecessary keys
        pre_dict = {
    
    k: v for k, v in pre_dict.items() if (k in net_dict)}
        net_dict.update(pre_dict)
        # 3. load the new state dict
        self.load_state_dict(net_dict)

测试

测试大致和训练相同,分类用的classify2

训练效果

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_37252519/article/details/119955467