Skip to content

Commit

Permalink
1. remove vits fp16 training; 2. grad_tts inference;
Browse files Browse the repository at this point in the history
  • Loading branch information
MingjieChen committed Mar 1, 2023
1 parent cd88d5c commit 52aae7d
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 9 deletions.
4 changes: 2 additions & 2 deletions configs/vctk_vqwav2vec_uttdvec_ppgvcf0_vits_none.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mel_type: vits_spec


# training
fp16_run: !!bool True
fp16_run: !!bool False
epochs: 200
save_freq: 2 # save ckpt frequency
show_freq: 100 # show training information frequency
Expand All @@ -28,7 +28,7 @@ ngpu: 2
dataset_class: VITSDataset
sampling_rate: !!int 24000
vits_hop_size: !!int 240
spec_max_len: !!int 360
spec_max_len: !!int 240
sort: !!bool True
dump_dir: dump
num_workers: !!int 4
Expand Down
38 changes: 35 additions & 3 deletions decoder/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
from .taco_mol.model import MelDecoderMOLv2 as TacoMOL
from .vits.models import VITS
from .vits.utils import load_checkpoint as load_vits_checkpoint
from .grad_tts.grad_tts_model import GradTTS
import torch
import yaml

def remove_module_from_state_dict(state_dict):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
if 'module' in k:
name = k[7:] # remove `module.`
new_state_dict[name] = v
else:
new_state_dict[k] = v
return new_state_dict


Expand All @@ -28,7 +32,22 @@ def load_VITS(ckpt = None, config = None, device = 'cpu'):
return generator


def load_GradTTS(ckpt = None, config = None, device = 'cpu'):
with open(config) as f:
model_config = yaml.safe_load(f)
f.close()


model = GradTTS(model_config['decoder_params'])
params = torch.load(ckpt, map_location = torch.device(device))
params = params['model']
params = remove_module_from_state_dict(params)

model.load_state_dict(params)
model.to(device)
model.eval()
return model


def load_FastSpeech2(ckpt = None, config = None, device = 'cpu'):
with open(config) as f:
Expand Down Expand Up @@ -106,4 +125,17 @@ def infer_TacoMOL(model, ling, pros, spk):

_, mel, _ = model.inference(ling, pros, spk)
return mel


def infer_GradTTS(model, ling, pros, spk):
ling = ling.transpose(1,2)
pros = pros.transpose(1,2)
if ling.size(2) %4 != 0:
pad_length = ling.size(2) % 4
ling = torch.nn.functional.pad(ling, [0, pad_length])
pros = torch.nn.functional.pad(pros, [0, pad_length])


ling_lengths = torch.LongTensor([ling.size(2)]).to(ling.device)
mel = model(ling, ling_lengths, spk, pros, 30)
mel = mel.transpose(1,2)
return mel
3 changes: 3 additions & 0 deletions decoder/vits/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def _train_epoch(self):
g_loss, g_losses = compute_g_loss(self.model, _batch, self.config)
self.timer.cnt('fw')
d_loss.backward()
g_loss.backward()
grad_norm_d = clip_grad_value_(self.model.discriminator.parameters(), None)
grad_norm_g = clip_grad_value_(self.model.generator.parameters(), None)
self.optim_d.step()
self.optim_g.step()
self.timer.cnt('bw')
Expand Down
4 changes: 4 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from tqdm import tqdm
from scipy.io import wavfile
import resampy
import logging
import logging
logger = logging.getLogger('log')
logger.setLevel(logging.WARNING)
from ling_encoder.interface import *
from speaker_encoder.interface import *
from prosodic_encoder.interface import *
Expand Down
2 changes: 1 addition & 1 deletion prosodic_encoder/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def infer_ppgvc_f0(source_wav, target_wav, config_path = 'configs/preprocess_ppg
ref_lf0_mean, ref_lf0_std = compute_mean_std(target_lf0)
src_wav, _ = librosa.load(source_wav, sr=config['sampling_rate'])
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True, sr = config['sampling_rate'])
lf0_uv = torch.FloatTensor([lf0_uv])
lf0_uv = torch.FloatTensor(lf0_uv).unsqueeze(0)
return lf0_uv
2 changes: 1 addition & 1 deletion submit_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ split=eval_all
ling_enc=vqwav2vec
spk_enc=uttdvec
pros_enc=ppgvcf0
dec=tacoar
dec=gradtts
vocoder=ppgvchifigan

# exp setup
Expand Down
4 changes: 2 additions & 2 deletions submit_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dec=vits
#vocoder=ppgvchifigan
vocoder=none

exp_name=vctk_first_train
exp_name=vctk_no16fp
config=configs/${dataset}_${ling}_${spk}_${pros}_${dec}_${vocoder}.yaml
if [ ! -e $config ] ; then
echo "can't find config file $config"
Expand All @@ -34,7 +34,7 @@ slots=4
#gputypes="GeForceRTX3060|GeForceRTX3090"
#gputypes="GeForceRTX3090"
#gputypes="GeForceGTXTITANX|GeForceGTX1080Ti|GeForceRTX3060"
gputypes="GeForceGTX1080Ti|GeforceRTX3090|GeForceRTX3060"
gputypes="GeForceGTX1080Ti|GeforceRTX3090"

# create exp dir
[ ! -e $exp ] && mkdir -p $exp
Expand Down

0 comments on commit 52aae7d

Please sign in to comment.