From 044cc4aa11d9c6b379495b7a18b24a7243055a65 Mon Sep 17 00:00:00 2001 From: Chaitanya Talnikar Date: Tue, 13 Nov 2018 17:15:04 -0800 Subject: [PATCH] Ensure remove weight norm actually removes the weight computation from network --- glow.py | 18 +++++++++--------- inference.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/glow.py b/glow.py index fd283c6..c5b4848 100644 --- a/glow.py +++ b/glow.py @@ -281,15 +281,15 @@ def infer(self, spect, sigma=1.0): audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data return audio - def remove_weightnorm(self): - waveglow = copy.deepcopy(self) - for WN in waveglow.WN: - WN.start = torch.nn.utils.remove_weight_norm(WN.start) - WN.in_layers = remove(WN.in_layers) - WN.cond_layers = remove(WN.cond_layers) - WN.res_layers = remove(WN.res_layers) - WN.skip_layers = remove(WN.skip_layers) - self = waveglow +def remove_weightnorm(self): + waveglow = copy.deepcopy(self) + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layers = remove(WN.cond_layers) + WN.res_layers = remove(WN.res_layers) + WN.skip_layers = remove(WN.skip_layers) + return waveglow def remove(conv_list): new_conv_list = torch.nn.ModuleList() diff --git a/inference.py b/inference.py index d615905..98bffa1 100644 --- a/inference.py +++ b/inference.py @@ -29,11 +29,12 @@ from scipy.io.wavfile import write import torch from mel2samp import files_to_list, MAX_WAV_VALUE +from glow import remove_weightnorm def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16): mel_files = files_to_list(mel_files) waveglow = torch.load(waveglow_path)['model'] - waveglow.remove_weightnorm() + waveglow = remove_weightnorm(waveglow) waveglow.cuda().eval() if is_fp16: waveglow.half()