Skip to content

Commit

Permalink
add a spec_max_len parameter for faster training of VITS
Browse files Browse the repository at this point in the history
  • Loading branch information
MingjieChen committed Feb 20, 2023
1 parent cba8783 commit 184178a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
11 changes: 6 additions & 5 deletions configs/vctk_vqw2v_uttdvec_fs2pitchenergy_vits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ show_freq: 100 # show training information frequency
load_only_params: !!bool False
seed: !!int 1234
trainer: VITSTrainer
ngpu: 1
ngpu: 2

#dataloader
dataset_class: VITSDataset
sampling_rate: !!int 24000
vits_hop_size: !!int 240
sort: !!bool False
spec_max_len: !!int 360
sort: !!bool True
dump_dir: dump
num_workers: !!int 8
batch_size: !!int 16
num_workers: !!int 4
batch_size: !!int 12
drop_last: !!bool True
rm_long_utt: !!bool True # remove too long utterances from metadata
rm_long_utt: !!bool False # remove too long utterances from metadata
max_utt_duration: !!float 10.0 # max utterance duration (seconds)


Expand Down
14 changes: 14 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, config, metadata_csv, split):
self.hop_size = config['vits_hop_size']
self.sampling_rate = config['sampling_rate']
self.segment_size = config['decoder_params']['segment_size'] # random slice segment size of the HIfIGAN in the VITS model.
self.spec_max_len = config['spec_max_len']
# read metadata
with open(metadata_csv) as f:
reader = csv.DictReader(f, delimiter = ',')
Expand All @@ -102,6 +103,8 @@ def __init__(self, config, metadata_csv, split):
_duration = row['duration']
if float(_duration) < config['max_utt_duration']:
self.metadata.append(row)
else:
self.metadata.append(row)
f.close()

print(f'{split} data samples {len(self.metadata)}')
Expand Down Expand Up @@ -200,6 +203,17 @@ def __getitem__(self, idx):
elif audio_duration > int(spec_duration * self.hop_size):
audio = audio[:int(spec_duration * self.hop_size)]

# slice by spec_max_len
if spec_duration > self.spec_max_len:
start = random.randint(0, spec_duration - self.spec_max_len)
end = start + self.spec_max_len
spec_duration = self.spec_max_len
spec = spec[start:end, :]
ling_rep = ling_rep[start:end, :]
pros_rep = pros_rep[start:end, :]
audio = audio[start * self.hop_size: end *self.hop_size]




return (audio, spec, ling_rep, pros_rep, spk_emb, spec_duration)
Expand Down
1 change: 0 additions & 1 deletion decoder/vits/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .commons import slice_segments
def compute_g_loss(model, batch, config):
y, spec, ling, pros, spk_emb, spec_lengths, audio_length = batch

with autocast(enabled=config['fp16_run']):
y_hat, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = model.generator(spec, spec_lengths, ling, spk_emb, pros)

Expand Down
19 changes: 14 additions & 5 deletions decoder/vits/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def __init__(self,

self.args = args
self.epochs = initial_epochs
self.model = model
self.model_ema = model_ema
self.train_dataloader = train_dataloader
self.dev_dataloader = dev_dataloader
Expand All @@ -42,6 +41,10 @@ def __init__(self,
self.fp16_run = fp16_run
self.step_writer = step_writer
self.timer = timer
if self.config['ngpu'] > 1:
self.model = model.module
else:
self.model = model
print(f'trainer device {self.device}')
self.iters = 0
self.optim_g = torch.optim.AdamW(
Expand All @@ -54,6 +57,7 @@ def __init__(self,
config['optimizer']['discriminator']['lr'],
config['optimizer']['discriminator']['betas'],
eps=config['optimizer']['discriminator']['eps'])

self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR(self.optim_g, gamma=config['scheduler']['generator']['lr_decay'], last_epoch=self.epochs - 1)
self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR(self.optim_d, gamma=config['scheduler']['discriminator']['lr_decay'], last_epoch=self.epochs - 1)

Expand All @@ -63,7 +67,10 @@ def save_checkpoint(self, checkpoint_path):
checkpoint_path (str): Checkpoint path to be saved.
"""
state_dict = {
"optimizer": self.optimizer.state_dict(),
"optim_g": self.optim_g.state_dict(),
"optim_d": self.optim_d.state_dict(),
"sched_g": self.scheduler_g.state_dict(),
"sched_d": self.scheduler_d.state_dict(),
"epochs": self.epochs,
"model": self.model.state_dict(),
"iters": self.iters
Expand Down Expand Up @@ -92,8 +99,10 @@ def load_checkpoint(self, checkpoint_path, load_only_params=False):
if not load_only_params:
self.epochs = state_dict["epochs"]
self.iters = state_dict['iters']
self.optimizer.load_state_dict(state_dict["optimizer"])
self.scheduler.current_step = self.iters
self.optim_g.load_state_dict(state_dict["optim_g"])
self.optim_d.load_state_dict(state_dict["optim_d"])
self.scheduler_g.load_state_dict(state_dict['sched_g'])
self.scheduler_d.load_state_dict(state_dict['sched_d'])


def _load(self, states, model, force_load=True):
Expand Down Expand Up @@ -148,7 +157,7 @@ def _train_epoch(self):
scaler = torch.cuda.amp.GradScaler() if (('cuda' in str(self.device)) and self.fp16_run) else None


for train_steps_per_epoch, batch in tqdm(enumerate(self.train_dataloader, 1)):
for train_steps_per_epoch, batch in tqdm(enumerate(self.train_dataloader, 1), total = len(self.train_dataloader)):
_batch = []
for b in batch:
if isinstance(b, torch.Tensor):
Expand Down
5 changes: 3 additions & 2 deletions submit_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ exp_dir=exp
model_name=${ling}_${spk}_${pros}_${dec}
exp=$exp_dir/$model_name/$exp_name
njobs=1
ngpus=1
ngpus=2
slots=4
#gputypes="GeForceRTX3060|GeForceRTX3090"
gputypes="GeForceRTX3090"
#gputypes="GeForceRTX3090"
#gputypes="GeForceGTXTITANX|GeForceGTX1080Ti|GeForceRTX3060"
gputypes="GeForceGTX1080Ti"

# create exp dir
[ ! -e $exp ] && mkdir -p $exp
Expand Down

0 comments on commit 184178a

Please sign in to comment.