Skip to content

Commit

Permalink
vits inference success run
Browse files Browse the repository at this point in the history
  • Loading branch information
MingjieChen committed Mar 2, 2023
1 parent b6b1a37 commit 8fa0695
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
9 changes: 6 additions & 3 deletions decoder/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .taco_ar.model import Model as TacoAR
from .taco_mol.model import MelDecoderMOLv2 as TacoMOL
from .vits.models import VITS
from .vits.utils import load_checkpoint as load_vits_checkpoint
from .grad_tts.grad_tts_model import GradTTS
import torch
import yaml
Expand All @@ -23,8 +22,12 @@ def load_VITS(ckpt = None, config = None, device = 'cpu'):
with open(config) as f:
model_config = yaml.safe_load(f)
f.close()
model = VITS(model_config)
model = load_vits_checkpoint(ckpt, model, None)
model = VITS(model_config['decoder_params'])
params = torch.load(ckpt, map_location = torch.device(device))
params = params['model']
params = remove_module_from_state_dict(params)
model.load_state_dict(params)

generator = model.generator
generator.dec.remove_weight_norm()
generator.to(device)
Expand Down
4 changes: 3 additions & 1 deletion decoder/vits/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,11 @@ def forward(self, spec, spec_lengths, ling, spk, pros):
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)

def infer(self, ling, lengths, pros, spk):
if self.prosodic_net is not None and pros is not None:
pros = self.prosodic_net(pros)
z_p, m_p, logs_p, mask = self.enc_p(ling, lengths, pros)
z = self.flow(z_p, mask, g=spk, reverse=True)
o = self.dec((z * mask), g=spk, pros = pros)
o = self.dec((z * mask), g=spk)
return o


6 changes: 3 additions & 3 deletions submit_inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ split=eval_all
ling_enc=vqwav2vec
spk_enc=uttdvec
pros_enc=ppgvcf0
dec=gradtts
vocoder=ppgvchifigan
dec=vits
vocoder=none

# exp setup
exp_name=vctk_first_train
exp_name=vctk_no16fp_split
exp_dir=exp/${dataset}_${ling_enc}_${spk_enc}_${pros_enc}_${dec}_${vocoder}/${exp_name}


Expand Down

0 comments on commit 8fa0695

Please sign in to comment.