From 21117f0c6908ea956af1c58dfa4e64eabdbed17b Mon Sep 17 00:00:00 2001 From: Liyuan Liu Date: Fri, 31 May 2019 10:35:55 -0700 Subject: [PATCH] moving cuda before optimizer construction --- mt_dnn/model.py | 6 +++++- train.py | 2 -- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mt_dnn/model.py b/mt_dnn/model.py index eed86c2d..a459336e 100644 --- a/mt_dnn/model.py +++ b/mt_dnn/model.py @@ -34,6 +34,8 @@ def __init__(self, opt, state_dict=None, num_train_step=-1): self.network.load_state_dict(state_dict['state']) self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad]) + if opt['cuda']: + self.network.cuda() no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight'] @@ -86,6 +88,8 @@ def __init__(self, opt, state_dict=None, num_train_step=-1): self.ema = None if opt['ema_opt'] > 0: self.ema = EMA(self.config['ema_gamma'], self.network) + if opt['cuda']: + self.ema.cuda() self.para_swapped=False def setup_ema(self): @@ -200,4 +204,4 @@ def save(self, filename): def cuda(self): self.network.cuda() if self.config['ema_opt']: - self.ema.cuda() + self.ema.cuda() \ No newline at end of file diff --git a/train.py b/train.py index 959c2f3b..f23b7278 100644 --- a/train.py +++ b/train.py @@ -278,8 +278,6 @@ def main(): if args.freeze_layers > 0: model.network.freeze_layers(args.freeze_layers) - if args.cuda: - model.cuda() for epoch in range(0, args.epochs): logger.warning('At epoch {}'.format(epoch)) for train_data in train_data_list: