Skip to content

Commit

Permalink
Merge pull request NVIDIA#21 from nvchai/weight_norm_fix
Browse files Browse the repository at this point in the history
Ensure remove weight norm actually removes the weight computation from the inference network
  • Loading branch information
rafaelvalle committed Nov 14, 2018
2 parents ea431fd + 044cc4a commit 045482e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
18 changes: 9 additions & 9 deletions glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,15 @@ def infer(self, spect, sigma=1.0):
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
return audio

def remove_weightnorm(self):
waveglow = copy.deepcopy(self)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_layers = remove(WN.res_layers)
WN.skip_layers = remove(WN.skip_layers)
self = waveglow
def remove_weightnorm(self):
waveglow = copy.deepcopy(self)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_layers = remove(WN.res_layers)
WN.skip_layers = remove(WN.skip_layers)
return waveglow

def remove(conv_list):
new_conv_list = torch.nn.ModuleList()
Expand Down
3 changes: 2 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
from scipy.io.wavfile import write
import torch
from mel2samp import files_to_list, MAX_WAV_VALUE
from glow import remove_weightnorm

def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
mel_files = files_to_list(mel_files)
waveglow = torch.load(waveglow_path)['model']
waveglow.remove_weightnorm()
waveglow = remove_weightnorm(waveglow)
waveglow.cuda().eval()
if is_fp16:
waveglow.half()
Expand Down

0 comments on commit 045482e

Please sign in to comment.