【3D 图像分割】基于 Pytorch 的 3D 图像分割5(训练篇)

在本系列的开篇,就对整个项目训练所需要的所有模块都进行了一个简要的介绍,尤其是针对训练中需要引入的各个结构,进行一个串联介绍。

而在之前的数据构建篇和网络模型篇中,都对其中的每一个组块进行了分别的验证。预先在未开始训练前,检验其中各个模块的正确性,避免在训练时候,问题连连,着实抓马。

通过这一系列文章的学习后,我相信绝大部分的模块都已经介绍过了。包括:

  1. 综述篇中对优化器、模型获取和保存模型进行了介绍;
  2. 在数据流模块中,学习了如何导入数据,验证数据流;
  3. 网络模型那里,损失函数loss的调用。

本篇其实存在的最大意义,就在于将这些零零散散的东西,拼接成一个整体。至于推理阶段,将单独新开一节,放到后面。通过这个系列的学习,也能多一些思考,加深一些感悟。

一、损失函数

在分割任务中,把目标分割任务的mask,转化为对像素点的分类任务。所以在计算损失的时候,论文里面的损失函数采用的就是交叉熵损失函数

在后续的损失改进中,多引入dice lossfocal loss。我们就从交叉熵损失函数开始,探讨下它为什么可以应用在分割任务中。

本文继续沿着在网络模型评估阶段,使用的交叉熵损失函数,定义如下。对于其他分割的损失函数,参考这篇文章:【AI面试】CrossEntropy Loss 、Balanced Cross Entropy、 Dice Loss 和 Focal Loss 分类损失横评

1.1、CrossEntropyLoss

在上一篇关于网络模型中,对模型的测试阶段,引入了交叉熵损失函数。链接在这:【3D图像分割】基于 Pytorch 的 VNet 3D 图像分割3(3D UNet 模型篇)。其中引入loss的方式,如下这样:

expected_output_shape = (batch_size, num_out_classes, 64, 64, 64)
assert output.shape == expected_output_shape, "Unexpected output shape, check the architecture!"

# Defining loss fn
ce_layer = torch.nn.CrossEntropyLoss()
# Calculating loss
ce_loss = ce_layer(output, ground_truth)
print("CE Loss = {}".format(ce_loss))

其中,

  • ground_truth 的大小是 BxDxHxW
  • output 的大小是 BxCxDxHxW
  • 对于输入的预测张量,通常会在C维度上进行softmax操作,使得每个通道(类别)的输出值都在[0,1]范围内,并且所有通道的输出值之和为1。
  • 这样做的目的是将预测结果转换成概率分布,方便计算交叉熵损失。
  • PyTorch中,torch.nn.CrossEntropyLoss()函数会自动将输入进行softmax操作。

1.2、Dice loss

Dice系数中的"Dice"实际上是一位科学家名字的缩写,其全名是Sørensen–Dice coefficient,常被称为Dice similarity coefficient或者F1 score。它由植物学家Thorvald SørensenLee Raymond Dice独立研制,分别于 1948 年和 1945 年发表。

Dice系数是一种常见的相似度计算方法,主要用于计算两个集合的相似度。在 Dice Loss 中,用 Dice 系数来计算预测结果和真实标签的相似度,因此得名 Dice Loss

dice coefficient定义如下:
1

如果看作是对像素点类别的分类任务,也可以写成:
2

于是,dice loss就可以表示为:
3

Dice系数的中文名称为“Dice相似系数”或“Dice相似度”,因此 Dice Loss 也可以称为“Dice相似度损失”或“Dice相似系数损失”。

multi dice loss定义如下:

import torch
import numpy as np

def one_hot_encode(label, num_classes):
    """ Torch One Hot Encode
    :param label: Tensor of shape BxHxW or BxDxHxW
    :param num_classes: K classes
    :return: label_ohe, Tensor of shape BxKxHxW or BxKxDxHxW
    """
    assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
    label_ohe = None
    if len(label.shape) == 3:
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
    elif len(label.shape) == 4:
        label_ohe = torch.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))

    for batch_idx, batch_el_label in enumerate(label):
        for cls in range(num_classes):
            label_ohe[batch_idx, cls] = (batch_el_label == cls)
    label_ohe = label_ohe.long()
    return label_ohe

def dice(outputs, labels):
    eps = 1e-5
    outputs, labels = outputs.float(), labels.float()
    outputs, labels = outputs.flatten(), labels.flatten()
    intersect = torch.dot(outputs, labels)  # 对应元素相乘再相加
    union = torch.add(torch.sum(outputs), torch.sum(labels))
    dice_coeff = (2 * intersect + eps) / (union + eps)
    dice_loss = 1 - dice_coeff
    return dice_loss

def dice_n_classes(outputs, labels, do_one_hot=False, get_list=False, device=None):
    """
    Computes the Multi-class classification Dice Coefficient.
    It is computed as the average Dice for all classes, each time
    considering a class versus all the others.
    Class 0 (background) is not considered in the average(不计入平均数).

    :param outputs: probabilities outputs of the CNN. Shape: [BxCxDxHxW]
    :param labels:  ground truth                      Shape: [BxDxHxW]
    :param do_one_hot: set to True if ground truth has shape [BxHxW]
    :param get_list:   set to True if you want the list of dices per class instead of average
    :param device: CUDA device on which compute the dice
    :return: Multiclass classification Dice Loss
    """
    num_classes = outputs.shape[1]
    if do_one_hot:
        labels = one_hot_encode(labels, num_classes)
        labels = labels.cuda(device=device)

    dices = list()
    for cls in range(1, num_classes):
        outputs_ = outputs[:, cls].unsqueeze(dim=1)
        labels_  = labels[:, cls].unsqueeze(dim=1)
        dice_ = dice(outputs_, labels_)
        dices.append(dice_)
    if get_list:
        return dices
    else:
        return sum(dices) / (num_classes-1)


def get_multi_dice_loss(outputs, labels, device=None):
    return dice_n_classes(outputs, labels, do_one_hot=True, get_list=False, device=device)

二、Dice coeff(系数)评价指标

在定义 Dice loss的时候,就已经介绍了 Dice coeff,他们两者之间的关系是:Dice loss = 1- Dice coeff

在本文中,尽管是只有一个类别,但是还是给出了多个类别情况下的Dice coeff,求平均就是average Dice coeff。但是,由于本篇的输出有个背景类,在计算的时候是不算上背景的。所以计算Dice coeff时候是从1开始的。

代码如下:

def one_hot_encode_np(label, num_classes):
    """ Numpy One Hot Encode
    :param label: Numpy Array of shape BxHxW or BxDxHxW
    :param num_classes: K classes
    :return: label_ohe, Numpy Array of shape BxKxHxW or BxKxDxHxW
    """
    assert len(label.shape) == 3 or len(label.shape) == 4, 'Invalid Label Shape {}'.format(label.shape)
    label_ohe = None
    if len(label.shape) == 3:
        label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2]))
    elif len(label.shape) == 4:
        label_ohe = np.zeros((label.shape[0], num_classes, label.shape[1], label.shape[2], label.shape[3]))
    for batch_idx, batch_el_label in enumerate(label):
        for cls in range(num_classes):
            label_ohe[batch_idx, cls] = (batch_el_label == cls)
    return label_ohe

def dice_coeff(gt, pred, eps=1e-5):
    dice = np.sum(pred[gt == 1]) * 2.0 / (np.sum(pred) + np.sum(gt))
    return dice

def multi_dice_coeff(gt, pred, num_classes):
    print('loss shape:', gt.shape, pred)
    labels = one_hot_encode_np(gt, num_classes)
    outputs = one_hot_encode_np(pred, num_classes)
    dices = list()
    for cls in range(1, num_classes):
        outputs_ = outputs[:, cls]
        labels_  = labels[:, cls]
        dice_ = dice_coeff(outputs_, labels_)
        dices.append(dice_)
    return sum(dices) / (num_classes-1)

对于多个类别的情况,在调用multi_dice_coeff前,需要先进行如下的操作:(下面的操作,默认了一种情况,那就是targetmask,是以不同的数字,代表不同的类别的,比如0-背景;1-类别1;2-类别2;3-类别3)

outputs = torch.argmax(output, dim=1)  # B x Z x Y x X
outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
labels_np = target.data.cpu().numpy()  # B x Z x Y x X
multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)

其中,torch.argmax 在类别channel上进行argmax操作,确定该像素属于哪个类别。如此得到的output,就与target的方式,保持了一致。

三、训练和验证

在综述篇,已经把框架固定内容基本上都介绍完了,到了本文就显得没什么好展开的了。那就把训练和验证中大的组块给补上。再配合上模型和数据流两篇文章,搭建好自己的训练代码不是问题。

3.1、main 主函数部分

主函数部分,其实是统筹整个训练主代码的。他包括了:

  1. 对训练超参数的定义
  2. 数据流的加载
  3. 网络模型的创建
  4. 优化器的定义
  5. 学习率的调整策略
  6. 损失函数的定义
  7. 训练和验证函数循环
  8. 训练过程参数的保存
  9. 训练模型的保存

这个过程在综述篇基本上已经介绍了,感兴趣的可以翻过去,再仔细的看看。如果是你自己来构建,是不是可以完整的走完这些内容。

下面就是主函数的代码,如下:

def main():
    Config = Configuration()
    Config.display()

    train_loader, valid_loader = get_Dataloader(Config)

    print('---start get model now---')
    model = get_model(Config).to(DEVICE)

    # ---- OPTIMIZER ----
    if Config.OPTIMR == "SGD":
        optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)
    elif Config.OPTIMR == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
    elif Config.OPTIMR == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=Config.LR, betas=(0.9, 0.999))
    elif Config.OPTIMR == "RMSProp":
        optimizer = optim.RMSprop(model.parameters(), lr=Config.LR)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.05, patience=20,
                                                           verbose=False, threshold=0.0001, threshold_mode='rel',
                                                           cooldown=0, min_lr=0, eps=1e-08)

    # Defining loss fn
    ce_layer = torch.nn.CrossEntropyLoss()

    train_loss_list = []  # 用来记录训练损失
    valid_loss_list = []  # 用来记录验证损失
    valid_dice_list = []
    epoch_list = []
    for epoch in range(1, Config.Max_epoch + 1):
        epoch_list.append(epoch)
        train_loss = train_model(model, DEVICE, train_loader, optimizer, ce_layer, epoch)  # 训练

        valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, ce_layer, epoch)   # 验证
        train_loss_list.append(train_loss)  # 记录每个epoch训练损失
        valid_loss_list.append(valid_loss)  # 验证损失
        valid_dice_list.append(valid_dice)
        draw_plot(epoch_list, valid_dice_list, 'valid_dice')
        draw_plot(epoch_list, valid_loss_list, 'valid_loss')
        draw_plot(epoch_list, train_loss_list, 'train_loss')

        if valid_dice > Config.Dice_Best:  
            path_ckpt = os.path.join(Config.model_path, 'best_model.pth')
            save_model(path_ckpt, model)
            Config.Dice_Best = valid_dice 
        else:
            path_ckpt = os.path.join(Config.model_path, 'last_model.pth')
            save_model(path_ckpt, model)

        scheduler.step(valid_loss)
    print('best val Dice is ', Config.Dice_Best)

3.2、训练部分

单个epoch的训练过程,和单个epoch的验证过程,在这里单独来定义。这样做的好处就是主函数的代码,相对会简洁一些,避免都放到一起,缩进了太深了,反正影响阅读。

下面是训练的部分,包括了:

  1. 对单个epoch中所有batch的迭代
  2. 对单个batch的前向推理
  3. 对单个batch预测结果损伤计算
  4. 对单个batch的预测结果进行dice coeff计算
  5. 梯度清零,反向回归
  6. 实时打印

下面是训练代码:

def train_model(model, device, train_loader, optimizer, ce_layer, epoch):  # 训练模型
    config = Configuration()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    multi_dices = list()

    model.train()
    bar = Bar('Processing train ', max=len(train_loader))
    for batch_index, (data, target) in enumerate(train_loader):  # 取batch索引,(data,target),也就是图和标签
        data_time.update(time.time() - end)
        data, target = data.to(device), target.to(device)

        output = model(data)  # 图 进模型 得到预测输出
        # loss = Loss(output, target)  # 计算损失
        loss = ce_layer(output, target)
        losses.update(loss.item(), data.size(0))

        outputs = torch.argmax(output, dim=1)  # B x Z x Y x X
        outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
        labels_np = target.data.cpu().numpy()  # B x Z x Y x X
        multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
        multi_dices.append(multi_dice)

        optimizer.zero_grad()  # 梯度归零
        loss.backward()  # 反向传播
        optimizer.step()  # 优化器走一步

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        multi_dices_np = np.array(multi_dices)
        mean_multi_dice = np.mean(multi_dices_np)

        # plot progress
        bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}| LR: {lr:.6f}'.format(
            epoch=epoch,
            batch=batch_index + 1,
            size=len(train_loader),
            data=data_time.val,
            bt=batch_time.val,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            dice=mean_multi_dice,
            lr=optimizer.param_groups[0]['lr']
        )
        bar.next()
    bar.finish()
    return losses.avg  # 返回平均损失

3.3、验证部分

验证部分与训练部分基本上一致的,只不过:

  1. 在训练阶段,model.train(),而在验证阶段,需要model.eval()
  2. 验证阶段不进行梯度回归更新模型,损失只是为了统计使用

其他几乎是没什么两样了,代码如下:

def valid_model(model, device, test_loader, ce_layer, epoch):    # 加了个test  1是想打印时好看(区分valid和test)  2是test要打印图,需要特别设计
    config = Configuration()
    # 模型训练-----调取方法
    model.eval()  # 用来验证或测试的
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    end = time.time()

    multi_dices = list()
    bar = Bar('Processing valid ', max=len(test_loader))

    with torch.no_grad():  # 不进行 梯度计算(反向传播)
        for batch_index, (data, target) in enumerate(test_loader):  # 枚举batch索引,(图,标签)
            data_time.update(time.time() - end)
            data, target = data.to(device), target.to(device)

            output = model(data)
            loss = ce_layer(output, target)
            losses.update(loss.item(), data.size(0))

            outputs = torch.argmax(output, dim=1)  # B x C x Z x Y x X   >   B x Z x Y x X
            outputs_np = outputs.data.cpu().numpy()  # B x Z x Y x X
            labels_np = target.data.cpu().numpy()  # B x Z x Y x X
            multi_dice = multi_dice_coeff(labels_np, outputs_np, config.num_outs)
            multi_dices.append(multi_dice)

            multi_dices_np = np.array(multi_dices)
            mean_multi_dice = np.mean(multi_dices_np)
            std_multi_dice = np.std(multi_dices_np)

            # plot progress
            bar.suffix = '(Epoch: {epoch: .1f} | {batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Dice: {dice:.4f}'.format(
                epoch=epoch,
                batch=batch_index + 1,
                size=len(test_loader),
                data=data_time.val,
                bt=batch_time.val,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                dice=mean_multi_dice
            )
            bar.next()
    bar.finish()

    return losses.avg, mean_multi_dice

3.4、训练感触

3D UNet 模型那一篇中,我们提到:

模型在训练阶段,是不需要在最后增加sigmoidsoftmax操作的。只有在推理阶段,才需要。

但是,反观 CrossEntropyLoss,它尽管没有在模型中,定义使用了sigmoidsoftmax操作,但是他在计算损失函数的时候,是偷偷使用了sigmoidsoftmax操作的。

如果不用 CrossEntropyLoss,采用 Dice loss,那在计算损失函数前,需要先对模型输出,做一个类似于 CrossEntropyLoss的归一化操作吗?

依照我自己训练发现:如果在计算 Dice loss 前,未进行归一化操作,梯度很容易消失,表现出来的就是没法收敛,很难训练。这或许及时sigmoidsoftmax起到的规范化作用,使得模型的训练更加简单了。至于其他的原因和现象,待发现了进一步补充。

四、总结

上次有人评论说要完整的代码,这个到最后肯定是会都发出来的。其中在单个文章里面,基本上已经将完整的代码给都贴上去了,稍作做下问题排查,应该就没什么问题。即便有什么问题,也都是一些简单的小问题,这点我都做过了验证。

对于一些初学的,比如pythonos文件操作的库,都不明白的,建议看看其他的文章,把这部分的知识给补齐,再继续学习。

如果出现了报错,第一时间先看看报错提示的修改建议,或者根据提示,定位到错误的地方,针对性的修改。不行就百度,绝大部分的问题,网上都已经有人遇到过了。最后实在不行,就在评论区留言,大家一起解决问题,会比较的快。

最后,还差一个预测篇,继续往后看吧。

猜你喜欢

转载自blog.csdn.net/wsLJQian/article/details/134250370