Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update tempo change #403

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,51 +53,51 @@ def list_avaliable_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks

def inference_sft(self, tts_text, spk_id, stream=False):
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
if self.frontend.instruct is True:
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
if self.frontend.instruct is False:
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream):
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
Expand Down
11 changes: 8 additions & 3 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
Expand Down Expand Up @@ -91,7 +92,7 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
Expand All @@ -116,14 +117,17 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=
self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
return tts_speech

def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, **kwargs):
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
Expand Down Expand Up @@ -169,7 +173,8 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
Expand Down
13 changes: 0 additions & 13 deletions cosyvoice/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,3 @@ def load_wav(wav, target_sr):
assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
return speech


def speed_change(waveform, sample_rate, speed_factor: str):
effects = [
["tempo", speed_factor], # speed_factor
["rate", f"{sample_rate}"]
]
augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor(
waveform,
sample_rate,
effects
)
return augmented_waveform, new_sample_rate
14 changes: 7 additions & 7 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def change_instruction(mode_checkbox_group):


def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed_factor):
seed, stream, speed):
if prompt_wav_upload is not None:
prompt_wav = prompt_wav_upload
elif prompt_wav_record is not None:
Expand Down Expand Up @@ -117,24 +117,24 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
if mode_checkbox_group == '预训练音色':
logging.info('get sft inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '3s极速复刻':
logging.info('get zero_shot inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())
elif mode_checkbox_group == '跨语种复刻':
logging.info('get cross_lingual inference request')
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
set_all_random_seed(seed)
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())
else:
logging.info('get instruct inference request')
set_all_random_seed(seed)
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream, speed=speed):
yield (target_sr, i['tts_speech'].numpy().flatten())


Expand All @@ -147,12 +147,12 @@ def main():
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")

tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
speed_factor = gr.Slider(minimum=0.25, maximum=4, step=0.05, label="语速调节", value=1.0, interactive=True)
with gr.Row():
mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
stream = gr.Radio(choices=stream_mode_list, label='是否流式推理', value=stream_mode_list[0][1])
speed = gr.Number(value=1, label="速度调节(仅支持非流式推理)", minimum=0.5, maximum=2.0, step=0.1)
with gr.Column(scale=0.25):
seed_button = gr.Button(value="\U0001F3B2")
seed = gr.Number(value=0, label="随机推理种子")
Expand All @@ -170,7 +170,7 @@ def main():
seed_button.click(generate_seed, inputs=[], outputs=seed)
generate_button.click(generate_audio,
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
seed, stream, speed_factor],
seed, stream, speed],
outputs=[audio_output])
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
demo.queue(max_size=4, default_concurrency_limit=2)
Expand Down
Loading