ddp 多卡训练torch 记录

前言

之前一直拿别人的开源代码改,最近需要自己从头写,实验需要多卡训练,于是就记录一下。


基本框架

from torch.utils.data.distributed import DistributedSampler

 # 1) 初始化
 # torch.distributed.init_process_group(backend="nccl", init_method='env://', rank=0, world_size=torch.cuda.device_count())
  torch.distributed.init_process_group(backend="nccl")
  # 2) 配置每个进程的gpu
  local_rank = torch.distributed.get_rank()
  torch.cuda.set_device(local_rank)

  device = torch.device("cuda", local_rank)
  # 3) 使用DistributedSampler
  trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True,
                          num_workers=opt.num_workers, sampler=DistributedSampler(trainset, shuffle=True))

# 4) 封装之前要把模型移到对应的gpu
if num_gpus >= 1:
    torch.backends.cudnn.enabled = False
    net = net.cuda()
if num_gpus > 1:
    # 5) 封装
    net = torch.nn.parallel.DistributedDataParallel(net,
                                                   device_ids=[local_rank],
                                                   output_device=local_rank,
                                                find_unused_parameters=True)

最开始的options加上这句

# for distribution
parser.add_argument("--local_rank", type=int)

例如在四卡服务器上运行

python -m torch.distributed.launch --nproc_per_node 4 train_ddp.py



细节

checkpoint

主线程保存模型权重就好了,如果需要resume train的话,还得将checkpoint字典整个torch.save,这里我暂时只保存了模型的权重。
着急的小伙伴可以参考 Pytorch模型resume training,加载模型基础上继续训练

if cur_epoch > 0 and cur_epoch % 2 == 0:
    if check_print_rank(opt):
        net_state_dict = net.module.state_dict() if opt.num_gpus > 1 else net.state_dict()
        train_state = {
    
    
            "net": net_state_dict,
            'optimizer': optimizer.state_dict(),
            "cur_epoch": cur_epoch, 
            "cur_step": cur_step
        }
        # 确保路径没错
        train_ckpt_path = 'train_ckpt/save_net_ckpt'
        if not os.path.exists(train_ckpt_path):
            os.makedirs(train_ckpt_path)
        train_state_path = 'train_ckpt/save_state'
        if not os.path.exists(train_state_path):
            os.makedirs(train_state_path)
        
        # 保存模型权重
        torch.save(net_state_dict,
                '{}/epoch_{}.pth'.format(train_ckpt_path, cur_epoch))
        # 保存训练时的状态,方便后续resume train
        torch.save(train_state,
                '{}/state_epoch_{}.pth'.format(train_state_path, cur_epoch))



7月14日后续: 绷不住了,服务器出了点问题直接死掉了。还是得老老实实写个resume train…
首先parse options那里加两个选项

# for resume train
parser.add_argument("--resume_train", type=bool, default=False)
parser.add_argument("--resume_state_path", type=str, default='')

然后训练开始前加载一下一些 模型权重,优化器权重,还有当前的epoch和step

if opt.resume_train:
    if opt.resume_state_path is not None and opt.resume_state_path != '':
         # 加载之前保存的断点
          state_dict = torch.load(opt.resume_state_path)
          print('load state dict from {}'.format(opt.resume_state_path))
          net.module.load_state_dict(state_dict['net'])
          optimizer.load_state_dict(state_dict['optimizer'])
          cur_epoch = state_dict['cur_epoch']
          cur_step = state_dict['cur_step']
          # -----------------------
          if check_print_rank(opt):
              print('start resume train at epoch {}, step {}'.format(cur_epoch, cur_step))
else:
    optimizer = optim.AdamW(net.parameters(), lr=2e-4)
    cur_epoch = 0
    cur_step = 0
    if check_print_rank(opt):
        print('train from epoch 0, step 0 ...')

之后设置一下学习率优化的scheduler

scheduler = LinearWarmupCosineAnnealingLR(
optimizer=optimizer, warmup_epochs=15, max_epochs=150)
max_epoch = opt.epochs

while cur_epoch <= max_epoch:
     cur_epoch = cur_epoch + 1
     scheduler.step(cur_epoch)
     if check_print_rank(opt):
            # lr = scheduler.get_lr()
            lr = scheduler.get_last_lr()[0]
            print('--> cur_epoch: {}, use lr: {}\n'.format(cur_epoch, lr))
            writer.add_scalar("learning rate", lr, cur_epoch)

     for batch in tqdm(trainloader):
         net.train()
         cur_step = cur_step + 1
         optimizer.zero_grad()
         # ... 训练的代码

保存模型权重和状态

if (cur_epoch + 1) % 2 == 0:
	net_state_dict = net.module.state_dict() if opt.num_gpus > 1 else net.state_dict()
	  train_state = {
    
    
	      "net": net_state_dict,
	      'optimizer': optimizer.state_dict(),
	      "cur_epoch": cur_epoch, 
	      "cur_step": cur_step
	  }
	  
	if check_print_rank(opt):
	    # 确保路径没错
	    train_ckpt_path = 'train_ckpt/save_net_ckpt'
	    if not os.path.exists(train_ckpt_path):
	        os.makedirs(train_ckpt_path)
	    train_state_path = 'train_ckpt/save_state'
	    if not os.path.exists(train_state_path):
	        os.makedirs(train_state_path)
	    
	    # 保存模型权重
	    torch.save(net_state_dict,
	            '{}/epoch_{}.pth'.format(train_ckpt_path, cur_epoch))
	    # 保存训练时的状态,方便后续resume train
	    torch.save(train_state,
	            '{}/state_epoch_{}.pth'.format(train_state_path, cur_epoch))



validation

这里我也只用rank为0的主线程去做validation, 因为生成式任务多卡测试的话 合成测试结果有点儿麻烦(当然也不是不行,只是我懒),就不搞了。要是分类的任务就方便一点。

if opt.num_gpus <= 1:
    print('epoch: {}, start eval...'.format(cur_epoch))
    eval_operation(net, writer, cur_epoch)
    print('epoch: {}, eval end\n\n'.format(cur_epoch))
else:
    if torch.distributed.get_rank() == 0:
        print('\n rank: {}, epoch: {}, start eval...'.format(torch.distributed.get_rank(), cur_epoch))
        eval_operation(net.module, writer, cur_epoch)
        print('\n rank: {}, epoch: {}, eval end\n\n'.format(torch.distributed.get_rank(), cur_epoch))



tensorboard

tensorboard经常会用到
远程展示的操作可以看我另一篇博客 tensorboard显示远程服务器

def check_print_rank(opt):
    return opt.num_gpus <= 1 or torch.distributed.get_rank() == 0

if check_print_rank(opt):
    # logs存在Tensorboard_logs目录下
    start_time = datetime.datetime.now().strftime('%Y-%m-%d  %H:%M:%S')
    Tensorboard_logs_dir = "Tensorboard_logs/{}".format(start_time)
    writer = SummaryWriter(Tensorboard_logs_dir)
    print('log in {}...\n'.format(Tensorboard_logs_dir))

这里的opt是这个东西

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--num_gpus",type=int,default= 4,help = "Number of GPUs to use for training")
# ... 一系列命令行参数配置
opt = parser.parse_args()



写入的时候用主线程写就行

if check_print_rank(opt):
      # Logging to TensorBoard (if installed) by default
      writer.add_scalar("train loss", loss, cur_step)



用完之后, 最后关闭

if check_print_rank(opt):
	writer.close()

猜你喜欢

转载自blog.csdn.net/weixin_43850253/article/details/131706419
ddp
今日推荐