Skip to content

Commit

Permalink
vits start training
Browse files Browse the repository at this point in the history
  • Loading branch information
MingjieChen committed Feb 20, 2023
1 parent 043ebb7 commit cba8783
Show file tree
Hide file tree
Showing 11 changed files with 561 additions and 501 deletions.
99 changes: 99 additions & 0 deletions configs/vctk_vqw2v_uttdvec_fs2pitchenergy_vits.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# experiment
dataset: vctk
train_meta: data/vctk/train_nodev_all/metadata.csv
dev_meta: data/vctk/dev_all/metadata.csv
train_set: train_nodev_all
dev_set: dev_all


# encoder-decoder
ling_enc: vqwav2vec
spk_enc: utt_dvec
pros_enc: norm_fastspeech2_pitch_energy
decoder: VITS
mel_type: vits_spec

# training
fp16_run: !!bool True
epochs: 200
save_freq: 2 # save ckpt frequency
show_freq: 100 # show training information frequency
load_only_params: !!bool False
seed: !!int 1234
trainer: VITSTrainer
ngpu: 1

#dataloader
dataset_class: VITSDataset
sampling_rate: !!int 24000
vits_hop_size: !!int 240
sort: !!bool False
dump_dir: dump
num_workers: !!int 8
batch_size: !!int 16
drop_last: !!bool True
rm_long_utt: !!bool True # remove too long utterances from metadata
max_utt_duration: !!float 10.0 # max utterance duration (seconds)


# decoder params
decoder_params:
spk_emb_dim: 256
prosodic_net:
hidden_dim: !!int 192
prosodic_bins: !!int 256
prosodic_stats_path: dump/vctk/train_nodev_all/fastspeech2_pitch_energy/pitch_energy_min_max.npy
input_dim: !!int 512
spec_channels: !!int 513
inter_channels: !!int 192
hidden_channels: !!int 192
filter_channels: !!int 768
n_heads: !!int 2
n_layers: !!int 6
kernel_size: !!int 3
p_dropout: !!float 0.1
resblock : 1
resblock_kernel_sizes: [3,7,11]
resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
upsample_rates: [10,6,2,2]
upsample_initial_channel: !!int 512
upsample_kernel_sizes: [20, 12, 4, 4]
n_layers_q: !!int 3
use_spectral_norm: !!bool False
filter_length: !!int 1024
n_mels_channels: !!int 80
win_length: !!int 1024
hop_length: !!int 240
sampling_rate: !!int 24000
segment_size: !!int 9600




#optimizer & scheduler
optimizer:
generator:
lr: !!float 2e-4
betas: [0.8,0.99]
eps: !!float 1e-9
discriminator:
lr: !!float 2e-4
betas: [0.8,0.99]
eps: !!float 1e-9
scheduler:
generator:
lr_decay: !!float 0.999875
discriminator:
lr_decay: !!float 0.999875

# loss hyper-parameters
losses:
mel: !!int 45
kl: !!int 1







97 changes: 55 additions & 42 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import json
import csv
import random
import librosa
from torch.utils.data import DataLoader
from collections import defaultdict
from prosodic_encoder.ppgvc_f0.ppgvc_lf0 import get_cont_lf0 as process_ppgvc_f0
from prosodic_encoder.fastspeech2_pitch_energy.pitch_energy import process_norm_fastspeech2_pitch_energy
import decoder.vits.commons as vits_commons
from decoder.vits.mel_processing import vits_spectrogram_torch
def get_dataloader(config):
train_dataset = Dataset(config, config['train_meta'], config['train_set'])
dev_dataset = Dataset(config, config['dev_meta'], config['dev_set'])

train_dataset = eval(config['dataset_class'])(config, config['train_meta'], config['train_set'])
dev_dataset = eval(config['dataset_class'])(config, config['dev_meta'], config['dev_set'])

if config['ngpu'] >1:
shuffle = False
Expand Down Expand Up @@ -88,6 +89,10 @@ def __init__(self, config, metadata_csv, split):
super().__init__()
self.metadata = []

# setup
self.hop_size = config['vits_hop_size']
self.sampling_rate = config['sampling_rate']
self.segment_size = config['decoder_params']['segment_size'] # random slice segment size of the HIfIGAN in the VITS model.
# read metadata
with open(metadata_csv) as f:
reader = csv.DictReader(f, delimiter = ',')
Expand All @@ -113,9 +118,7 @@ def __init__(self, config, metadata_csv, split):
self.pros_enc = config['pros_enc'] #e.g. ppgvc_f0
self.pros_rep_dir = os.path.join(config['dump_dir'], config['dataset'], split, self.pros_enc)
self.pros_rep_process_func = f'process_{self.pros_enc}'
# frames per step (only work for TacoMOL)
self.frames_per_step = config['frames_per_step'] if 'frames_per_step' in config else 1





Expand All @@ -129,11 +132,10 @@ def __getitem__(self, idx):
wav_path = row['wav_path']
start, end = float(row['start']), float(row['end'])
# audio
audio, fs = librosa.load(wav_path, sr = config['sampling_rate'])
audio = audio[ int(start * config['sampling_rate']):
int(end * config['sampling_rate'])
audio, fs = librosa.load(wav_path, sr = self.sampling_rate)
audio = audio[ int(start * self.sampling_rate):
int(end * self.sampling_rate)
]
audio_tensor = audio


# feature path
Expand All @@ -143,14 +145,23 @@ def __getitem__(self, idx):
spk_emb_path = os.path.join(self.spk_emb_dir, spk, ID+'.npy')
pros_rep_path = os.path.join(self.pros_rep_dir, spk, ID + '.npy')

assert os.path.exists(mel_path), f"{spec_path}"
assert os.path.exists(spec_path), f"{spec_path}"
assert os.path.exists(ling_rep_path), f'{ling_rep_path}'
assert os.path.exists(spk_emb_path), f'{spk_emb_path}'
assert os.path.exists(pros_rep_path), f'{pros_rep_path}'

# load feature
spec = np.load(spec_path)
spec_duration = spec.shape[0]

# pad spec to match the segment_size
spec_segment_size = self.segment_size // self.hop_size
if spec_duration < spec_segment_size:
spec_pad_length = spec_segment_size - spec_duration
spec = np.pad(spec, [[0,spec_pad_length],[0,0]], mode = 'constant', constant_values = 0.)
assert spec.shape[0] >= spec_segment_size
spec_duration = spec.shape[0]
audio_duration = audio.shape[0]
ling_rep = np.load(ling_rep_path)
ling_duration = ling_rep.shape[0]
spk_emb = np.load(spk_emb_path)
Expand All @@ -165,22 +176,32 @@ def __getitem__(self, idx):
ling_duration = ling_rep.shape[0]


# match length between mel and ling_rep
# match length between spec and ling_rep
if spec_duration > ling_duration :
pad_vec = np.expand_dims(ling_rep[-1,:], axis = 0)
ling_rep = np.concatenate((ling_rep, np.repeat(pad_vec, spec_duration - ling_duration, 0)),0)

elif spec_duration < ling_duration:
ling_rep = ling_rep[:spec_duration,:]

# match length between mel and pros_rep
# match length between spec and pros_rep
if spec_duration > pros_duration:
pad_vec = np.expand_dims(pros_rep[-1,:],axis = 0)
pros_rep = np.concatenate((pros_rep, np.repeat(pad_vec, spec_duration - pros_duration, 0)),0)
elif spec_duration < pros_duration:
pros_rep = pros_rep[:spec_duration,:]


# match length between audio and spec
if audio_duration < int(spec_duration * self.hop_size):
#pad
pad_length = int(spec_duration * self.hop_size) - audio_duration
audio = np.concatenate([audio, np.array([0.]*pad_length)], axis = 0)
elif audio_duration > int(spec_duration * self.hop_size):
audio = audio[:int(spec_duration * self.hop_size)]



return (audio, spec, ling_rep, pros_rep, spk_emb, spec_duration)

def collate_fn(self, data):
Expand All @@ -191,35 +212,27 @@ def collate_fn(self, data):
idx_arr = np.argsort(~len_arr)
else:
idx_arr = np.arange(batch_size)
#audio = [ data[id][0] for id in idx_arr]
#spec = [data[id][1] for id in idx_arr]
#ling_rep = [ data[id][2] for id in idx_arr]
#pros_rep = [ data[id][3] for id in idx_arr]
#spk_emb = [ data[id][4] for id in idx_arr]
#spec_length = [ data[id][5] for id in idx_arr ]
#audio_length = [len(_audio) for _audio in audio]

#max_spec_len = max(spec_length)
#max_wav_len = max(audio_length)

#padded_mel = torch.FloatTensor(pad_2D(mel, max_len))
#padded_ling_rep = torch.FloatTensor(pad_2D(ling_rep, max_len))
#padded_pros_rep = torch.FloatTensor(pad_2D(pros_rep, max_len))
#spk_emb_tensor = torch.FloatTensor(np.array(spk_emb))
#length = torch.LongTensor(np.array(length))

max_spec_len = max([x[1].shape[1] for x in batch])
max_wav_len = max([x[0].shape[1] for x in batch])

spec_lengths = torch.LongTensor(len(batch))
wav_lengths = torch.LongTensor(len(batch))

spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
text_padded.zero_()
spec_padded.zero_()
wav_padded.zero_()
output = (padded_mel, padded_ling_rep, padded_pros_rep, spk_emb_tensor, length, max_len)
audio = [ data[id][0] for id in idx_arr]
spec = [data[id][1] for id in idx_arr]
ling_rep = [ data[id][2] for id in idx_arr]
pros_rep = [ data[id][3] for id in idx_arr]
spk_emb = [ data[id][4] for id in idx_arr]
spec_length = [ data[id][5] for id in idx_arr ]
audio_length = [len(_audio) for _audio in audio]

max_spec_len = max(spec_length)
max_wav_len = max(audio_length)

padded_audio = torch.FloatTensor(pad_1D(audio, max_wav_len)).unsqueeze(1)
padded_spec = torch.FloatTensor(pad_2D(spec, max_spec_len)).transpose(1,2)
padded_ling_rep = torch.FloatTensor(pad_2D(ling_rep, max_spec_len)).transpose(1,2)
padded_pros_rep = torch.FloatTensor(pad_2D(pros_rep, max_spec_len)).transpose(1,2)
spk_emb_tensor = torch.FloatTensor(np.array(spk_emb)).unsqueeze(2)
spec_length = torch.LongTensor(np.array(spec_length))
audio_length = torch.LongTensor(np.array(audio_length))


output = (padded_audio, padded_spec, padded_ling_rep, padded_pros_rep, spk_emb_tensor, spec_length, audio_length)

return output

Expand Down
21 changes: 10 additions & 11 deletions decoder/vits/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from torch import nn
from torch.nn import functional as F

import commons
import modules
from modules import LayerNorm
from .commons import convert_pad_shape, subsequent_mask
from .modules import LayerNorm


class Encoder(nn.Module):
Expand Down Expand Up @@ -79,7 +78,7 @@ def forward(self, x, x_mask, h, h_mask):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
self_attn_mask = subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
Expand Down Expand Up @@ -205,7 +204,7 @@ def _get_relative_embeddings(self, relative_embeddings, length):
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
Expand All @@ -218,11 +217,11 @@ def _relative_position_to_absolute_position(self, x):
"""
batch, heads, length, _ = x.size()
# Concat columns of pad to shift from relative to absolute indexing.
x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
x = F.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))

# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
x_flat = F.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))

# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
Expand All @@ -235,10 +234,10 @@ def _absolute_position_to_relative_position(self, x):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
x_flat = x.view([batch, heads, length**2 + length*(length -1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
return x_final

Expand Down Expand Up @@ -290,7 +289,7 @@ def _causal_padding(self, x):
pad_l = self.kernel_size - 1
pad_r = 0
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x

def _same_padding(self, x):
Expand All @@ -299,5 +298,5 @@ def _same_padding(self, x):
pad_l = (self.kernel_size - 1) // 2
pad_r = self.kernel_size // 2
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
x = F.pad(x, commons.convert_pad_shape(padding))
x = F.pad(x, convert_pad_shape(padding))
return x
Loading

0 comments on commit cba8783

Please sign in to comment.