问题描述
在使用Pytorch Lightning时,若使用Pycharm或在Colab中用“python train.py”方式运行时,验证时的进度条会出现一个batch打印一行的情况。
例如:
trainer = pl.Trainer(
callbacks=[TQDMProgressBar()],
)
在Pycharm中运行时,进度条会出现如下情况:
Epoch 0: 100%|██████████| 20/20 [00:13<00:00, 1.49it/s, v_num=10]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/5 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/5 [00:00<?, ?it/s]
Validation DataLoader 0: 20%|██ | 1/5 [00:00<00:00, 28.66it/s]
Validation DataLoader 0: 40%|████ | 2/5 [00:00<00:00, 27.07it/s]
Validation DataLoader 0: 60%|██████ | 3/5 [00:00<00:00, 27.41it/s]
Validation DataLoader 0: 80%|████████ | 4/5 [00:00<00:00, 26.82it/s]
Validation DataLoader的进度条出现多次。
解决方案
重写TQDMProgressBar
的init_validation_tqdm
方法。代码如下:
class MyTQDMProgressBar(TQDMProgressBar):
def __init__(self):
super(MyTQDMProgressBar, self).__init__()
def init_validation_tqdm(self):
bar = Tqdm(
desc=self.validation_description,
position=0, # 这里固定写0
disable=self.is_disabled,
leave=True, # leave写True
dynamic_ncols=True,
file=sys.stdout,
)
return bar
然后使用MyTQDMProgressBar
代替TQDMProgressBar
即可。