From 6f25b66c78f4fe4ef7c74b40c7ae512c823d975e Mon Sep 17 00:00:00 2001 From: Weiyang Date: Wed, 29 Mar 2023 20:40:29 +0800 Subject: [PATCH 1/4] Update config_v1.json Fix Upsampling ratio == kernel size //2 --- tools/nsf_hifigan/config_v1.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/nsf_hifigan/config_v1.json b/tools/nsf_hifigan/config_v1.json index 5f4de6a1..6bfdb545 100644 --- a/tools/nsf_hifigan/config_v1.json +++ b/tools/nsf_hifigan/config_v1.json @@ -14,9 +14,9 @@ "upsample_kernel_sizes": [ 16, 16, - 8, - 2, - 2 + 4, + 4, + 4 ], "upsample_initial_channel": 512, "resblock_kernel_sizes": [ From eb825658135f7f27632e442bf356cdf29dce6257 Mon Sep 17 00:00:00 2001 From: Weiyang Date: Wed, 29 Mar 2023 20:43:48 +0800 Subject: [PATCH 2/4] Update train.py Choose random scale mel and stft, each 5 different levels --- tools/nsf_hifigan/train.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tools/nsf_hifigan/train.py b/tools/nsf_hifigan/train.py index ebf60efa..2220f4d7 100644 --- a/tools/nsf_hifigan/train.py +++ b/tools/nsf_hifigan/train.py @@ -2,6 +2,8 @@ import json from argparse import ArgumentParser +import numpy as np + import matplotlib.pyplot as plt import pytorch_lightning as pl import torch @@ -70,6 +72,8 @@ def __init__(self, config): ) for (n_fft, hop_length, win_length) in [ (self.h.n_fft, self.h.hop_size, self.h.win_size), + (512, 66, 264), + (1024, 135, 540), (2048, 270, 1080), (4096, 540, 2160), ] @@ -136,11 +140,15 @@ def training_step(self, batch, batch_idx): # We referenced STFT and Mel-Spectrogram loss from SingGAN # L1 STFT Loss stft_config = [ + (256, 25, 120), (512, 50, 240), (1024, 120, 600), (2048, 240, 1200), + (4096, 480, 2400), ] - + + stft_config = [stft_config[np.random.randint(5)]] + loss_stft = 0 for n_fft, hop_length, win_length in stft_config: y_stft = torch.stft( @@ -158,7 +166,10 @@ def training_step(self, batch, batch_idx): # L1 Mel-Spectrogram Loss loss_mel = 0 - for mel_transform in self.multi_scale_mels: + + rnd_scale_mel = [self.multi_scale_mels[np.random.randint(5)]] + + for mel_transform in rnd_scale_mel: y_mel = self.get_mels(y, mel_transform) y_g_hat_mel = self.get_mels(y_g_hat, mel_transform) loss_mel += F.l1_loss(y_mel, y_g_hat_mel) From e0d4bb21385511855e427078a31550dd876d8d4c Mon Sep 17 00:00:00 2001 From: catalystfrank Date: Tue, 13 Jun 2023 18:01:08 +0800 Subject: [PATCH 3/4] Add hop256 Support Add BigVGan Snake to nsf_hifigan, use modelsnake --- configs/vocoder_nsf_hifigan.py | 16 +- .../modules/pitch_extractors/builder.py | 12 +- .../modules/pitch_extractors/parsel_mouth.py | 3 +- .../nsf_hifigan/alias_free_torch/__init__.py | 6 + .../nsf_hifigan/alias_free_torch/act.py | 28 + .../nsf_hifigan/alias_free_torch/filter.py | 95 +++ .../nsf_hifigan/alias_free_torch/resample.py | 49 ++ .../vocoders/nsf_hifigan/modelsnake.py | 672 ++++++++++++++++++ tools/nsf_hifigan/config_v2_256.json | 61 ++ tools/nsf_hifigan/train.py | 19 +- tools/preprocessing/extract_features.py | 16 +- 11 files changed, 942 insertions(+), 35 deletions(-) create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/__init__.py create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/act.py create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/filter.py create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/resample.py create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/modelsnake.py create mode 100644 tools/nsf_hifigan/config_v2_256.json diff --git a/configs/vocoder_nsf_hifigan.py b/configs/vocoder_nsf_hifigan.py index 0a811a1f..311e4c43 100644 --- a/configs/vocoder_nsf_hifigan.py +++ b/configs/vocoder_nsf_hifigan.py @@ -28,7 +28,7 @@ model = dict( type="NSF-HiFiGAN", - config="tools/nsf_hifigan/config_v1_256.json", + config="tools/nsf_hifigan/config_v2_256.json", # The following code are used for preprocessing vocoder=dict( type="NsfHifiGAN", @@ -39,16 +39,16 @@ dataset = dict( train=dict( type="NaiveVOCODERDataset", - path="/fs/nexus-scratch/lengyue/vocoder-data/train", - segment_size=32768, + path="/home/ai4/VocoderData/train256", + segment_size=16384, pitch_shift=[-12, 12], - loudness_shift=[0.1, 0.9], + loudness_shift=[0.1, 0.999], hop_length=hop_length, sampling_rate=sampling_rate, ), valid=dict( type="NaiveVOCODERDataset", - path="/fs/nexus-scratch/lengyue/vocoder-data/valid", + path="/home/ai4/VocoderData/val256", segment_size=None, pitch_shift=None, loudness_shift=None, @@ -59,7 +59,7 @@ dataloader = dict( train=dict( - batch_size=20, + batch_size=8, shuffle=True, num_workers=4, persistent_workers=True, @@ -74,10 +74,10 @@ preprocessing = dict( pitch_extractor=dict( - type="HarvestPitchExtractor", + type="ParselMouthPitchExtractor", keep_zeros=False, f0_min=40.0, - f0_max=2000.0, + f0_max=2100.0, hop_length=hop_length, ), ) diff --git a/fish_diffusion/modules/pitch_extractors/builder.py b/fish_diffusion/modules/pitch_extractors/builder.py index f85ab74c..f81e0012 100644 --- a/fish_diffusion/modules/pitch_extractors/builder.py +++ b/fish_diffusion/modules/pitch_extractors/builder.py @@ -10,17 +10,17 @@ class BasePitchExtractor: def __init__( self, - hop_length: int = 512, - f0_min: float = 50.0, - f0_max: float = 1100.0, + hop_length: int = 256, + f0_min: float = 40.0, + f0_max: float = 2100.0, keep_zeros: bool = True, ): """Base pitch extractor. Args: - hop_length (int, optional): Hop length. Defaults to 512. - f0_min (float, optional): Minimum f0. Defaults to 50.0. - f0_max (float, optional): Maximum f0. Defaults to 1100.0. + hop_length (int, optional): Hop length. Defaults to 256 + f0_min (float, optional): Minimum f0. Defaults to 40.0 + f0_max (float, optional): Maximum f0. Defaults to 2100.0. keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True. """ diff --git a/fish_diffusion/modules/pitch_extractors/parsel_mouth.py b/fish_diffusion/modules/pitch_extractors/parsel_mouth.py index 15671c32..f3a1ebf9 100644 --- a/fish_diffusion/modules/pitch_extractors/parsel_mouth.py +++ b/fish_diffusion/modules/pitch_extractors/parsel_mouth.py @@ -36,7 +36,8 @@ def __call__(self, x, sampling_rate=44100, pad_to=None): # Pad zeros to the end if pad_to is not None: - total_pad = pad_to - f0.shape[0] + total_pad = pad_to - f0.shape[0] + 1 + #print(pad_to,f0.shape[0]) f0 = np.pad(f0, (total_pad // 2, total_pad - total_pad // 2), "constant") return self.post_process(x, sampling_rate, f0, pad_to) diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/__init__.py b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/__init__.py new file mode 100644 index 00000000..a2318b63 --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/act.py b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/act.py new file mode 100644 index 00000000..028debd6 --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/filter.py b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/filter.py new file mode 100644 index 00000000..7ad6ea87 --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/resample.py b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/resample.py new file mode 100644 index 00000000..750e6c34 --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/modelsnake.py b/fish_diffusion/modules/vocoders/nsf_hifigan/modelsnake.py new file mode 100644 index 00000000..a12ec638 --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/modelsnake.py @@ -0,0 +1,672 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm +from .activations import Snake +from .alias_free_torch import * + +LRELU_SLOPE = 0.1 + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + #''' + self.num_layers = len(self.convs1) + len(self.convs2) + self.activations= nn.ModuleList([ + Activation1d(activation=Snake(channels,alpha_logscale=False)) + for _ in range(self.num_layers) + ]) + #''' + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + #for c1, c2 in zip(self.convs1, self.convs2): + #xt = F.leaky_relu(x, LRELU_SLOPE, inplace=False) + xt = a1(x) + xt = c1(xt) + #xt = F.leaky_relu(xt, LRELU_SLOPE, inplace=True) + xt = a2(xt) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + #''' + self.num_layers = len(self.convs) + + self.activations = nn.ModuleList([ + Activation1d(activation=Snake(channels, alpha_logscale=False)) + for _ in range(self.num_layers) + ]) + #''' + + def forward(self, x): + for c,a in zip(self.convs,self.activations): + #for c in self.convs: + #xt = F.leaky_relu(x, LRELU_SLOPE, inplace=True) + xt = a(x) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def _f02sine(self, f0_values): + """f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand( + f0_values.shape[0], f0_values.shape[2], device=f0_values.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + + # generate sine waveforms + sine_waves = self._f02sine(f0_buf) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__( + self, + sampling_rate, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + ): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod + ) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates)) + self.m_source = SourceModuleHnNSF(sampling_rate=h.sampling_rate, harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + c_cur = h.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + if i + 1 < len(h.upsample_rates): # + stride_f0 = np.prod(h.upsample_rates[i + 1 :]) + self.noise_convs.append( + Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, f0): + if f0.ndim == 2: + f0 = f0[:, None] + + f0 = F.interpolate( + f0, size=x.shape[-1] * self.h.hop_size, mode="linear" + ).transpose(1, 2) + + har_source, _, _ = self.m_source(f0) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE, inplace=True) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + x = x + x_source + xs = None + + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + + x = xs / self.num_kernels + + x = F.leaky_relu(x, inplace=True) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + + for l in self.resblocks: + l.remove_weight_norm() + + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE, inplace=True) + x = torch.nan_to_num(x) + + fmap.append(x) + + x = self.conv_post(x) + x = torch.nan_to_num(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE, inplace=True) + x = torch.nan_to_num(x) + fmap.append(x) + + x = self.conv_post(x) + x = torch.nan_to_num(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/tools/nsf_hifigan/config_v2_256.json b/tools/nsf_hifigan/config_v2_256.json new file mode 100644 index 00000000..2b2be9a5 --- /dev/null +++ b/tools/nsf_hifigan/config_v2_256.json @@ -0,0 +1,61 @@ +{ + "resblock": "1", + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "upsample_rates": [ + 4, + 4, + 4, + 2, + 2 + ], + "upsample_kernel_sizes": [ + 8, + 8, + 8, + 4, + 4 + ], + "upsample_initial_channel": 768, + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "discriminator_periods": [ + 3, + 5, + 7, + 11, + 17, + 23, + 37 + ], + "segment_size": 16384, + "num_mels": 128, + "n_fft": 2048, + "hop_size": 256, + "win_size": 2048, + "sampling_rate": 44100, + "fmin": 40, + "fmax": 16000 +} diff --git a/tools/nsf_hifigan/train.py b/tools/nsf_hifigan/train.py index 51e5a259..0aa10733 100644 --- a/tools/nsf_hifigan/train.py +++ b/tools/nsf_hifigan/train.py @@ -1,7 +1,6 @@ import itertools import json from argparse import ArgumentParser - import numpy as np import matplotlib.pyplot as plt @@ -71,8 +70,6 @@ def __init__(self, config): ) for (n_fft, hop_length, win_length) in [ (self.h.n_fft, self.h.hop_size, self.h.win_size), - (512, 66, 264), - (1024, 135, 540), (2048, 270, 1080), (4096, 540, 2160), ] @@ -160,13 +157,14 @@ def training_step(self, batch, batch_idx): (512, 50, 240), (1024, 120, 600), (2048, 240, 1200), - (4096, 480, 2400), + (4096, 480, 2400) ] - - stft_config = [stft_config[np.random.randint(5)]] - + loss_stft = 0 - for n_fft, hop_length, win_length in stft_config: + + tmp_stft_config = [stft_config[np.random.randint(0,5)] for i in range(3)] + + for n_fft, hop_length, win_length in tmp_stft_config: y_stft = torch.stft( y.squeeze(1), n_fft, hop_length, win_length, return_complex=True ) @@ -182,10 +180,7 @@ def training_step(self, batch, batch_idx): # L1 Mel-Spectrogram Loss loss_mel = 0 - - rnd_scale_mel = [self.multi_scale_mels[np.random.randint(5)]] - - for mel_transform in rnd_scale_mel: + for mel_transform in self.multi_scale_mels: y_mel = self.get_mels(y, mel_transform) y_g_hat_mel = self.get_mels(y_g_hat, mel_transform) loss_mel += F.l1_loss(y_mel, y_g_hat_mel) diff --git a/tools/preprocessing/extract_features.py b/tools/preprocessing/extract_features.py index d4f0d72a..16319895 100644 --- a/tools/preprocessing/extract_features.py +++ b/tools/preprocessing/extract_features.py @@ -124,14 +124,14 @@ def process( audio = torch.from_numpy(audio).unsqueeze(0).to(device) # Obtain mel spectrogram - if vocoder is not None: - mel = vocoder.wav2spec(audio, sr, key_shift=key_shift) - mel_length = mel.shape[-1] - sample["mel"] = mel.cpu().numpy() - else: - # Calculate mel length from audio length - hop_length = getattr(config, "hop_length", 512) - mel_length = int(audio.shape[-1] / hop_length) + 1 + #if vocoder is not None: + # mel = vocoder.wav2spec(audio, sr, key_shift=key_shift) + # mel_length = mel.shape[-1] + # sample["mel"] = mel.cpu().numpy() + #else: + # Calculate mel length from audio length + hop_length = getattr(config, "hop_length", 512) + mel_length = int(audio.shape[-1] / hop_length) + 1 # Extract text features if text_features_extractor is not None: From 5778f7a894a6fe2637351f3275819edae22db7c5 Mon Sep 17 00:00:00 2001 From: catalystfrank Date: Tue, 13 Jun 2023 20:04:26 +0800 Subject: [PATCH 4/4] Add Activations --- .../vocoders/nsf_hifigan/activations.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 fish_diffusion/modules/vocoders/nsf_hifigan/activations.py diff --git a/fish_diffusion/modules/vocoders/nsf_hifigan/activations.py b/fish_diffusion/modules/vocoders/nsf_hifigan/activations.py new file mode 100644 index 00000000..61f2808a --- /dev/null +++ b/fish_diffusion/modules/vocoders/nsf_hifigan/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file