Skip to content

Commit

Permalink
support merge_all_iters_to_one_epoch to avoid deadlock when switching…
Browse files Browse the repository at this point in the history
… epochs, eval last 10 epochs after training
  • Loading branch information
sshaoshuai committed Jul 8, 2020
1 parent 9fdb443 commit ed6f3dd
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
9 changes: 9 additions & 0 deletions pcdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=

self.grid_size = self.data_processor.grid_size
self.voxel_size = self.data_processor.voxel_size
self.total_epochs = 0
self._merge_all_iters_to_one_epoch = False

@property
def mode(self):
Expand Down Expand Up @@ -65,6 +67,13 @@ def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=N
"""

def merge_all_iters_to_one_epoch(self, merge=True, epochs=None):
if merge:
self._merge_all_iters_to_one_epoch = True
self.total_epochs = epochs
else:
self._merge_all_iters_to_one_epoch = False

def __len__(self):
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions pcdet/datasets/kitti/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,16 @@ def evaluation(self, det_annos, class_names, **kwargs):
return ap_result_str, ap_dict

def __len__(self):
if self._merge_all_iters_to_one_epoch:
return len(self.kitti_infos) * self.total_epochs

return len(self.kitti_infos)

def __getitem__(self, index):
# index = 4
if self._merge_all_iters_to_one_epoch:
index = index % len(self.kitti_infos)

info = copy.deepcopy(self.kitti_infos[index])

sample_idx = info['point_cloud']['lidar_idx']
Expand Down
5 changes: 3 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def repeat_eval_ckpt(model, test_loader, args, eval_output_dir, logger, ckpt_dir
cur_epoch_id, cur_ckpt = get_no_evaluated_ckpt(ckpt_dir, ckpt_record_file, args)
if cur_epoch_id == -1 or int(float(cur_epoch_id)) < args.start_epoch:
wait_second = 30
print('Wait %s seconds for next check (progress: %.1f / %d minutes): %s \r'
% (wait_second, total_time * 1.0 / 60, args.max_waiting_mins, ckpt_dir), end='', flush=True)
if cfg.LOCAL_RANK == 0:
print('Wait %s seconds for next check (progress: %.1f / %d minutes): %s \r'
% (wait_second, total_time * 1.0 / 60, args.max_waiting_mins, ckpt_dir), end='', flush=True)
time.sleep(wait_second)
total_time += 30
if total_time > args.max_waiting_mins * 60 and (first_eval is False):
Expand Down
25 changes: 22 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from train_utils.optimization import build_optimizer, build_scheduler
from train_utils.train_utils import train_model
import torch.distributed as dist

from test import repeat_eval_ckpt
from pathlib import Path
import argparse
import datetime
Expand Down Expand Up @@ -129,9 +129,8 @@ def main():
model = nn.parallel.DistributedDataParallel(model, device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()])
logger.info(model)

total_iters_each_epoch = len(train_loader) if not args.merge_all_iters_to_one_epoch else len(train_loader) // args.epochs
lr_scheduler, lr_warmup_scheduler = build_scheduler(
optimizer, total_iters_each_epoch=total_iters_each_epoch, total_epochs=args.epochs,
optimizer, total_iters_each_epoch=len(train_loader), total_epochs=args.epochs,
last_epoch=last_epoch, optim_cfg=cfg.OPTIMIZATION
)

Expand Down Expand Up @@ -161,6 +160,26 @@ def main():
logger.info('**********************End training %s/%s(%s)**********************\n\n\n'
% (cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))

logger.info('**********************Start evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))
test_set, test_loader, sampler = build_dataloader(
dataset_cfg=cfg.DATA_CONFIG,
class_names=cfg.CLASS_NAMES,
batch_size=args.batch_size,
dist=dist_train, workers=args.workers, logger=logger, training=False
)
eval_output_dir = output_dir / 'eval' / 'eval_with_train'
eval_output_dir.mkdir(parents=True, exist_ok=True)
args.start_epoch = max(args.epochs - 10, 0) # Only evaluate the last 10 epochs

repeat_eval_ckpt(
model.module if dist_train else model,
test_loader, args, eval_output_dir, logger, ckpt_dir,
dist_test=dist_train
)
logger.info('**********************End evaluation %s/%s(%s)**********************' %
(cfg.EXP_GROUP_PATH, cfg.TAG, args.extra_tag))


if __name__ == '__main__':
main()

0 comments on commit ed6f3dd

Please sign in to comment.