Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hop256 Support And BigVGan Snake structure #101

Open
wants to merge 5 commits into
base: v3-backup
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions configs/vocoder_nsf_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

model = dict(
type="NSF-HiFiGAN",
config="tools/nsf_hifigan/config_v1_256.json",
config="tools/nsf_hifigan/config_v2_256.json",
# The following code are used for preprocessing
vocoder=dict(
type="NsfHifiGAN",
Expand All @@ -39,16 +39,16 @@
dataset = dict(
train=dict(
type="NaiveVOCODERDataset",
path="/fs/nexus-scratch/lengyue/vocoder-data/train",
segment_size=32768,
path="/home/ai4/VocoderData/train256",
segment_size=16384,
pitch_shift=[-12, 12],
loudness_shift=[0.1, 0.9],
loudness_shift=[0.1, 0.999],
hop_length=hop_length,
sampling_rate=sampling_rate,
),
valid=dict(
type="NaiveVOCODERDataset",
path="/fs/nexus-scratch/lengyue/vocoder-data/valid",
path="/home/ai4/VocoderData/val256",
segment_size=None,
pitch_shift=None,
loudness_shift=None,
Expand All @@ -59,7 +59,7 @@

dataloader = dict(
train=dict(
batch_size=20,
batch_size=8,
shuffle=True,
num_workers=4,
persistent_workers=True,
Expand All @@ -74,10 +74,10 @@

preprocessing = dict(
pitch_extractor=dict(
type="HarvestPitchExtractor",
type="ParselMouthPitchExtractor",
keep_zeros=False,
f0_min=40.0,
f0_max=2000.0,
f0_max=2100.0,
hop_length=hop_length,
),
)
12 changes: 6 additions & 6 deletions fish_diffusion/modules/pitch_extractors/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
class BasePitchExtractor:
def __init__(
self,
hop_length: int = 512,
f0_min: float = 50.0,
f0_max: float = 1100.0,
hop_length: int = 256,
f0_min: float = 40.0,
f0_max: float = 2100.0,
keep_zeros: bool = True,
):
"""Base pitch extractor.

Args:
hop_length (int, optional): Hop length. Defaults to 512.
f0_min (float, optional): Minimum f0. Defaults to 50.0.
f0_max (float, optional): Maximum f0. Defaults to 1100.0.
hop_length (int, optional): Hop length. Defaults to 256
f0_min (float, optional): Minimum f0. Defaults to 40.0
f0_max (float, optional): Maximum f0. Defaults to 2100.0.
keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True.
"""

Expand Down
3 changes: 2 additions & 1 deletion fish_diffusion/modules/pitch_extractors/parsel_mouth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __call__(self, x, sampling_rate=44100, pad_to=None):

# Pad zeros to the end
if pad_to is not None:
total_pad = pad_to - f0.shape[0]
total_pad = pad_to - f0.shape[0] + 1
#print(pad_to,f0.shape[0])
f0 = np.pad(f0, (total_pad // 2, total_pad - total_pad // 2), "constant")

return self.post_process(x, sampling_rate, f0, pad_to)
120 changes: 120 additions & 0 deletions fish_diffusion/modules/vocoders/nsf_hifigan/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.

import torch
from torch import nn, sin, pow
from torch.nn import Parameter


class Snake(nn.Module):
'''
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
'''
super(Snake, self).__init__()
self.in_features = in_features

# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)

self.alpha.requires_grad = alpha_trainable

self.no_div_by_zero = 0.000000001

def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
Snake ∶= x + 1/a * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

return x


class SnakeBeta(nn.Module):
'''
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
'''
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
'''
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
'''
super(SnakeBeta, self).__init__()
self.in_features = in_features

# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)

self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable

self.no_div_by_zero = 0.000000001

def forward(self, x):
'''
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta ∶= x + 1/b * sin^2 (xa)
'''
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)

return x
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

from .filter import *
from .resample import *
from .act import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch.nn as nn
from .resample import UpSample1d, DownSample1d


class Activation1d(nn.Module):
def __init__(self,
activation,
up_ratio: int = 2,
down_ratio: int = 2,
up_kernel_size: int = 12,
down_kernel_size: int = 12):
super().__init__()
self.up_ratio = up_ratio
self.down_ratio = down_ratio
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)

# x: [B,C,T]
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)

return x
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

if 'sinc' in dir(torch):
sinc = torch.sinc
else:
# This code is adopted from adefossez's julius.core.sinc under the MIT License
# https://adefossez.github.io/julius/julius/core.html
# LICENSE is in incl_licenses directory.
def sinc(x: torch.Tensor):
"""
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
"""
return torch.where(x == 0,
torch.tensor(1., device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x)


# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2

#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)

# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)

return filter


class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
half_width=0.6,
stride: int = 1,
padding: bool = True,
padding_mode: str = 'replicate',
kernel_size: int = 12):
# kernel_size should be even number for stylegan3 setup,
# in this implementation, odd number is also possible.
super().__init__()
if cutoff < -0.:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)

#input [B, C, T]
def forward(self, x):
_, C, _ = x.shape

if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right),
mode=self.padding_mode)
out = F.conv1d(x, self.filter.expand(C, -1, -1),
stride=self.stride, groups=C)

return out
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
# LICENSE is in incl_licenses directory.

import torch.nn as nn
from torch.nn import functional as F
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d


class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size)
self.register_buffer("filter", filter)

# x: [B, C, T]
def forward(self, x):
_, C, _ = x.shape

x = F.pad(x, (self.pad, self.pad), mode='replicate')
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left:-self.pad_right]

return x


class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size)

def forward(self, x):
xx = self.lowpass(x)

return xx
Loading
Loading