【datawhale】学习小组打卡博客5

比赛链接:零基础入门 CV 赛事 - 街景字符编码识别
打卡任务:模型集成

多模型集成

四种最常用的多模型集成方法

假设共有 N 个模型待集成,对某测试样本xx, 其预测结果为 N 个 C 维向量,(C 为数据的标记空间大小):s1,s2,⋯,sN

  • 直接平均
    直接平均不同模型产生的类别置信度得到最后预测结果

  • 加权平均
    在直接平均法基础上加入权重
    调整不同模型输出的重要程度
    在这里插入图片描述
    wi 作为第 i 个模型的权重,需满足:
    在这里插入图片描述
    高准确率的模型权重较高,低准确率模型可设置稍小权重

  • 投票
    多数表决法 (majority voting)
    将各自模型返回的预测置信度 si 转化为预测类
    若某预测类别获得一半以上模型投票,则该样本预测结果为该类别;
    若对该样本无任何类别获得一半以上投票,则拒绝作出预测(称为”rejection option”)
    相对多数表决法 (plurality voting)
    选择投票数最高的类别作为最后预测
    一定会返回某个类别

  • 堆叠 (stacking)
    又称” 二次集成法”,高阶的集成学习方法
    样本 x 作为学习算法或网络模型的输入,si 作为第 i 个模型的类别置信度输出,整个学习过程记作一阶学习过程 (first-level learning)
    Stacking 是以一阶学习过程的输出作为输入,展开二阶学习过程 (second-level learning)

  • 元学习(meta learning)
    置信度可以级联作为新的特征表示。
    之后基于这样的” 特征表示” 训练学习器将其映射到样本原本的标记空间。
    此时的学习器可以为任何算法习得的模型

定义多模型

import torchvision.models as models
import torch
from efficientnet_pytorch import EfficientNet
from torchsummary import summary
import torch.nn as nn


def net(num_class, model_name, pretrain=True):
    model_list = []

    for name in model_name:
        if name == "resnet":
            model = models.resnet101(pretrained=True)
            model.fc = torch.nn.Sequential(torch.nn.Linear(2048, num_class))
            model_list.append(model)
        elif name == "densenet":
            model = models.densenet121(pretrained=pretrain)
            model = Densenet(model, num_class)
            # model.fc = torch.nn.Sequential(torch.nn.Linear(2048, 2))
            model_list.append(model)
        elif name == "efficientnet":
            model = EfficientNet.from_pretrained('efficientnet-b5', num_classes=num_class)
            model_list.append(model)

        # model = models.densenet121(pretrained=pretrain)
        # model = EfficientNet.from_pretrained('efficientnet-b5')
        # model.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))
        # model.fc = torch.nn.Sequential(torch.nn.Linear(2048, num_class))
    return model_list


class Densenet(nn.Module):
    def __init__(self, model, num_class):
        super(Densenet, self).__init__()
        self.ori_model = model
        # self.relu = nn.ReLU(inplace=True)
        self.linear = nn.Linear(1000, num_class)

    def forward(self, x):
        x = self.ori_model(x)
        # x = self.relu(x)
        x = self.linear(x)

TTA

在测试阶段将测试集做数据增强,让测试图像来“拟合”模型

class Prediction(FlyAI):
    def load_model(self):
        self.model_name = ["resnet", "densenet", "efficientnet"]
        self.model_list = []
        for name in self.model_name:
            model = torch.load(os.path.join(MODEL_PATH, name+"_best.pth"))
            self.model_list.append(model)
            # model = torch.load("lab_model/pretrain_model/COVIDC_densenet121_best.pth")
            # model = ttach.SegmentationTTAWrapper(model, ttach.aliases.d4_transform(), merge_mode='mean')
            # self.model_list = model.to(device)

    def predict(self, image_path):
        out_list = []

        for n in range(len(self.model_list)):
            model = self.model_list[n]
            model = model.to(device)
            img = Image.open(image_path).convert("RGB")
            output_list = []
            transform = Mydata().v_transform()
            img_tensor = transform(img).unsqueeze(0)
            img_tensor = img_tensor.to(device)
            output = model(img_tensor)
            output_list.append(output)

            for i in range(4):
                transform = Mydata().t_transform()
                img_tensor = transform(img).unsqueeze(0)
                img_tensor = img_tensor.to(device)
                output = model(img_tensor)
                output_list.append(output)

            output = 0
            for i in range(len(output_list)):
                output += output_list[i]

            output = output / 5
            out_list.append(output)

        output = 0
        for i in out_list:
            output += i
        output = output / 3
        pred = output.max(1, keepdim=True)[1]
        # print(pred)
        return {
    
    "label": pred}

猜你喜欢

转载自blog.csdn.net/weixin_45612763/article/details/106504944