Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mlsacheck #30

Merged
merged 4 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .mglsadf import PseudoMGLSADigitalFilter as MLSA
from .mlpg import MaximumLikelihoodParameterGeneration
from .mlpg import MaximumLikelihoodParameterGeneration as MLPG
from .mlsacheck import MLSADigitalFilterStabilityCheck
from .mpir2c import MinimumPhaseImpulseResponseToCepstrum
from .msvq import MultiStageVectorQuantization
from .ndps2c import NegativeDerivativeOfPhaseSpectrumToCepstrum
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/b2mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MLSADigitalFilterCoefficientsToMelCepstrum(nn.Module):

"""

def __init__(self, cep_order, alpha):
def __init__(self, cep_order, alpha=0):
super(MLSADigitalFilterCoefficientsToMelCepstrum, self).__init__()

assert 0 <= cep_order
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/freqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FrequencyTransform(nn.Module):

"""

def __init__(self, in_order, out_order, alpha):
def __init__(self, in_order, out_order, alpha=0):
super(FrequencyTransform, self).__init__()

assert 0 <= in_order
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/freqt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class SecondOrderAllPassFrequencyTransform(nn.Module):

"""

def __init__(self, in_order, out_order, alpha, theta, n_fft=512):
def __init__(self, in_order, out_order, alpha=0, theta=0, n_fft=512):
super(SecondOrderAllPassFrequencyTransform, self).__init__()

assert 0 <= in_order
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/ifreqt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class SecondOrderAllPassInverseFrequencyTransform(nn.Module):

"""

def __init__(self, in_order, out_order, alpha, theta, n_fft=512):
def __init__(self, in_order, out_order, alpha=0, theta=0, n_fft=512):
super(SecondOrderAllPassInverseFrequencyTransform, self).__init__()

assert 0 <= in_order
Expand Down
8 changes: 4 additions & 4 deletions diffsptk/core/lpc2par.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def forward(self, a):
if torch.any(1 <= torch.abs(km)):
if self.warn_type == "ignore":
pass
elif self.warn_type == "warn": # pragma: no cover
warnings.warn("Unstable LPC is detected")
elif self.warn_type == "exit": # pragma: no cover
raise RuntimeError("Unstable LPC is detected")
elif self.warn_type == "warn":
warnings.warn("Unstable LPC coefficients")
elif self.warn_type == "exit":
raise RuntimeError("Unstable LPC coefficients")
else:
raise RuntimeError

Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/lpccheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def forward(self, a1):
Returns
-------
a2 : Tensor [shape=(..., M+1)]
Stabilized LPC coefficients.
Modified LPC coefficients.

Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/mc2b.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class MelCepstrumToMLSADigitalFilterCoefficients(nn.Module):

"""

def __init__(self, cep_order, alpha):
def __init__(self, cep_order, alpha=0):
super(MelCepstrumToMLSADigitalFilterCoefficients, self).__init__()

assert 0 <= cep_order
Expand Down
168 changes: 168 additions & 0 deletions diffsptk/core/mlsacheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import warnings

import numpy as np
import torch
import torch.nn as nn

from ..misc.utils import check_size
from ..misc.utils import numpy_to_torch


class MLSADigitalFilterStabilityCheck(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/mlsacheck.html>`_
for details.

Parameters
----------
cep_order : int >= 0 [scalar]
Order of mel-cepstrum, :math:`M`.

alpha : float [-1 < alpha < 1]
Frequency warping factor, :math:`\\alpha`.

fft_length : int > M [scalar]
Number of FFT bins, :math:`L`.

pade_order : [4 <= int <= 7].
Order of Pade approximation.

strict : bool [scalar]
If True, keep maximum log approximation error rather than MLSA filter stability.

fast : bool [scalar]
Fast mode.

threshold : float > 0 [scalar]
Threshold value. If not given, automatically computed.

mod_type : ['clip', 'scale']
Modification type.

warn_type : ['ignore', 'warn', 'exit']
Behavior for unstable MLSA.

"""

def __init__(
self,
cep_order,
alpha=0,
fft_length=256,
pade_order=4,
strict=True,
fast=True,
threshold=None,
mod_type="scale",
warn_type="warn",
):
super(MLSADigitalFilterStabilityCheck, self).__init__()

self.cep_order = cep_order
self.fft_length = fft_length
self.fast = fast
self.mod_type = mod_type
self.warn_type = warn_type

assert 0 <= self.cep_order
assert self.mod_type in ("clip", "scale")
assert self.warn_type in ("ignore", "warn", "exit")
assert not (self.fast and self.mod_type == "clip")

if threshold is None:
if pade_order == 4:
threshold = 4.5 if strict else 6.20
elif pade_order == 5:
threshold = 6.0 if strict else 7.65
elif pade_order == 6:
threshold = 7.4 if strict else 9.13
elif pade_order == 7:
threshold = 8.9 if strict else 10.6
else:
raise ValueError("Unexpected Pade order")
self.threshold = threshold
assert 0 < threshold

alpha_vector = (-alpha) ** np.arange(self.cep_order + 1)
self.register_buffer("alpha_vector", numpy_to_torch(alpha_vector))

def forward(self, c1):
"""Check stability of MLSA filter.

Parameters
----------
c1 : Tensor [shape=(..., M+1)]
Mel-cepstrum.

Returns
-------
c2 : Tensor [shape=(..., M+1)]
Modified mel-cepstrum.

Examples
--------
>>> c1 = diffsptk.nrand(4, stdv=10)
>>> c1
tensor([ 1.8963, 7.6629, 4.4804, 8.0669, -1.2768])
>>> mlsacheck = diffsptk.MLSADigitalFilterStabilityCheck(4, warn_type="ignore")
>>> c2 = mlsacheck(c1)
>>> c2
tensor([ 1.3336, 1.7537, 1.0254, 1.8462, -0.2922])

"""
check_size(c1.size(-1), self.cep_order + 1, "dimension of mel-cepstrum")

gain = (c1 * self.alpha_vector).sum(-1, keepdim=True)

if self.fast:
max_amplitude = c1.sum(-1, keepdim=True) - gain
else:
c1 = torch.cat((c1[..., :1] - gain, c1[..., 1:]), dim=-1)
C1 = torch.fft.rfft(c1, n=self.fft_length)
C1_amplitude = C1.abs()
max_amplitude, _ = C1_amplitude.max(-1, keepdim=True)
max_amplitude = torch.clip(max_amplitude, 1e-16)

if torch.any(self.threshold < max_amplitude):
if self.warn_type == "ignore":
pass
elif self.warn_type == "warn":
warnings.warn("Unstable MLSA filter")
elif self.warn_type == "exit":
raise RuntimeError("Unstable MLSA filter")
else:
raise RuntimeError

if self.mod_type == "clip":
scale = self.threshold / C1_amplitude
elif self.mod_type == "scale":
scale = self.threshold / max_amplitude
else:
raise RuntimeError
scale = torch.clip(scale, max=1)

if self.fast:
c0, cX = torch.split(c1, [1, self.cep_order], dim=-1)
c0 = (c0 - gain) * scale + gain
cX = cX * scale
c2 = torch.cat((c0, cX), dim=-1)
else:
c2 = torch.fft.irfft(C1 * scale)[..., : self.cep_order + 1]
c2 = torch.cat((c2[..., :1] + gain, c2[..., 1:]), dim=-1)

return c2
2 changes: 1 addition & 1 deletion docs/core/imglsadf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ imglsadf
.. autoclass:: diffsptk.PseudoInverseMGLSADigitalFilter
:members:

.. seealso:: :ref:`mgcep` :ref:`mglsadf`
.. seealso:: :ref:`mgcep` :ref:`mglsadf` :ref:`mlsacheck`
2 changes: 1 addition & 1 deletion docs/core/mglsadf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ mglsadf
.. autoclass:: diffsptk.PseudoMGLSADigitalFilter
:members:

.. seealso:: :ref:`mgcep` :ref:`imglsadf`
.. seealso:: :ref:`mgcep` :ref:`imglsadf` :ref:`mlsacheck`
9 changes: 9 additions & 0 deletions docs/core/mlsacheck.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _mlsacheck:

mlsacheck
---------

.. autoclass:: diffsptk.MLSADigitalFilterStabilityCheck
:members:

.. seealso:: :ref:`mglsadf` :ref:`imglsadf`
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"if self.verbose:",
"raise NotImplementedError",
"raise RuntimeError",
"raise ValueError",
"self.verbose",
"warn_type",
"warnings",
]

Expand Down
58 changes: 58 additions & 0 deletions tests/test_mlsacheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("P", [4, 5, 6, 7])
@pytest.mark.parametrize("strict", [False, True])
@pytest.mark.parametrize("fast", [False, True])
@pytest.mark.parametrize("mod_type", ["clip", "scale"])
def test_compatibility(device, P, strict, fast, mod_type, M=9, L=32, alpha=0.1, B=10):
if fast and mod_type == "clip":
return

mlsacheck = diffsptk.MLSADigitalFilterStabilityCheck(
M,
alpha=alpha,
fft_length=L,
pade_order=P,
strict=strict,
fast=fast,
mod_type=mod_type,
warn_type="ignore",
)

opt = "-f " if fast else ""
opt += "-r 0 " if strict else "-r 1 "
opt += "-t 0 " if mod_type == "clip" else "-t 1 "

U.check_compatibility(
device,
mlsacheck,
[],
f"nrand -l {B*L} | mgcep -m {M} -l {L} -a {alpha} | sopr -m 10",
f"mlsacheck -m {M} -l {L} -a {alpha} -P {P} {opt} -e 0 -x",
[],
dx=M + 1,
dy=M + 1,
)

U.check_differentiable(device, mlsacheck, [B, M + 1])