diff --git a/fish_diffusion/denoisers/wavenet.py b/fish_diffusion/denoisers/wavenet.py index 24264727..eb49df42 100644 --- a/fish_diffusion/denoisers/wavenet.py +++ b/fish_diffusion/denoisers/wavenet.py @@ -200,7 +200,10 @@ def forward(self, x, diffusion_step, conditioner): :param conditioner: [B, M, T] :return: """ - + Onnx = False + if x.dim() == 4: + x = x.squeeze(0) + Onnx = True assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}" x = self.input_projection(x) # x [B, residual_channel, T] @@ -219,4 +222,4 @@ def forward(self, x, diffusion_step, conditioner): x = F.relu(x) x = self.output_projection(x) # [B, 128, T] - return x + return x.unsqueeze(0) if Onnx else x diff --git a/fish_diffusion/diffusions/noise_predictor.py b/fish_diffusion/diffusions/noise_predictor.py index bc9720b6..66541309 100644 --- a/fish_diffusion/diffusions/noise_predictor.py +++ b/fish_diffusion/diffusions/noise_predictor.py @@ -9,6 +9,10 @@ def extract(a, t): return a[t].reshape((1, 1, 1)) +def extract_MoeSS(a, t): + return a[t].reshape((1, 1, 1, 1)) + + to_torch = partial(torch.tensor, dtype=torch.float32) @@ -112,8 +116,8 @@ def __init__(self, betas): self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) def forward(self, x, noise_t, t, t_prev): - a_t = extract(self.alphas_cumprod, t) - a_prev = extract(self.alphas_cumprod, t_prev) + a_t = extract(self.alphas_cumprod, t) if x.dim() == 3 else extract_MoeSS(self.alphas_cumprod, t) + a_prev = extract(self.alphas_cumprod, t_prev) if x.dim() == 3 else extract_MoeSS(self.alphas_cumprod, t_prev) a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() x_delta = (a_prev - a_t) * ( diff --git a/fish_diffusion/moessdiffusion/Moess_diffusion.py b/fish_diffusion/moessdiffusion/Moess_diffusion.py new file mode 100644 index 00000000..45102bef --- /dev/null +++ b/fish_diffusion/moessdiffusion/Moess_diffusion.py @@ -0,0 +1,185 @@ +from fish_diffusion.denoisers import DENOISERS +from fish_diffusion.diffusions.noise_predictor import ( + NaiveNoisePredictor, + PLMSNoisePredictor, +) +import numpy as np +import torch +from torch import nn +import json +from functools import partial +from fish_diffusion.moessdiffusion import MOESSDIFFUSIONS + + +def get_noise_schedule_list(schedule_mode, timesteps, max_beta=0.01, s=0.008): + if schedule_mode == "linear": + schedule_list = np.linspace(1e-4, max_beta, timesteps) + elif schedule_mode == "cosine": + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + schedule_list = np.clip(betas, a_min=0, a_max=0.999) + else: + raise NotImplementedError + + return schedule_list + + +def predict_stage0(noise_pred, noise_pred_prev): + return (noise_pred + noise_pred_prev) / 2 + + +def predict_stage1(noise_pred, noise_list): + return (noise_pred * 3 + - noise_list[-1]) / 2 + + +def predict_stage2(noise_pred, noise_list): + return (noise_pred * 23 + - noise_list[-1] * 16 + + noise_list[-2] * 5) / 12 + + +def predict_stage3(noise_pred, noise_list): + return (noise_pred * 55 + - noise_list[-1] * 59 + + noise_list[-2] * 37 + - noise_list[-3] * 9) / 24 + + +class AfterDiffusion(nn.Module): + def __init__(self, spec_max, spec_min): + super().__init__() + self.spec_max = spec_max + self.spec_min = spec_min + + def forward(self, x): + x = x.squeeze(1).permute(0, 2, 1) + d = (self.spec_max - self.spec_min) / 2 + m = (self.spec_max + self.spec_min) / 2 + mel_out = x * d + m + mel_out = mel_out * 2.30259 + return mel_out.transpose(2, 1) + + +@MOESSDIFFUSIONS.register_module() +class GaussianDiffusion(nn.Module): + def __init__( + self, denoiser, mel_channels=128, noise_schedule="linear", timesteps=1000, max_beta=0.01, s=0.008, + noise_loss="l1", sampler_interval=10, spec_stats_path="dataset/stats.json", spec_min=None, spec_max=None, + ): + super().__init__() + self.denoise_fn = DENOISERS.build(denoiser) + self.mel_bins = mel_channels + betas = get_noise_schedule_list(noise_schedule, timesteps, max_beta, s) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.noise_loss = noise_loss + to_torch = partial(torch.tensor, dtype=torch.float32) + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + assert (spec_min is None and spec_max is None) or ( + spec_min is not None and spec_max is not None + ), "spec_min and spec_max must be both None or both not None" + if spec_min is None: + with open(spec_stats_path) as f: + stats = json.load(f) + + spec_min = stats["spec_min"] + spec_max = stats["spec_max"] + assert ( + len(spec_min) == len(spec_max) == mel_channels + or len(spec_min) == len(spec_max) == 1 + ), "spec_min and spec_max must be either of length 1 or mel_channels" + self.register_buffer("spec_min", torch.FloatTensor(spec_min).view(1, 1, -1)) + self.register_buffer("spec_max", torch.FloatTensor(spec_max).view(1, 1, -1)) + self.sampler_interval = sampler_interval + self.naive_noise_predictor = NaiveNoisePredictor(betas=betas) + self.plms_noise_predictor = PLMSNoisePredictor(betas=betas) + self.ad = AfterDiffusion(spec_max=self.spec_max, spec_min=self.spec_min) + + def MoeSSOnnxExport(self, project_name, device): + sampler_interval = self.sampler_interval + features = torch.zeros([1, 256, 10]).to(device) + shape = (features.shape[0], 1, self.mel_bins, features.shape[2]) + x = torch.randn(shape, device=device) + pndms = 100 + n_frames = features.shape[2] + step_range = torch.arange(0, 1000, pndms, dtype=torch.long, device=device).flip(0)[:, None] + plms_noise_stage = torch.tensor(0, dtype=torch.long, device=device) + noise_list = torch.zeros((0, 1, 1, self.mel_bins, n_frames), device=device) + ot = step_range[0] + torch.onnx.export( + self.denoise_fn, + (x.to(device), ot.to(device), features.to(device)), + f"{project_name}_denoise.onnx", + input_names=["noise", "time", "condition"], + output_names=["noise_pred"], + dynamic_axes={ + "noise": [3], + "condition": [2] + }, + opset_version=16 + ) + for t in step_range: + noise_pred = self.denoise_fn(x, t, features) + t_prev = t - sampler_interval + t_prev = t_prev * (t_prev > 0) + if plms_noise_stage == 0: + torch.onnx.export( + self.plms_noise_predictor, + (x.to(device), noise_pred.to(device), t.to(device), t_prev.to(device)), + f"{project_name}_pred.onnx", + input_names=["noise", "noise_pred", "time", "time_prev"], + output_names=["noise_pred_o"], + dynamic_axes={ + "noise": [3], + "noise_pred": [3] + }, + opset_version=16 + ) + x_pred = self.plms_noise_predictor(x, noise_pred, t, t_prev) + noise_pred_prev = self.denoise_fn(x_pred, t_prev, features) + noise_pred_prime = self.plms_noise_predictor.predict_stage0( + noise_pred, noise_pred_prev + ) + elif plms_noise_stage == 1: + noise_pred_prime = self.plms_noise_predictor.predict_stage1( + noise_pred, noise_list + ) + elif plms_noise_stage == 2: + noise_pred_prime = self.plms_noise_predictor.predict_stage2( + noise_pred, noise_list + ) + else: + noise_pred_prime = self.plms_noise_predictor.predict_stage3( + noise_pred, noise_list + ) + + noise_pred = noise_pred.unsqueeze(0) + if plms_noise_stage < 3: + noise_list = torch.cat((noise_list, noise_pred), dim=0) + plms_noise_stage = plms_noise_stage + 1 + else: + noise_list = torch.cat((noise_list[-2:], noise_pred), dim=0) + + x = self.plms_noise_predictor(x, noise_pred_prime, t, t_prev) + torch.onnx.export( + self.ad, + x.to(device), + f"{project_name}_after.onnx", + input_names=["x"], + output_names=["mel_out"], + dynamic_axes={ + "x": [3] + }, + opset_version=16 + ) diff --git a/fish_diffusion/moessdiffusion/__init__.py b/fish_diffusion/moessdiffusion/__init__.py new file mode 100644 index 00000000..c0e79cf7 --- /dev/null +++ b/fish_diffusion/moessdiffusion/__init__.py @@ -0,0 +1,4 @@ +from .builder import MOESSDIFFUSIONS +from .Moess_diffusion import GaussianDiffusion + +__all__ = ["MOESSDIFFUSIONS", "GaussianDiffusion"] \ No newline at end of file diff --git a/fish_diffusion/moessdiffusion/builder.py b/fish_diffusion/moessdiffusion/builder.py new file mode 100644 index 00000000..1e05fe6b --- /dev/null +++ b/fish_diffusion/moessdiffusion/builder.py @@ -0,0 +1,3 @@ +from mmengine import Registry + +MOESSDIFFUSIONS = Registry("moessdiffusion") diff --git a/tools/onnx/MoeSS_Onnx_Model.py b/tools/onnx/MoeSS_Onnx_Model.py new file mode 100644 index 00000000..a381653b --- /dev/null +++ b/tools/onnx/MoeSS_Onnx_Model.py @@ -0,0 +1,93 @@ +import torch +from torch import nn +import pytorch_lightning as pl +import torch.nn.functional as F +from fish_diffusion.encoders import ENCODERS +from mmengine import Config +from fish_diffusion.moessdiffusion import MOESSDIFFUSIONS + + +def denorm_f0(f0, pitch_padding=None): + rf0 = 2 ** f0 + rf0[pitch_padding] = 0 + return rf0 + + +def add_pitch(f0, mel2ph): + pitch_padding = (mel2ph == 0) + f0_denorm = denorm_f0(f0, pitch_padding=pitch_padding) + return f0_denorm + + +class DiffSvc(nn.Module): + def __init__(self, model_config): + super(DiffSvc, self).__init__() + self.text_encoder = ENCODERS.build(model_config.text_encoder) + self.diffusion = MOESSDIFFUSIONS.build(model_config.diffusion) + self.speaker_encoder = ENCODERS.build(model_config.speaker_encoder) + self.pitch_encoder = ENCODERS.build(model_config.pitch_encoder) + + def forward(self, hubert, mel2ph, spk_embed, f0): + decoder_inp = F.pad(hubert, [0, 0, 1, 0]) + mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, hubert.shape[-1]]) + decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] + + f0_denorm = add_pitch(f0, mel2ph) + + max_src_len = decoder_inp.shape[1] + features = self.text_encoder(decoder_inp, None) + speaker_embed = ( + self.speaker_encoder(spk_embed).unsqueeze(1).expand(-1, max_src_len, -1) + ) + features += speaker_embed + features += self.pitch_encoder(f0_denorm) + return features.transpose(1, 2), f0_denorm + + +class FishDiffusion(pl.LightningModule): + def __init__(self, config): + super().__init__() + self.save_hyperparameters() + self.model = DiffSvc(config.model) + self.config = config + + +def main(project_name): + device = "cpu" + config = Config.fromfile("configs/svc_hubert_soft_multi_speakers.py") + model = FishDiffusion(config) + state_dict = torch.load( + "epoch=619-step=140000-valid_loss=0.22.ckpt", + map_location=device, + )["state_dict"] + model.load_state_dict(state_dict, strict=False) + model.eval() + model.to(device) + model = model.model + + hubert = torch.randn(1, 300, 256) + mel2ph = torch.arange(0, 300, dtype=torch.int64).unsqueeze(0) + f0 = torch.randn(1, 300) + spk_embed = torch.LongTensor([0]) + print(hubert.shape, mel2ph.shape, spk_embed.shape, f0.shape) + torch.onnx.export( + model, + (hubert, mel2ph, spk_embed, f0), + f"{project_name}_encoder.onnx", + input_names=["hubert", "mel2ph", "spk_embed", "f0"], + output_names=["mel_pred", "f0_pred"], + dynamic_axes={ + "hubert": [1], + "f0": [1], + "mel2ph": [1] + }, + opset_version=16 + ) + + print("exporting Diffusion") + model.diffusion.MoeSSOnnxExport(project_name, device) + print("Diffusion exported") + + +if __name__ == "__main__": + main(project_name="MyModel")