Skip to content

Commit

Permalink
refactor gnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Feb 1, 2024
1 parent c25d8be commit fb813b8
Show file tree
Hide file tree
Showing 14 changed files with 124 additions and 107 deletions.
123 changes: 51 additions & 72 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

from . import modules as fm
from . import modules as nn


def decimate(x, period=1, start=0, dim=-1):
Expand All @@ -39,15 +39,8 @@ def decimate(x, period=1, start=0, dim=-1):
Tensor [shape=(..., T/P-S, ...)]
Decimated signal.
Examples
--------
>>> x = diffsptk.ramp(9)
>>> y = diffsptk.functional.decimate(x, 3, start=1)
>>> y
tensor([1., 4., 7.])
"""
return fm.Decimation._forward(x, period=period, start=start, dim=dim)
return nn.Decimation._forward(x, period=period, start=start, dim=dim)


def delay(x, start=0, keeplen=False, dim=-1):
Expand All @@ -72,18 +65,8 @@ def delay(x, start=0, keeplen=False, dim=-1):
Tensor [shape=(..., T-S, ...)] or [shape=(..., T, ...)]
Delayed signal.
Examples
--------
>>> x = diffsptk.ramp(1, 3)
>>> y = diffsptk.functional.delay(x, 2)
>>> y
tensor([0., 0., 1., 2., 3.])
>>> y = diffsptk.functional.delay(x, 2, keeplen=True)
>>> y
tensor([0., 0., 1.])
"""
return fm.Delay._forward(x, start=start, keeplen=keeplen, dim=dim)
return nn.Delay._forward(x, start=start, keeplen=keeplen, dim=dim)


def dequantize(y, abs_max=1, n_bit=8, quantizer="mid-rise"):
Expand All @@ -108,22 +91,34 @@ def dequantize(y, abs_max=1, n_bit=8, quantizer="mid-rise"):
Tensor [shape=(...,)]
Dequantized input.
Examples
--------
>>> x = diffsptk.ramp(-4, 4)
>>> x
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.])
>>> y = diffsptk.functional.quantize(x, 4, 2)
>>> z = diffsptk.functional.dequantize(y, 4, 2)
>>> z
tensor([-3., -3., -1., -1., 1., 1., 3., 3., 3.])
"""
return fm.InverseUniformQuantization_.forward(
return nn.InverseUniformQuantization._forward(
y, abs_max=abs_max, n_bit=n_bit, quantizer=quantizer
)


def gnorm(x, gamma=0):
"""Perform cepstrum gain normalization.
Parameters
----------
x : Tensor [shape=(..., M+1)]
Generalized cepstrum.
Parameters
----------
gamma : float in [-1, 1]
Gamma, :math:`\\gamma`.
Returns
-------
Tensor [shape=(..., M+1)]
Normalized generalized cepstrum.
"""
return nn.GeneralizedCepstrumGainNormalization._forward(x, gamma=gamma)


def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs):
"""Compute group delay.
Expand All @@ -149,19 +144,32 @@ def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs):
Tensor [shape=(..., L/2+1)]
Group delay or modified group delay function.
Examples
--------
>>> x = diffsptk.ramp(3)
>>> g = diffsptk.functional.grpdelay(x, fft_length=8)
>>> g
tensor([2.3333, 2.4278, 3.0000, 3.9252, 3.0000])
"""
return fm.GroupDelay._forward(
return nn.GroupDelay._forward(
b, a, fft_length=fft_length, alpha=alpha, gamma=gamma, **kwargs
)


def ignorm(y, gamma=0):
"""Perform cepstrum inverse gain normalization.
Parameters
----------
y : Tensor [shape=(..., M+1)]
Normalized generalized cepstrum.
gamma : float in [-1, 1]
Gamma, :math:`\\gamma`.
Returns
-------
x : Tensor [shape=(..., M+1)]
Generalized cepstrum.
"""
return nn.GeneralizedCepstrumInverseGainNormalization._forward(y, gamma=gamma)


def interpolate(x, period=1, start=0, dim=-1):
"""Interpolate signal.
Expand All @@ -184,15 +192,8 @@ def interpolate(x, period=1, start=0, dim=-1):
Tensor [shape=(..., TxP+S, ...)]
Interpolated signal.
Examples
--------
>>> x = diffsptk.ramp(1, 3)
>>> y = diffsptk.functional.interpolate(x, 3, start=1)
>>> y
tensor([0., 1., 0., 0., 2., 0., 0., 3., 0., 0.])
"""
return fm.Interpolation._forward(x, period=period, start=start, dim=dim)
return nn.Interpolation._forward(x, period=period, start=start, dim=dim)


def phase(b=None, a=None, *, fft_length=512, unwrap=False):
Expand All @@ -217,15 +218,8 @@ def phase(b=None, a=None, *, fft_length=512, unwrap=False):
Tensor [shape=(..., L/2+1)]
Phase spectrum [:math:`\\pi` rad].
Examples
--------
>>> x = diffsptk.ramp(3)
>>> p = diffsptk.functional.phase(x, fft_length=8)
>>> p
tensor([ 0.0000, -0.5907, 0.7500, -0.1687, 1.0000])
"""
return fm.Phase._forward(b, a, fft_length=fft_length, unwrap=unwrap)
return nn.Phase._forward(b, a, fft_length=fft_length, unwrap=unwrap)


def quantize(x, abs_max=1, n_bit=8, quantizer="mid-rise"):
Expand All @@ -250,15 +244,8 @@ def quantize(x, abs_max=1, n_bit=8, quantizer="mid-rise"):
Tensor [shape=(...,)]
Quantized input.
Examples
--------
>>> x = diffsptk.ramp(-4, 4)
>>> y = diffsptk.functional.quantize(4, 2)
>>> y
tensor([0, 0, 1, 1, 2, 2, 3, 3, 3], dtype=torch.int32)
"""
return fm.UniformQuantization._forward(
return nn.UniformQuantization._forward(
x, abs_max=abs_max, n_bit=n_bit, quantizer=quantizer
)

Expand Down Expand Up @@ -293,16 +280,8 @@ def spec(
Tensor [shape=(..., L/2+1)]
Spectrum.
Examples
--------
>>> x = diffsptk.ramp(1, 3)
tensor([1., 2., 3.])
>>> s = diffsptk.functional.spec(x, fft_length=8)
>>> s
tensor([36.0000, 25.3137, 8.0000, 2.6863, 4.0000])
"""
return fm.Spectrum._forward(
return nn.Spectrum._forward(
b,
a,
fft_length=fft_length,
Expand Down
16 changes: 8 additions & 8 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,6 @@ def to_3d(x):
return y


def check(x, names):
if type(x) is int and 0 <= x <= len(names):
x = list(names)[x]
if x not in names:
raise ValueError(f"Unsupported value: {x}")
return x


def reflect(x):
d = x.size(-1)
y = x.view(-1, d)
Expand Down Expand Up @@ -284,6 +276,14 @@ def check_size(x, y, cause):
assert x == y, f"Unexpected {cause} (input {x} vs target {y})"


def check_mode(x, names):
if type(x) is int and 0 <= x <= len(names):
x = list(names)[x]
if x not in names:
raise ValueError(f"Unsupported item: {x}")
return x


def read(filename, double=False, **kwargs):
"""Read waveform from file.
Expand Down
4 changes: 2 additions & 2 deletions diffsptk/modules/dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn as nn

from ..misc.utils import check
from ..misc.utils import check_mode
from .quantize import _quantization_level


Expand All @@ -43,7 +43,7 @@ def __init__(self, abs_max=1, n_bit=8, quantizer="mid-rise"):

self.abs_max = abs_max
self.n_bit = n_bit
self.quantizer = check(quantizer, _quantization_level)
self.quantizer = check_mode(quantizer, _quantization_level)

assert 0 < self.abs_max
assert 1 <= self.n_bit
Expand Down
20 changes: 11 additions & 9 deletions diffsptk/modules/gnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class GeneralizedCepstrumGainNormalization(nn.Module):
Parameters
----------
cep_order : int >= 0 [scalar]
cep_order : int >= 0
Order of cepstrum, :math:`M`.
gamma : float [-1 <= gamma <= 1]
gamma : float in [-1, 1]
Gamma, :math:`\\gamma`.
c : int >= 1 [scalar]
c : int >= 1 or None
Number of stages.
"""
Expand All @@ -57,7 +57,7 @@ def forward(self, x):
Returns
-------
y : Tensor [shape=(..., M+1)]
Tensor [shape=(..., M+1)]
Normalized generalized cepstrum.
Examples
Expand All @@ -70,15 +70,17 @@ def forward(self, x):
"""
check_size(x.size(-1), self.cep_order + 1, "dimension of cepstrum")
return self._forward(x, self.gamma)

x0, x1 = torch.split(x, [1, self.cep_order], dim=-1)
if self.gamma == 0:
@staticmethod
def _forward(x, gamma):
x0, x1 = torch.split(x, [1, x.size(-1) - 1], dim=-1)
if gamma == 0:
K = torch.exp(x0)
y = x1
else:
z = 1 + self.gamma * x0
K = torch.pow(z, 1 / self.gamma)
z = 1 + gamma * x0
K = torch.pow(z, 1 / gamma)
y = x1 / z

y = torch.cat((K, y), dim=-1)
return y
18 changes: 10 additions & 8 deletions diffsptk/modules/ignorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class GeneralizedCepstrumInverseGainNormalization(nn.Module):
Parameters
----------
cep_order : int >= 0 [scalar]
cep_order : int >= 0
Order of cepstrum, :math:`M`.
gamma : float [-1 <= gamma <= 1]
gamma : float in [-1, 1]
Gamma, :math:`\\gamma`.
c : int >= 1 [scalar]
c : int >= 1 or None
Number of stages.
"""
Expand Down Expand Up @@ -71,15 +71,17 @@ def forward(self, y):
"""
check_size(y.size(-1), self.cep_order + 1, "dimension of cepstrum")
return self._forward(y, self.gamma)

K, y = torch.split(y, [1, self.cep_order], dim=-1)
if self.gamma == 0:
@staticmethod
def _forward(y, gamma):
K, y = torch.split(y, [1, y.size(-1) - 1], dim=-1)
if gamma == 0:
x0 = torch.log(K)
x1 = y
else:
z = torch.pow(K, self.gamma)
x0 = (z - 1) / self.gamma
z = torch.pow(K, gamma)
x0 = (z - 1) / gamma
x1 = y * z

x = torch.cat((x0, x1), dim=-1)
return x
4 changes: 2 additions & 2 deletions diffsptk/modules/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn as nn

from ..misc.utils import check
from ..misc.utils import check_mode

_quantization_level = {
"mid-rise": lambda n_bit: 1 << n_bit,
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, abs_max=1, n_bit=8, quantizer="mid-rise"):

self.abs_max = abs_max
self.n_bit = n_bit
self.quantizer = check(quantizer, _quantization_level)
self.quantizer = check_mode(quantizer, _quantization_level)

assert 0 < self.abs_max
assert 1 <= self.n_bit
Expand Down
4 changes: 2 additions & 2 deletions diffsptk/modules/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import torch.nn as nn

from ..misc.utils import check
from ..misc.utils import check_mode
from ..misc.utils import remove_gain

_spec2spec = {
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self, fft_length, eps=0, relative_floor=None, out_format="power"):

self.fft_length = fft_length
self.eps = eps
self.out_format = check(out_format)
self.out_format = check_mode(out_format, _spec2spec)

assert 2 <= self.fft_length
assert 0 <= self.eps
Expand Down
2 changes: 2 additions & 0 deletions docs/modules/dequantize.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ dequantize
.. autoclass:: diffsptk.InverseUniformQuantization
:members:

.. autofunction:: diffsptk.functional.dequantize

.. seealso:: :ref:`iulaw` :ref:`quantize`
2 changes: 2 additions & 0 deletions docs/modules/gnorm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ gnorm
.. autoclass:: diffsptk.GeneralizedCepstrumGainNormalization
:members:

.. autofunction:: diffsptk.functional.gnorm

.. seealso:: :ref:`ignorm` :ref:`mgc2mgc`
2 changes: 2 additions & 0 deletions docs/modules/ignorm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ ignorm
.. autoclass:: diffsptk.GeneralizedCepstrumInverseGainNormalization
:members:

.. autofunction:: diffsptk.functional.ignorm

.. seealso:: :ref:`gnorm` :ref:`mgc2mgc`
Loading

0 comments on commit fb813b8

Please sign in to comment.