Skip to content

Commit

Permalink
Merge pull request #178 from r9y9/fix-pytorch13
Browse files Browse the repository at this point in the history
Fixes for pytorch 1.3
  • Loading branch information
r9y9 authored Dec 21, 2019
2 parents f04a271 + f6f87aa commit 897f31e
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ python synthesis.py --preset=20180505_deepvoice3_ljspeech.json \

## Requirements

- Python 3 (<= 3.6)
- Python >= 3.5
- CUDA >= 8.0
- PyTorch >= v0.4.0
- PyTorch >= v1.0.0
- [nnmnkwii](https://github.com/r9y9/nnmnkwii) >= v0.0.11
- [MeCab](http://taku910.github.io/mecab/) (Japanese only)

Expand Down
6 changes: 3 additions & 3 deletions deepvoice3_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_mask_from_lengths(memory, memory_lengths):
memory: (batch, max_time, dim)
memory_lengths: array like
"""
mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_()
for idx, l in enumerate(memory_lengths):
mask[idx][:l] = 1
max_len = max(memory_lengths)
mask = torch.arange(max_len).expand(memory.size(0), max_len) < torch.tensor(memory_lengths).unsqueeze(-1)
mask = mask.to(memory.device)
return ~mask
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def create_readme_rst():
install_requires=[
"numpy",
"scipy",
"torch >= 0.4.0",
"torch >= 1.0.0",
"unidecode",
"inflect",
"librosa",
Expand Down
16 changes: 10 additions & 6 deletions tests/test_deepvoice3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def _get_model(n_speakers=1, speaker_embed_dim=None,
force_monotonic_attention=False,
use_decoder_state_for_postnet_input=False):
use_decoder_state_for_postnet_input=False, use_memory_mask=False):
model = deepvoice3(n_vocab=n_vocab,
embed_dim=32,
mel_dim=num_mels,
Expand All @@ -42,6 +42,7 @@ def _get_model(n_speakers=1, speaker_embed_dim=None,
converter_channels=32,
force_monotonic_attention=force_monotonic_attention,
use_decoder_state_for_postnet_input=use_decoder_state_for_postnet_input,
use_memory_mask=use_memory_mask,
)
return model

Expand All @@ -62,7 +63,7 @@ def _test_data():
x = torch.LongTensor(seqs)
y = torch.rand(x.size(0), 12, 80)

return x, y
return x, y, input_lengths


def _deepvoice3(n_vocab, embed_dim=256, mel_dim=80,
Expand Down Expand Up @@ -110,11 +111,14 @@ def _deepvoice3(n_vocab, embed_dim=256, mel_dim=80,


def test_single_speaker_deepvoice3():
x, y = _test_data()
x, y, lengths = _test_data()

for v in [False, True]:
model = _get_model(use_decoder_state_for_postnet_input=v)
mel_outputs, linear_outputs, alignments, done = model(x, y)
mel_outputs, linear_outputs, alignments, done = model(x, y, input_lengths=lengths)

model = _get_model(use_memory_mask=True)
mel_outputs, linear_outputs, alignments, done = model(x, y, input_lengths=lengths)


def _pad_2d(x, max_len, b_pad=0):
Expand Down Expand Up @@ -192,7 +196,7 @@ def test_incremental_correctness():
assert max_target_len % r == 0
mel = _pad_2d(mel, max_target_len)
mel = torch.from_numpy(mel)
mel_reshaped = mel.view(1, -1, mel_dim * r)
mel_reshaped = mel.contiguous().view(1, -1, mel_dim * r)
frame_positions = np.arange(1, mel_reshaped.size(1) + 1).reshape(1, mel_reshaped.size(1))

x = torch.LongTensor(seqs)
Expand Down Expand Up @@ -269,7 +273,7 @@ def test_incremental_forward():
assert max_target_len % r == 0
mel = _pad_2d(mel, max_target_len)
mel = torch.from_numpy(mel)
mel_reshaped = mel.view(1, -1, mel_dim * r)
mel_reshaped = mel.contiguous().view(1, -1, mel_dim * r)

frame_positions = np.arange(1, mel_reshaped.size(1) + 1).reshape(1, mel_reshaped.size(1))

Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,10 @@ def collate_fn(batch):
s, e = 1, max_decoder_target_len + 1
# if b_pad > 0:
# s, e = s - 1, e - 1
# NOTE: needs clone to supress RuntimeError in dataloarder...
# ref: https://github.com/pytorch/pytorch/issues/10756
frame_positions = torch.arange(s, e).long().unsqueeze(0).expand(
len(batch), max_decoder_target_len)
len(batch), max_decoder_target_len).clone()

# done flags
done = np.array([_pad(np.zeros(len(x[1]) // r // downsample_step - 1),
Expand Down Expand Up @@ -963,7 +965,7 @@ def restore_parts(path, model):
data_loader = data_utils.DataLoader(
dataset, batch_size=hparams.batch_size,
num_workers=hparams.num_workers, sampler=sampler,
collate_fn=collate_fn, pin_memory=hparams.pin_memory)
collate_fn=collate_fn, pin_memory=hparams.pin_memory, drop_last=True)

device = torch.device("cuda" if use_cuda else "cpu")

Expand Down

0 comments on commit 897f31e

Please sign in to comment.