Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update modules #59

Merged
merged 31 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c8d6f86
VALLE add continual inference
lifeiteng Mar 19, 2023
91ecd50
separate text embedding & position of AR and NAR Decoders
lifeiteng Mar 19, 2023
e34c101
Separate Modules of AR and NAR Decoders
lifeiteng Mar 19, 2023
486898d
Support train AR Decoder and NAR Decoder separately
lifeiteng Mar 19, 2023
1297357
Copy transformer modules from pytorch
lifeiteng Mar 20, 2023
b6a824c
update trainer.py
lifeiteng Mar 20, 2023
aced965
Implement InputStrategy PromptedPrecomputedFeatures
lifeiteng Mar 20, 2023
7afedd5
VALL-E Add prefix_mode=4
lifeiteng Mar 20, 2023
fbb3fbc
Fix InputStrategy PromptedPrecomputedFeatures
lifeiteng Mar 20, 2023
4c05d68
Fix InputStrategy PromptedPrecomputedFeatures
lifeiteng Mar 20, 2023
cfe4965
LibriTTS update README
lifeiteng Mar 21, 2023
5c4f85f
use load_manifest_lazy
lifeiteng Mar 22, 2023
0f0c7fd
Fix index of PromptedPrecomputedFeatures
lifeiteng Mar 22, 2023
db5997c
Trainer - Add config --filter-min-duration
lifeiteng Mar 22, 2023
e7162e5
Unify Prefix Mode 2 and 4
lifeiteng Mar 22, 2023
751c226
update trainer
lifeiteng Mar 26, 2023
637c476
Add Hparam --share-embedding
lifeiteng Mar 26, 2023
a50b5b4
Merge branch 'prefix4' into stage
lifeiteng Mar 26, 2023
f6f3017
Fix Hparam --share-embedding
lifeiteng Mar 26, 2023
140a0b9
Fix MultiGPU load_checkpoint
lifeiteng Mar 31, 2023
7657ef6
Tune prefix_mode 1
lifeiteng Mar 31, 2023
a952f95
valid every epoch
lifeiteng Mar 31, 2023
51a6955
update --train-stage logic
lifeiteng Mar 31, 2023
e55582f
set NUM_TEXT_TOKENS=512 for multi-language models
lifeiteng Mar 31, 2023
d34b025
VALLF support --train-stage
lifeiteng Mar 31, 2023
8a8facf
VALLF support --prefix-mode
lifeiteng Mar 31, 2023
7e3bb2f
Fix VALl-F test
lifeiteng Apr 3, 2023
7d6b721
Fix DDP --train-stage
lifeiteng Apr 4, 2023
9acece1
Add model hparam --scale-factor
lifeiteng Apr 4, 2023
5154048
VALL-E & F update embedding sharing and inference sampling
lifeiteng Apr 6, 2023
cf9f26c
egs rename run.sh to prepare.sh and simplify README
lifeiteng Apr 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
VALL-E Add prefix_mode=4
  • Loading branch information
lifeiteng committed Mar 20, 2023
commit 7afedd5bb460ebf1af168be75645371690d8b07f
46 changes: 40 additions & 6 deletions valle/models/valle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from icefall.utils import make_pad_mask
from torchmetrics.classification import MulticlassAccuracy

from valle.data.input_strategies import PromptedFeatures
from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from valle.modules.transformer import (
AdaptiveLayerNorm,
Expand Down Expand Up @@ -287,8 +288,8 @@ def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
y_lens: torch.Tensor,
y: Union[torch.Tensor, PromptedFeatures],
y_lens: Union[torch.Tensor, PromptedFeatures],
reduction: str = "sum",
train_stage: int = 0,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
Expand Down Expand Up @@ -570,8 +571,8 @@ def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
y_lens: torch.Tensor,
y: Union[torch.Tensor, PromptedFeatures],
y_lens: Union[torch.Tensor, PromptedFeatures],
reduction: str = "sum",
train_stage: int = 0,
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
Expand All @@ -594,9 +595,17 @@ def forward(
"""
assert x.ndim == 2, x.shape
assert x_lens.ndim == 1, x_lens.shape

y_prompts_codes = None
if isinstance(y, PromptedFeatures):
y_prompts_codes, y = y.data
prompts_len, y_lens = y_lens.data
assert prompts_len.min() == prompts_len.max()
assert self.prefix_mode == 4
y_prompts_codes = y_prompts_codes.type(torch.int64)

assert y.ndim == 3, y.shape
assert y_lens.ndim == 1, y_lens.shape

assert torch.all(x_lens > 0)

# NOTE: x has been padded in TextTokenCollater
Expand Down Expand Up @@ -756,6 +765,31 @@ def pad_y_eos(y, eos_id):
)

prefix_len = 0
elif self.prefix_mode == 4:
assert y_prompts_codes is not None
y_prompts = self.nar_audio_embeddings[0](
y_prompts_codes[..., 0]
)
y_emb = self.nar_audio_embeddings[0](y)
for j in range(1, 8):
y_prompts += self.nar_audio_embeddings[j](
y_prompts_codes[..., j]
)
if j < nar_stage:
y_emb += self.nar_audio_embeddings[j](codes[..., j])
y_emb = torch.concat([y_prompts, y_emb], axis=1)

prompts_len += y_prompts.shape[1]
xy_padding_mask = torch.concat(
[
x_mask,
F.pad(y_mask, (y_prompts.shape[1], 0), value=False),
],
dim=1,
)

prefix_len = 0

else:
raise ValueError

Expand Down Expand Up @@ -902,7 +936,7 @@ def inference(
# Non-AR Decoders
y_emb = self.nar_audio_embeddings[0](y)

if self.prefix_mode == 2: # Exclude enrolled_phonemes
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
enrolled_len = enroll_x_lens.max().item()
# SOS + Synthesis Text + EOS
x = torch.concat(
Expand Down
44 changes: 44 additions & 0 deletions valle/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from icefall.utils import AttributeDict
from torchmetrics.classification import MulticlassAccuracy

from valle.data.input_strategies import PromptedFeatures
from valle.models import NUM_MEL_BINS, get_model


Expand Down Expand Up @@ -107,6 +108,49 @@ def test_valle(self):
x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens
)

def test_valle_prefix4(self):
params = AttributeDict()
params.decoder_dim = 64
params.nhead = 16
params.num_decoder_layers = 4

x = torch.from_numpy(np.random.randint(0, 100, size=[4, 8]))
x_lens = torch.from_numpy(np.random.randint(4, 8, size=[4]))
x_lens[-1] = 8
enroll_x_lens = torch.from_numpy(np.random.randint(1, 3, size=[4]))

y = torch.from_numpy(np.random.randint(0, 1000, size=[4, 16, 8]))
y_lens = torch.from_numpy(np.random.randint(8, 16, size=[4]))
y_lens[-1] = 16

prompts = torch.from_numpy(np.random.randint(0, 1000, size=[4, 12, 8]))
prompts_lens = torch.from_numpy(np.random.randint(12, 13, size=[4]))

params.norm_first = False
params.add_prenet = True
params.model_name = "VALL-E"

for device in self.devices:
for mode in [4]:
params.prefix_mode = mode
# VALL-E
model = get_model(params)
model.to(device)
x = x.to(device)
x_lens = x_lens.to(device)
y = y.to(device)

_y = PromptedFeatures(prompts, y).to(device)
_y_lens = PromptedFeatures(prompts_lens, y_lens).to(device)

# Training
codes, loss, metrics = model(x, x_lens, _y, _y_lens)
# Inference
model.eval()
codes = model.inference(
x[-1:], x_lens[-1:], y[-1:], enroll_x_lens=enroll_x_lens
)

def test_topmetric(self):
metric_top10 = MulticlassAccuracy(1024, top_k=10, average="micro")
metric_top1 = MulticlassAccuracy(1024, top_k=1, average="micro")
Expand Down