Skip to content

Commit

Permalink
remove grad_tts use_text_encoder parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
MingjieChen committed Jun 13, 2023
1 parent 10546e6 commit 714bd17
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions decoder/grad_tts/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,18 @@ def _train_epoch(self):
loss, losses = compute_loss(self.model, _batch)
self.timer.cnt('fw')
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
max_norm=1)
self.optimizer.step()
self.timer.cnt('bw')
enc_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.encoder.parameters(),
max_norm=1)
dec_grad_norm = torch.nn.utils.clip_grad_norm_(self.model.decoder.parameters(),
max_norm=1)

loss_string = f"epoch: {self.epochs}| iters: {self.iters}| timer: {self.timer.show()}|"
for key in losses:
train_losses["train/%s" % key].append(losses[key])
loss_string += f" {key}:{losses[key]:.3f} "
self.step_writer.add_scalar('step/'+key, losses[key], self.iters)
self.step_writer.add_scalar('step/lr', self._get_lr(), self.iters)
self.step_writer.add_scalar('step/grad_norm', grad_norm, self.iters)
self.iters+=1
if self.iters % self.config['show_freq'] == 0:
print(loss_string, flush = True)
Expand Down

0 comments on commit 714bd17

Please sign in to comment.