TensorBoard的学习与使用

(一)基础使用

这里首先介绍TensorBoard中各个功能如何使用。

1.1add_scalar

add_scalar(tag, scalar_value, global_step=None, walltime=None)

参数

tag (string): 数据名称,不同名称的数据使用不同曲线展示
scalar_value (float): 数字常量值
global_step (int, optional): 训练的 step
walltime (float, optional): 记录发生的时间,默认为 time.time()

基础使用参考这篇文章
https://blog.csdn.net/bigbennyguo/article/details/87956434

(二)远程操纵TensorBoard

本地浏览器使用tensorboard查看远程服务器训练情况
参考这篇博文
https://blog.csdn.net/u010626937/article/details/107747070

(三)TensorBoard如何在具体的实例中使用

参考pytorch官方文档
https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html

(四)存在的问题

1、ERROR: TensorBoard could not bind to port 6006, it was already in use

这表明6006端口号被占用,使用lsof -i:6006找出占用号,之后使用kill将其杀死
在这里插入图片描述

2、No dashboards are active for the current data set

如图
在这里插入图片描述这是路径问题,将显示tensorBoard的指令中logdir后的路径改为绝对路径
tensorboard --logdir=/home/sgyj/code/FrequecyTransformer/runs/FTmodel --port=6006

(五)修改之后的训练代码

import torch
import torch.nn as nn
import matplotlib as mpl
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
mpl.use('Agg')

import matplotlib.pyplot as plt

import Copy_data as dataload
from FTmodelEasy import FTModel
from vit_seg_modeling import VisionTransformer
from vit_seg_modeling import CONFIGS
from tensorboardX import SummaryWriter

import time
import numpy as np
import os
BATCH = 16
LR = 1e-3
EPOCHES = 1


# 计算模型准确率,召回率和f1分数
# output->[batch, 1, 256, 256]
# img_gt->[batch, 1, 256, 256]
def calprecise(output, img_gt):
    output = torch.sigmoid(output)
    mask = output > 0.3

    acc_mask = torch.mul(mask.float(), img_gt)
    acc_mask = acc_mask.sum()
    acc_fenmu = mask.sum()
    recall_fenmu = img_gt.sum()

    acc = acc_mask / (acc_fenmu + 0.0001)
    recall = acc_mask / (recall_fenmu + 0.0001)
    f1 = 2 * acc * recall / (acc + recall + 0.0001)

    return acc, recall, f1
#TensorBoardX设置
writer=SummaryWriter('runs/FTmodel')


def train():
    copy_train = dataload.Copy_DATA("train")
    '''
    np.random.seed(200)
    np.random.shuffle(casia_train.image_name)
    np.random.seed(200)
    np.random.shuffle(casia_train.gt_name)
    '''
    train_loader = torch.utils.data.DataLoader(copy_train, batch_size=BATCH, shuffle=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #net = get_danet().to(device)

    net = FTModel().to(device)

    net.train()

    # 初始化网络参数
    '''
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_in')
    '''
    lossfunction = nn.BCELoss()
    learning_rate = LR
    optimizer = torch.optim.SGD([{
    
    'params' : net.parameters(),'initial_lr': 1e-3}], lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    '''
    max_iterations = EPOCHES  * len(train_loader)
    iter_num = 0
    '''
    #scheduler = MultiStepLR(optimizer, milestones=[240,280,340,380],gamma=0.1,last_epoch=200)
    scheduler = MultiStepLR(optimizer, milestones=[40,80],gamma=0.1)

    losses = []
    precises = []
    recalles = []
    f1es = []
    # 如果从断点开始 resume为true
    RESUME = False
    if RESUME:
        path_checkpoint = "/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_55.pth"  # 断点路径
        checkpoint = torch.load(path_checkpoint)  # 加载断点

        net.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
        start_epoch = checkpoint['epoch']  # 设置开始的epoch
        #scheduler.load_state_dict(checkpoint['scheduler'])
    # 如果需要从断点开始训练  则下面循环 in  range(start_epoch + 1 ,EPOCH)
    for epoch in range(EPOCHES):
        total_loss = 0
        precise = 0
        recall_score = 0
        f1_score = 0
        st = time.time()
        for step, data in enumerate(train_loader):
            img, img_gt = data
            img = img.to(device)
            img_gt = img_gt.to(device)
            pred_mask = net(img)
            #pred_mask=pred_mask[0]
            pred_mask_sigmoid = torch.sigmoid(pred_mask)
            pred_mask_flat = pred_mask_sigmoid.view(-1)
            true_masks_flat = img_gt.view(-1)
            loss = lossfunction(pred_mask_flat, true_masks_flat)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            '''
            lr_ = LR  * (1.0 - iter_num / max_iterations) ** 0.9
            #动态调整学习率
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_
            iter_num = iter_num + 1
            '''
            acc, recall, f1 = calprecise(pred_mask, img_gt)
            lrr=optimizer.state_dict()['param_groups'][0]['lr']

            print("(train)epoch%d->step%d loss:%.6f acc:%.6f recall:%.6f f1:%.6f lr:%.6f cost time:%ds" % (
                epoch, step, loss, acc, recall, f1, lrr,time.time() - st))

            total_loss = loss.item() + total_loss
            precise = precise + acc
            recall_score = recall_score + recall
            f1_score = f1_score + f1
        scheduler.step()    
        # 计算每个epoch的平均指标
        losses.append(total_loss / len(train_loader))
        precises.append(precise / len(train_loader))
        recalles.append(recall_score / len(train_loader))
        f1es.append(f1_score / len(train_loader))
        cost = time.time() - st
        print("(train)epoch%d-> loss:%.6f acc:%.6f recall:%.6f f1:%.6f cost time:%ds" %
              (epoch, total_loss / len(train_loader), precise / len(train_loader), recall_score / len(train_loader),
               f1_score / len(train_loader), cost))
        writer.add_scalar('training loss',
                          total_loss / len(train_loader),
                          epoch)
        writer.add_scalar('precises',
                          precise / len(train_loader),
                          epoch)
        writer.add_scalar('recalles',
                          recall_score / len(train_loader),
                          epoch)
        writer.add_scalar('f1es',
                          f1_score / len(train_loader),
                          epoch)
        # 每20个epoch保存一次模型断点
        if (epoch != 0 and epoch % 5 == 0):
            checkpoint = {
    
    
                "net": net.state_dict(),
                'optimizer': optimizer.state_dict(),
                "epoch": epoch
                #'scheduler':scheduler.state_dict()
            }
            if not os.path.isdir("/home/sgyj/code/FrequecyTransformer/checkpoint"):
                os.mkdir("/home/sgyj/code/FrequecyTransformer/checkpoint")
            torch.save(checkpoint, '/home/sgyj/code/FrequecyTransformer/checkpoint/ckpt_best_%s.pth' % (str(epoch)))
        # 每20个epoch保存一次模型

        if (epoch != 0 and epoch % 20 == 0):
            torch.save(net.state_dict(), "/home/sgyj/code/FrequecyTransformer/tem/FrequecyTransformer-copy_epoch_%d.pth" % (epoch))

        st = time.time()
    torch.save(net.state_dict(), '/home/sgyj/code/FrequecyTransformer/tem/FrequecyTransformer-copy_final.pth')
    '''
    # 绘图
    x = np.arange(len(losses))
    plt.plot(x, losses, label="train")
    # plt.plot(x, losses_val, label="val")
    plt.title("train losses")
    plt.grid()
    plt.legend()
    plt.savefig("losses.jpg")
    plt.clf()

    plt.plot(x, precises, label="train")
    # plt.plot(x, precises_val, label="val")
    plt.title("train acc")
    plt.grid()
    plt.legend()
    plt.savefig("acc.jpg")
    plt.clf()

    plt.plot(x, recalles, label="train")
    # plt.plot(x, recalles_val, label="val")
    plt.title("train recall")
    plt.grid()
    plt.legend()
    plt.savefig("recall.jpg")
    plt.clf()

    plt.plot(x, f1es, label="train")
    # plt.plot(x, f1es_val, label="val")
    plt.title("train f1")
    plt.grid()
    plt.legend()
    plt.savefig("f1_score.jpg")
    plt.clf()
    '''

if __name__ == "__main__":
    train()

猜你喜欢

转载自blog.csdn.net/weixin_44020747/article/details/119538087
今日推荐