Skip to content

Commit

Permalink
Re-organize fix random seed code; Make saved checkpoint compatible am…
Browse files Browse the repository at this point in the history
…ong different torch versions (open-mmlab#986)
  • Loading branch information
jihanyang authored May 30, 2022
1 parent 846cf3e commit 519b156
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pcdet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from functools import partial
from torch.utils.data import DataLoader
from torch.utils.data import DistributedSampler as _DistributedSampler

Expand Down Expand Up @@ -44,7 +45,7 @@ def __iter__(self):
return iter(indices)


def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4,
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, seed=None,
logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0):

dataset = __all__[dataset_cfg.DATASET](
Expand All @@ -70,7 +71,7 @@ def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None,
dataloader = DataLoader(
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers,
shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch,
drop_last=False, sampler=sampler, timeout=0
drop_last=False, sampler=sampler, timeout=0, worker_init_fn=partial(common_utils.worker_init_fn, seed=seed)
)

return dataset, dataloader, sampler
10 changes: 10 additions & 0 deletions pcdet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,20 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


def worker_init_fn(worker_id, seed=666):
if seed is not None:
random.seed(seed + worker_id)
np.random.seed(seed + worker_id)
torch.manual_seed(seed + worker_id)
torch.cuda.manual_seed(seed + worker_id)
torch.cuda.manual_seed_all(seed + worker_id)


def get_pad_params(desired_size, cur_size):
"""
Get padding parameters for np.pad function
Expand Down
5 changes: 3 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
args.epochs = cfg.OPTIMIZATION.NUM_EPOCHS if args.epochs is None else args.epochs

if args.fix_random_seed:
common_utils.set_random_seed(666)
common_utils.set_random_seed(666 + cfg.LOCAL_RANK)

output_dir = cfg.ROOT_DIR / 'output' / cfg.EXP_GROUP_PATH / cfg.TAG / args.extra_tag
ckpt_dir = output_dir / 'ckpt'
Expand Down Expand Up @@ -110,7 +110,8 @@ def main():
logger=logger,
training=True,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
total_epochs=args.epochs
total_epochs=args.epochs,
seed=666 if args.fix_random_seed else None
)

model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=train_set)
Expand Down
10 changes: 8 additions & 2 deletions tools/train_utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ def save_checkpoint(state, filename='checkpoint'):
optimizer_state = state['optimizer_state']
state.pop('optimizer_state', None)
optimizer_filename = '{}_optim.pth'.format(filename)
torch.save({'optimizer_state': optimizer_state}, optimizer_filename)
if torch.__version__ >= '1.4':
torch.save({'optimizer_state': optimizer_state}, optimizer_filename, _use_new_zipfile_serialization=False)
else:
torch.save({'optimizer_state': optimizer_state}, optimizer_filename)

filename = '{}.pth'.format(filename)
torch.save(state, filename)
if torch.__version__ >= '1.4':
torch.save(state, filename, _use_new_zipfile_serialization=False)
else:
torch.save(state, filename)

0 comments on commit 519b156

Please sign in to comment.