Skip to content

Commit

Permalink
Merge pull request NVIDIA#27 from NVIDIA/model_perfimprov
Browse files Browse the repository at this point in the history
Model perfimprov
  • Loading branch information
rafaelvalle committed Nov 14, 2018
2 parents 923531a + 6be0977 commit ee5ce70
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 21 deletions.
38 changes: 24 additions & 14 deletions glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@
from torch.autograd import Variable
import torch.nn.functional as F


@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b
t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts


class WaveGlowLoss(torch.nn.Module):
def __init__(self, sigma=1.0):
super(WaveGlowLoss, self).__init__()
Expand Down Expand Up @@ -145,12 +156,10 @@ def forward(self, forward_input):
audio = self.start(audio)

for i in range(self.n_layers):
in_act = self.in_layers[i](audio)
in_act = in_act + self.cond_layers[i](spect)

t_act = torch.nn.functional.tanh(in_act[:, :self.n_channels, :])
s_act = torch.nn.functional.sigmoid(in_act[:, self.n_channels:, :])
acts = t_act * s_act
acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio),
self.cond_layers[i](spect),
torch.IntTensor([self.n_channels]))

res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
Expand Down Expand Up @@ -282,14 +291,15 @@ def infer(self, spect, sigma=1.0):
audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
return audio

def remove_weightnorm(self):
waveglow = copy.deepcopy(self)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow
@staticmethod
def remove_weightnorm(model):
waveglow = copy.deepcopy(model)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow

def remove(conv_list):
new_conv_list = torch.nn.ModuleList()
Expand Down
91 changes: 87 additions & 4 deletions glow_old.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,89 @@
import copy
import torch
from glow import Invertible1x1Conv, remove
from glow import WN


@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b
t_act = torch.nn.functional.tanh(in_act[:, :n_channels_int, :])
s_act = torch.nn.functional.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts


class WN(torch.nn.Module):
"""
This is the WaveNet like layer for the affine coupling. The primary difference
from WaveNet is the convolutions need not be causal. There is also no dilation
size reset. The dilation only doubles on each layer
"""
def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
kernel_size):
super(WN, self).__init__()
assert(kernel_size % 2 == 1)
assert(n_channels % 2 == 0)
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_layers = torch.nn.ModuleList()
self.skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()

start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
start = torch.nn.utils.weight_norm(start, name='weight')
self.start = start

# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end

for i in range(n_layers):
dilation = 2 ** i
padding = int((kernel_size*dilation - dilation)/2)
in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size,
dilation=dilation, padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
self.in_layers.append(in_layer)

cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels, 1)
cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
self.cond_layers.append(cond_layer)

# last one is not necessary
if i < n_layers - 1:
res_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
res_layer = torch.nn.utils.weight_norm(res_layer, name='weight')
self.res_layers.append(res_layer)

skip_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
skip_layer = torch.nn.utils.weight_norm(skip_layer, name='weight')
self.skip_layers.append(skip_layer)

def forward(self, forward_input):
audio, spect = forward_input
audio = self.start(audio)

for i in range(self.n_layers):
acts = fused_add_tanh_sigmoid_multiply(
self.in_layers[i](audio),
self.cond_layers[i](spect),
torch.IntTensor([self.n_channels]))

if i < self.n_layers - 1:
res_acts = self.res_layers[i](acts)
audio = res_acts + audio

if i == 0:
output = self.skip_layers[i](acts)
else:
output = self.skip_layers[i](acts) + output

return self.end(output)


class WaveGlow(torch.nn.Module):
Expand Down Expand Up @@ -140,12 +222,13 @@ def infer(self, spect, sigma=1.0):

return audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data

def remove_weightnorm(self):
waveglow = copy.deepcopy(self)
@staticmethod
def remove_weightnorm(model):
waveglow = copy.deepcopy(model)
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_layers = remove(WN.res_layers)
WN.skip_layers = remove(WN.skip_layers)
self = waveglow
return waveglow
5 changes: 2 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@
#
# *****************************************************************************
import os
import numpy as np
from scipy.io.wavfile import write
import torch
from mel2samp import files_to_list, MAX_WAV_VALUE
from glow import remove_weightnorm


def main(mel_files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16):
mel_files = files_to_list(mel_files)
waveglow = torch.load(waveglow_path)['model']
waveglow = remove_weightnorm(waveglow)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow.cuda().eval()
if is_fp16:
waveglow.half()
Expand Down

0 comments on commit ee5ce70

Please sign in to comment.