Skip to content

Commit

Permalink
Self distillation support (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
KSGulin committed Feb 7, 2023
1 parent 778d4f5 commit d5687f4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,16 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
else None
)
amp = check_amp(model) # check AMP

# Knowledge distillation
with torch_distributed_zero_first(LOCAL_RANK):
teacher_model = (
attempt_load(opt.teacher_weights, device=device)
if (opt.teacher_weights and sparsification_manager and sparsification_manager.distillation_active)
else None
)
if opt.teacher_weights and sparsification_manager and sparsification_manager.distillation_active:
teacher_model = (
deepcopy(model.eval()) if opt.teacher_weights == "self"
else attempt_load(opt.teacher_weights, device=device)
)
else:
teacher_model = None

# Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand Down Expand Up @@ -514,7 +518,7 @@ def parse_opt(known=False, skip_parse=False):
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default=SAVE_ROOT / 'yolov5s.pt', help='initial weights path')
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
parser.add_argument('--teacher-weights', type=str, default='', help='distillation teacher initial weights path')
parser.add_argument('--teacher-weights', type=str, default='', help='distillation teacher initial weights path, can be set to `self` for self distillation')
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--data-path', type=str, default= '', help='path to dataset to overwrite the path in dataset.yaml')
parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
Expand Down

0 comments on commit d5687f4

Please sign in to comment.