Pytorch Lightning验证时TQDMProgressBar进度条输出异常问题与解决方案

问题描述

在使用Pytorch Lightning时,若使用Pycharm或在Colab中用“python train.py”方式运行时,验证时的进度条会出现一个batch打印一行的情况。

例如:

trainer = pl.Trainer(
            callbacks=[TQDMProgressBar()],
        )

在Pycharm中运行时,进度条会出现如下情况:

Epoch 0: 100%|██████████| 20/20 [00:13<00:00,  1.49it/s, v_num=10]
Validation: 0it [00:00, ?it/s]
Validation:   0%|          | 0/5 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|          | 0/5 [00:00<?, ?it/s]
Validation DataLoader 0:  20%|██        | 1/5 [00:00<00:00, 28.66it/s]
Validation DataLoader 0:  40%|████      | 2/5 [00:00<00:00, 27.07it/s]
Validation DataLoader 0:  60%|██████    | 3/5 [00:00<00:00, 27.41it/s]
Validation DataLoader 0:  80%|████████  | 4/5 [00:00<00:00, 26.82it/s]

Validation DataLoader的进度条出现多次。

解决方案

重写TQDMProgressBarinit_validation_tqdm方法。代码如下:

class MyTQDMProgressBar(TQDMProgressBar):

    def __init__(self):
        super(MyTQDMProgressBar, self).__init__()

    def init_validation_tqdm(self):
        bar = Tqdm(
            desc=self.validation_description,
            position=0,  # 这里固定写0
            disable=self.is_disabled,
            leave=True,  # leave写True
            dynamic_ncols=True,
            file=sys.stdout,
        )
        return bar

然后使用MyTQDMProgressBar代替TQDMProgressBar即可。

猜你喜欢

转载自blog.csdn.net/zhaohongfei_358/article/details/129914003