Skip to content

Commit

Permalink
wrap image write in try
Browse files Browse the repository at this point in the history
This sometimes fails due to "Unsupported format" error.
Report the exception and move on.
  • Loading branch information
tripzero committed Dec 19, 2019
1 parent 6aef4c8 commit 68296f7
Showing 1 changed file with 57 additions and 24 deletions.
81 changes: 57 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,17 @@ def eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeake
global_step, idx, speaker_str))
save_alignment(path, alignment)
tag = "eval_averaged_alignment_{}_{}".format(idx, speaker_str)
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
try:
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)
except Exception as e:
warn(str(e))

# Mel
writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
prepare_spec_image(mel), global_step)
try:
writer.add_image("(Eval) Predicted mel spectrogram text{}_{}".format(idx, speaker_str),
prepare_spec_image(mel), global_step)
except Exception as e:
warn(str(e))

# Audio
path = join(eval_output_dir, "step{:09d}_text{}_{}_predicted.wav".format(
Expand Down Expand Up @@ -442,44 +448,63 @@ def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
for i, alignment in enumerate(attn):
alignment = alignment[idx].cpu().data.numpy()
tag = "alignment_layer{}".format(i + 1)
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)

# save files as well for now
alignment_dir = join(checkpoint_dir, "alignment_layer{}".format(i + 1))
os.makedirs(alignment_dir, exist_ok=True)
path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
global_step, i + 1))
save_alignment(path, alignment)
try:
writer.add_image(tag, np.uint8(cm.viridis(
np.flip(alignment, 1).T) * 255), global_step)
# save files as well for now
alignment_dir = join(
checkpoint_dir, "alignment_layer{}".format(i + 1))
os.makedirs(alignment_dir, exist_ok=True)
path = join(alignment_dir, "step{:09d}_layer_{}_alignment.png".format(
global_step, i + 1))
save_alignment(path, alignment)
except Exception as e:
warn(str(e))

# Save averaged alignment
alignment_dir = join(checkpoint_dir, "alignment_ave")
os.makedirs(alignment_dir, exist_ok=True)
path = join(alignment_dir, "step{:09d}_alignment.png".format(global_step))
path = join(alignment_dir, "step{:09d}_layer_alignment.png".format(global_step))
alignment = attn.mean(0)[idx].cpu().data.numpy()
save_alignment(path, alignment)

tag = "averaged_alignment"
writer.add_image(tag, np.uint8(cm.viridis(np.flip(alignment, 1).T) * 255), global_step)

try:
writer.add_image(tag, np.uint8(cm.viridis(
np.flip(alignment, 1).T) * 255), global_step)
except Exception as e:
warn(str(e))

# Predicted mel spectrogram
if mel_outputs is not None:
mel_output = mel_outputs[idx].cpu().data.numpy()
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image("Predicted mel spectrogram", mel_output, global_step)
try:
writer.add_image("Predicted mel spectrogram",
mel_output, global_step)
except Exception as e:
warn(str(e))
pass

# Predicted spectrogram
if linear_outputs is not None:
linear_output = linear_outputs[idx].cpu().data.numpy()
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image("Predicted linear spectrogram", spectrogram, global_step)
try:
writer.add_image("Predicted linear spectrogram",
spectrogram, global_step)
except Exception as e:
warn(str(e))
pass

# Predicted audio signal
signal = audio.inv_spectrogram(linear_output.T)
signal /= np.max(np.abs(signal))
path = join(checkpoint_dir, "step{:09d}_predicted.wav".format(
global_step))
try:
writer.add_audio("Predicted audio signal", signal, global_step, sample_rate=hparams.sample_rate)
writer.add_audio("Predicted audio signal", signal,
global_step, sample_rate=hparams.sample_rate)
except Exception as e:
warn(str(e))
pass
Expand All @@ -489,13 +514,22 @@ def save_states(global_step, writer, mel_outputs, linear_outputs, attn, mel, y,
if mel_outputs is not None:
mel_output = mel[idx].cpu().data.numpy()
mel_output = prepare_spec_image(audio._denormalize(mel_output))
writer.add_image("Target mel spectrogram", mel_output, global_step)
try:
writer.add_image("Target mel spectrogram", mel_output, global_step)
except Exception as e:
warn(str(e))
pass

# Target spectrogram
if linear_outputs is not None:
linear_output = y[idx].cpu().data.numpy()
spectrogram = prepare_spec_image(audio._denormalize(linear_output))
writer.add_image("Target linear spectrogram", spectrogram, global_step)
try:
writer.add_image("Target linear spectrogram",
spectrogram, global_step)
except Exception as e:
warn(str(e))
pass


def logit(x, eps=1e-8):
Expand Down Expand Up @@ -712,7 +746,8 @@ def train(device, model, data_loader, optimizer, writer,
train_seq2seq, train_postnet)

if global_step > 0 and global_step % hparams.eval_interval == 0:
eval_model(global_step, writer, device, model, checkpoint_dir, ismultispeaker)
eval_model(global_step, writer, device, model,
checkpoint_dir, ismultispeaker)

# Update
loss.backward()
Expand All @@ -731,8 +766,7 @@ def train(device, model, data_loader, optimizer, writer,
if train_postnet:
writer.add_scalar("linear_loss", float(linear_loss.item()), global_step)
writer.add_scalar("linear_l1_loss", float(linear_l1_loss.item()), global_step)
writer.add_scalar("linear_binary_div_loss", float(
linear_binary_div.item()), global_step)
writer.add_scalar("linear_binary_div_loss", float(linear_binary_div.item()), global_step)
if train_seq2seq and hparams.use_guided_attention:
writer.add_scalar("attn_loss", float(attn_loss.item()), global_step)
if clip_thresh > 0:
Expand Down Expand Up @@ -963,8 +997,7 @@ def restore_parts(path, model):
# Setup summary writer for tensorboard
if log_event_path is None:
if platform.system() == "Windows":
log_event_path = "log/run-test" + \
str(datetime.now()).replace(" ", "_").replace(":", "_")
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_").replace(":", "_")
else:
log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_")
print("Log event path: {}".format(log_event_path))
Expand Down

0 comments on commit 68296f7

Please sign in to comment.