Skip to content

Commit

Permalink
glow_old.py: using fused res_skip, removing deep copy
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaelvalle committed Nov 16, 2018
1 parent 3122e3e commit d19cc9d
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions glow_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_layers = torch.nn.ModuleList()
self.skip_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
self.cond_layers = torch.nn.ModuleList()

start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
Expand Down Expand Up @@ -56,13 +55,12 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,

# last one is not necessary
if i < n_layers - 1:
res_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
res_layer = torch.nn.utils.weight_norm(res_layer, name='weight')
self.res_layers.append(res_layer)

skip_layer = torch.nn.Conv1d(n_channels, n_channels, 1)
skip_layer = torch.nn.utils.weight_norm(skip_layer, name='weight')
self.skip_layers.append(skip_layer)
res_skip_channels = 2*n_channels
else:
res_skip_channels = n_channels
res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
self.res_skip_layers.append(res_skip_layer)

def forward(self, forward_input):
audio, spect = forward_input
Expand All @@ -74,15 +72,17 @@ def forward(self, forward_input):
self.cond_layers[i](spect),
torch.IntTensor([self.n_channels]))

res_skip_acts = self.res_skip_layers[i](acts)
if i < self.n_layers - 1:
res_acts = self.res_layers[i](acts)
audio = res_acts + audio
audio = res_skip_acts[:,:self.n_channels,:] + audio
skip_acts = res_skip_acts[:,self.n_channels:,:]
else:
skip_acts = res_skip_acts

if i == 0:
output = self.skip_layers[i](acts)
output = skip_acts
else:
output = self.skip_layers[i](acts) + output

output = skip_acts + output
return self.end(output)


Expand Down Expand Up @@ -224,11 +224,10 @@ def infer(self, spect, sigma=1.0):

@staticmethod
def remove_weightnorm(model):
waveglow = copy.deepcopy(model)
waveglow = model
for WN in waveglow.WN:
WN.start = torch.nn.utils.remove_weight_norm(WN.start)
WN.in_layers = remove(WN.in_layers)
WN.cond_layers = remove(WN.cond_layers)
WN.res_layers = remove(WN.res_layers)
WN.skip_layers = remove(WN.skip_layers)
WN.res_skip_layers = remove(WN.res_skip_layers)
return waveglow

0 comments on commit d19cc9d

Please sign in to comment.