pytorch 设置随机种子排除随机性

pytorch 设置随机种子排除随机性


本文章不同意转载,禁止以任何形式转载!!

前言

设置好随机种子,对于做重复性实验或者对比实验是十分重要的,pytorch官网也给出了文档说明

设置随机种子

为了解决随机性,需要把所有产生随机的地方进行限制,在这里我自己总结了一下:

  1. 排除PyTorch的随机性
  2. 排除第三方库的随机性
  3. 排除cudnn加速的随机性

这是mmdetection所给的方法:

def set_random_seed(seed, deterministic=False):
    """Set random seed.

    Args:
        seed (int): Seed to be used.
        deterministic (bool): Whether to set the deterministic option for
            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
            to True and `torch.backends.cudnn.benchmark` to False.
            Default: False.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

CUDA卷积操作使用的cuDNN库可能是跨应用程序多次执行的不确定性的来源。当使用一组新的尺寸参数调用cuDNN卷积时,一个可选的特性可以运行多个卷积算法,对它们进行基准测试以找到最快的一个。然后,在剩下的过程中,对于相应的尺寸参数集,将一致地使用最快的算法。由于基准测试的噪音和不同的硬件,基准测试可能会在后续的运行中选择不同的算法,即使是在同一台机器上。可以设置torch.backends.cudnn.benchmark = False,禁用基准功能会导致 cuDNN 确定性地选择算法,可能以降低性能为代价。

torch.use_deterministic_algorithms()允许您配置PyTorch,在可用的情况下使用确定算法,而不是非决定性算法,如果操作已知为非决定性算法(且没有确定性替代方案),则会抛出错误。

虽然禁用CUDA卷积基准,确保每次运行时CUDA选择相同的算法应用程序,但算法本身可能是不确定的,除非torch.use_deterministic_algorithms(true)或torch.backends.cudnn.deterministic = True。后者只设置控制这种行为,而torch.use_deterministic_algorithms()将使其他PyTorch操作的行为具有确定性。

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True / False

DataLoader

如果你发现在训练网络的时候发现固定了随机种子,但torch.utils.data.DataLoader中num_workers设置大于0也会造成两次训练结果不一样,这应该是开启多进程读取顺序不一致导致的,我设置成0后结果就一样了,但这样速度下降很多。解决办法:

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker
)

本文章不同意转载,禁止以任何形式转载!!

猜你喜欢

转载自blog.csdn.net/qq_41917697/article/details/115042465