Skip to content

Commit

Permalink
Merge pull request PlayVoice#19 from NaruseMioShirakana/32k
Browse files Browse the repository at this point in the history
Fix Readme, onnx_export.py and Gradio
  • Loading branch information
innnky committed Jan 16, 2023
2 parents 669f1b8 + 0ffb1eb commit aa31743
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 3 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
# SoftVC VITS Singing Voice Conversion
## English docs
[英语资料](Eng_docs.md)

---
## Update
> 据不完全统计,多说话人似乎会导致**音色泄漏加重**,不建议训练超过5人的模型,目前的建议是如果想炼出来更像目标音色,**尽可能炼单说话人的**\
> 断音问题已解决,音质提升了不少\
> 2.0版本已经移至 sovits_2.0分支\
> 3.0版本使用FreeVC的代码结构,与旧版本不通用\
> [DiffSVC](https://github.com/prophesier/diff-svc) 相比,在训练数据质量非常高时diffsvc有着更好的表现,对于质量差一些的数据集,本仓库可能会有更好的表现,此外,本仓库推理速度上比diffsvc快很多
---
## 模型简介
歌声音色转换模型,通过SoftVC内容编码器提取源音频语音特征,与F0同时输入VITS替换原本的文本输入达到歌声转换的效果。同时,更换声码器为 [NSF HiFiGAN](https://github.com/openvpi/DiffSinger/tree/refactor/modules/nsf_hifigan) 解决断音问题

---
## 注意
+ 当前分支是32khz版本的分支,32khz模型推理更快,显存占用大幅减小,数据集所占硬盘空间也大幅降低,推荐训练该版本模型
+ 如果要训练48khz的模型请切换到[main分支](https://github.com/innnky/so-vits-svc/tree/main)

---
## 预先下载的模型文件
+ soft vc hubert:[hubert-soft-0d54a1f4.pt](https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt)
+ 放在hubert目录下
Expand All @@ -31,8 +38,12 @@ wget -P logs/32k/ https://huggingface.co/innnky/sovits_pretrained/resolve/main/G
wget -P logs/32k/ https://huggingface.co/innnky/sovits_pretrained/resolve/main/D_0.pth

```

---
## colab一键数据集制作、训练脚本
[一键colab](https://colab.research.google.com/drive/1_-gh9i-wCPNlRZw6pYF-9UufetcVrGBX?usp=sharing)

---
## 数据集准备
仅需要以以下文件结构将数据集放入dataset_raw目录即可
```shell
Expand All @@ -47,6 +58,7 @@ dataset_raw
└───xxx7-xxx007.wav
```

---
## 数据预处理
1. 重采样至 32khz

Expand All @@ -68,17 +80,39 @@ python preprocess_hubert_f0.py
```
执行完以上步骤后 dataset 目录便是预处理完成的数据,可以删除dataset_raw文件夹了

---
## 训练
```shell
python train.py -c configs/config.json -m 32k
```

---
## 推理

使用[inference_main.py](inference_main.py)
使用 [inference_main.py](inference_main.py)
+ 更改model_path为你自己训练的最新模型记录点
+ 将待转换的音频放在raw文件夹下
+ clean_names 写待转换的音频名称
+ trans 填写变调半音数量
+ spk_list 填写合成的说话人名称

---
## Onnx导出
使用 [onnx_export.py](onnx_export.py)
+ 新建文件夹:checkpoints 并打开
+ 在checkpoints文件夹中新建一个文件夹作为项目文件夹,文件夹名为你的项目名称
+ 将你的模型更名为model.pth,配置文件更名为config.json,并放置到刚才创建的文件夹下
+[onnx_export.py](onnx_export.py) 中path = "NyaruTaffy" 的 "NyaruTaffy" 修改为你的项目名称
+ 运行 [onnx_export.py](onnx_export.py)
+ 等待执行完毕,在你的项目文件夹下会生成一个model.onnx,即为导出的模型
### Onnx模型支持的UI
+ [MoeSS](https://github.com/NaruseMioShirakana/MoeSS)
---
## Gradio(WebUI)
使用 [sovits_gradio.py](sovits_gradio.py)
+ 新建文件夹:checkpoints 并打开
+ 在checkpoints文件夹中新建一个文件夹作为项目文件夹,文件夹名为你的项目名称
+ 将你的模型更名为model.pth,配置文件更名为config.json,并放置到刚才创建的文件夹下
+ 运行 [sovits_gradio.py](sovits_gradio.py)

---
160 changes: 160 additions & 0 deletions inference/infer_tool_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import hashlib
import json
import logging
import os
import time
from pathlib import Path
import io
import librosa
import maad
import numpy as np
from inference import slicer
import parselmouth
import soundfile
import torch
import torchaudio

from hubert import hubert_model
import utils
from models import SynthesizerTrn
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)

def resize2d_f0(x, target_len):
source = np.array(x)
source[source < 0.001] = np.nan
target = np.interp(np.arange(0, len(source) * target_len, len(source)) / target_len, np.arange(0, len(source)),
source)
res = np.nan_to_num(target)
return res

def get_f0(x, p_len,f0_up_key=0):

time_step = 160 / 16000 * 1000
f0_min = 50
f0_max = 1100
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)

f0 = parselmouth.Sound(x, 16000).to_pitch_ac(
time_step=time_step / 1000, voicing_threshold=0.6,
pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']

pad_size=(p_len - len(f0) + 1) // 2
if(pad_size>0 or p_len - len(f0) - pad_size>0):
f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')

f0 *= pow(2, f0_up_key / 12)
f0_mel = 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int)
return f0_coarse, f0

def clean_pitch(input_pitch):
num_nan = np.sum(input_pitch == 1)
if num_nan / len(input_pitch) > 0.9:
input_pitch[input_pitch != 1] = 1
return input_pitch


def plt_pitch(input_pitch):
input_pitch = input_pitch.astype(float)
input_pitch[input_pitch == 1] = np.nan
return input_pitch


def f0_to_pitch(ff):
f0_pitch = 69 + 12 * np.log2(ff / 440)
return f0_pitch


def fill_a_to_b(a, b):
if len(a) < len(b):
for _ in range(0, len(b) - len(a)):
a.append(a[0])


def mkdir(paths: list):
for path in paths:
if not os.path.exists(path):
os.mkdir(path)


class VitsSvc(object):
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.SVCVITS = None
self.hps = None
self.speakers = None
self.hubert_soft = hubert_model.hubert_soft("hubert/model.pt")

def set_device(self, device):
self.device = torch.device(device)
self.hubert_soft.to(self.device)
if self.SVCVITS != None:
self.SVCVITS.to(self.device)

def loadCheckpoint(self, path):
self.hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
self.SVCVITS = SynthesizerTrn(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.SVCVITS, None)
_ = self.SVCVITS.eval().to(self.device)
self.speakers = self.hps.spk

def get_units(self, source, sr):
source = source.unsqueeze(0).to(self.device)
with torch.inference_mode():
units = self.hubert_soft.units(source)
return units


def get_unit_pitch(self, in_path, tran):
source, sr = torchaudio.load(in_path)
source = torchaudio.functional.resample(source, sr, 16000)
if len(source.shape) == 2 and source.shape[1] >= 2:
source = torch.mean(source, dim=0).unsqueeze(0)
soft = self.get_units(source, sr).squeeze(0).cpu().numpy()
f0_coarse, f0 = get_f0(source.cpu().numpy()[0], soft.shape[0]*2, tran)
return soft, f0

def infer(self, speaker_id, tran, raw_path):
speaker_id = self.speakers[speaker_id]
sid = torch.LongTensor([int(speaker_id)]).to(self.device).unsqueeze(0)
soft, pitch = self.get_unit_pitch(raw_path, tran)
f0 = torch.FloatTensor(clean_pitch(pitch)).unsqueeze(0).to(self.device)
stn_tst = torch.FloatTensor(soft)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.device)
x_tst = torch.repeat_interleave(x_tst, repeats=2, dim=1).transpose(1, 2)
audio = self.SVCVITS.infer(x_tst, f0=f0, g=sid)[0,0].data.float()
return audio, audio.shape[-1]

def inference(self,srcaudio,chara,tran,slice_db):
sampling_rate, audio = srcaudio
audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.transpose(1, 0))
if sampling_rate != 16000:
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
soundfile.write("tmpwav.wav", audio, 16000, format="wav")
chunks = slicer.cut("tmpwav.wav", db_thresh=slice_db)
audio_data, audio_sr = slicer.chunks2audio("tmpwav.wav", chunks)
audio = []
for (slice_tag, data) in audio_data:
length = int(np.ceil(len(data) / audio_sr * self.hps.data.sampling_rate))
raw_path = io.BytesIO()
soundfile.write(raw_path, data, audio_sr, format="wav")
raw_path.seek(0)
if slice_tag:
_audio = np.zeros(length)
else:
out_audio, out_sr = self.infer(chara, tran, raw_path)
_audio = out_audio.cpu().numpy()
audio.extend(list(_audio))
audio = (np.array(audio) * 32768.0).astype('int16')
return (self.hps.data.sampling_rate,audio)
6 changes: 4 additions & 2 deletions onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from hubert import hubert_model_onnx

def main(HubertExport,NetExport):

path = "NyaruTaffy"

if(HubertExport):
device = torch.device("cuda")
hubert_soft = hubert_model_onnx.hubert_soft("hubert/model.pt")
Expand All @@ -30,7 +33,6 @@ def main(HubertExport,NetExport):
input_names=input_names,
output_names=output_names)
if(NetExport):
path = "NyaruTaffy"
device = torch.device("cuda")
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
Expand All @@ -55,7 +57,7 @@ def main(HubertExport,NetExport):
test_pitch.to(device),
test_sid.to(device)
),
"a.onnx",
f"checkpoints/{path}/model.onnx",
dynamic_axes={
"hidden_unit": [0, 1],
"pitch": [1]
Expand Down
47 changes: 47 additions & 0 deletions sovits_gradio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from inference.infer_tool_grad import VitsSvc
import gradio as gr
import os

class VitsGradio:
def __init__(self):
self.so = VitsSvc()
self.lspk = []
self.modelPaths = []
for root,dirs,files in os.walk("checkpoints"):
for dir in dirs:
self.modelPaths.append(dir)
with gr.Blocks() as self.Vits:
with gr.Tab("VoiceConversion"):
with gr.Row(visible=False) as self.VoiceConversion:
with gr.Column():
with gr.Row():
with gr.Column():
self.srcaudio = gr.Audio(label = "输入音频")
self.btnVC = gr.Button("说话人转换")
with gr.Column():
self.dsid = gr.Dropdown(label = "目标角色", choices = self.lspk)
self.tran = gr.Slider(label = "升降调", maximum = 60, minimum = -60, step = 1, value = 0)
self.th = gr.Slider(label = "切片阈值", maximum = 32767, minimum = -32768, step = 0.1, value = -40)
with gr.Row():
self.VCOutputs = gr.Audio()
self.btnVC.click(self.so.inference, inputs=[self.srcaudio,self.dsid,self.tran,self.th], outputs=[self.VCOutputs])
with gr.Tab("SelectModel"):
with gr.Column():
modelstrs = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
devicestrs = gr.Dropdown(label = "设备", choices = ["cpu","cuda"], value = "cpu", type = "value")
btnMod = gr.Button("载入模型")
btnMod.click(self.loadModel, inputs=[modelstrs,devicestrs], outputs = [self.dsid,self.VoiceConversion])

def loadModel(self, path, device):
self.lspk = []
self.so.set_device(device)
self.so.loadCheckpoint(path)
for spk, sid in self.so.hps.spk.items():
self.lspk.append(spk)
VChange = gr.update(visible = True)
SDChange = gr.update(choices = self.lspk, value = self.lspk[0])
return [SDChange,VChange]

grVits = VitsGradio()

grVits.Vits.launch()

0 comments on commit aa31743

Please sign in to comment.