| import torch |
| import random |
| import numpy as np |
| import os |
| |
| def set_seed(seed=1): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| |
| |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
另外,固定数据加载时的随机种子: (可能会对性能造成影响,甚至代码不收敛)
| GLOBAL_SEED = 1 |
| GLOBAL_WORKER_ID = None |
| |
| def worker_init_fn(worker_id): |
| global GLOBAL_WORKER_ID |
| GLOBAL_WORKER_ID = worker_id |
| set_seed(GLOBAL_SEED + worker_id) |
| |
| train_loader = DataLoaderX( |
| local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size, |
| sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True, |
| worker_init_fn=worker_init_fn) |
固定 python 随机数
| rng = numpy.random.RandomState(1) |
| rng.uniform(0,1,(2,3)) |
| |
| rng = numpy.random.RandomState(1) |
| rng.uniform(0,1,(2,3)) |