Skip to content

Commit

Permalink
moving cuda before optimizer construction
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyuanLucasLiu committed May 31, 2019
1 parent afe55ed commit 21117f0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 5 additions & 1 deletion mt_dnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(self, opt, state_dict=None, num_train_step=-1):
self.network.load_state_dict(state_dict['state'])
self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
if opt['cuda']:
self.network.cuda()

no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']

Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(self, opt, state_dict=None, num_train_step=-1):
self.ema = None
if opt['ema_opt'] > 0:
self.ema = EMA(self.config['ema_gamma'], self.network)
if opt['cuda']:
self.ema.cuda()
self.para_swapped=False

def setup_ema(self):
Expand Down Expand Up @@ -200,4 +204,4 @@ def save(self, filename):
def cuda(self):
self.network.cuda()
if self.config['ema_opt']:
self.ema.cuda()
self.ema.cuda()
2 changes: 0 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ def main():
if args.freeze_layers > 0:
model.network.freeze_layers(args.freeze_layers)

if args.cuda:
model.cuda()
for epoch in range(0, args.epochs):
logger.warning('At epoch {}'.format(epoch))
for train_data in train_data_list:
Expand Down

0 comments on commit 21117f0

Please sign in to comment.