Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plp #62

Merged
merged 2 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ diffsptk.write("voiced.wav", x_voiced, sr)
diffsptk.write("unvoiced.wav", x_unvoiced, sr)
```

### Mel-spectrogram and MFCC extraction
### Mel-spectrogram, MFCC, and PLP extraction
```python
import diffsptk

Expand Down Expand Up @@ -127,6 +127,16 @@ mfcc = diffsptk.MFCC(
)
Y = mfcc(X)
print(Y.shape)

# Extract PLP.
plp = diffsptk.PLP(
plp_order=M,
n_channel=n_channel,
fft_length=n_fft,
sample_rate=sr,
)
Y = plp(X)
print(Y.shape)
```

### Subband decomposition
Expand Down
2 changes: 2 additions & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
from .pca import PrincipalComponentAnalysis as PCA
from .phase import Phase
from .pitch import Pitch
from .plp import PerceptualLinearPredictiveCoefficientsAnalysis
from .plp import PerceptualLinearPredictiveCoefficientsAnalysis as PLP
from .pol_root import RootsToPolynomial
from .poledf import AllPoleDigitalFilter
from .pqmf import PseudoQuadratureMirrorFilterBanks
Expand Down
3 changes: 1 addition & 2 deletions diffsptk/core/c2acr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,5 @@ def forward(self, c):
"""
x = torch.fft.rfft(c, n=self.fft_length).real
x = torch.exp(2 * x)
r = torch.fft.hfft(x)[..., : self.acr_order + 1]
r = r / self.fft_length
r = torch.fft.hfft(x, norm="forward")[..., : self.acr_order + 1]
return r
17 changes: 12 additions & 5 deletions diffsptk/core/fbank.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def __init__(
f_min=0,
f_max=None,
floor=1e-5,
use_power=False,
out_format="y",
):
super(MelFilterBankAnalysis, self).__init__()

self.floor = floor
self.use_power = use_power
self.out_format = out_format

if f_max is None:
Expand All @@ -87,6 +89,9 @@ def __init__(
def hz_to_mel(x):
return 1127 * np.log(x / 700 + 1)

def mel_to_hz(x):
return 700 * (np.exp(x / 1127) - 1)

lower_bin_index = max(1, int(f_min / sample_rate * fft_length + 1.5))
upper_bin_index = min(
fft_length // 2, int(f_max / sample_rate * fft_length + 0.5)
Expand All @@ -96,17 +101,18 @@ def hz_to_mel(x):
mel_max = hz_to_mel(f_max)

seed = np.arange(1, n_channel + 2)
freq = (mel_max - mel_min) / (n_channel + 1) * seed + mel_min
center_frequencies = (mel_max - mel_min) / (n_channel + 1) * seed + mel_min
self.center_frequencies = mel_to_hz(center_frequencies)

seed = np.arange(lower_bin_index, upper_bin_index)
mel = hz_to_mel(sample_rate * seed / fft_length)
lower_channel_map = [np.argmax(0 < (m <= freq)) for m in mel]
lower_channel_map = [np.argmax(0 < (m <= center_frequencies)) for m in mel]

diff = freq - np.insert(freq[:-1], 0, mel_min)
diff = center_frequencies - np.insert(center_frequencies[:-1], 0, mel_min)
weights = np.zeros((fft_length // 2 + 1, n_channel))
for i, k in enumerate(seed):
m = lower_channel_map[i]
w = (freq[max(0, m)] - mel[i]) / diff[max(0, m)]
w = (center_frequencies[max(0, m)] - mel[i]) / diff[max(0, m)]
if 0 < m:
weights[k, m - 1] += w
if m < n_channel:
Expand Down Expand Up @@ -141,7 +147,8 @@ def forward(self, x):
[3.3640, 3.4518, 2.7717, 0.5088]])

"""
y = torch.matmul(torch.sqrt(x), self.H)
y = x if self.use_power else torch.sqrt(x)
y = torch.matmul(y, self.H)
y = torch.log(torch.clip(y, min=self.floor))
E = (2 * x[..., 1:-1]).sum(-1) + x[..., 0] + x[..., -1]
E = torch.log(E / (2 * (x.size(-1) - 1))).unsqueeze(-1)
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/levdur.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class LevinsonDurbin(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/levdur.html>`_
for details.
for details. The implementation is based on a simple matrix inversion.

Parameters
----------
Expand Down
14 changes: 7 additions & 7 deletions diffsptk/core/mfcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
self.mfcc_order = mfcc_order

assert 1 <= self.mfcc_order < n_channel
assert 1 <= lifter

if out_format == 0 or out_format == "y":
self.format_func = lambda y, c, E: y
Expand All @@ -90,11 +91,10 @@ def __init__(
)
self.dct = DiscreteCosineTransform(n_channel)

m = np.arange(1, self.mfcc_order + 1)
liftering_vector = 1 + (lifter / 2) * np.sin((np.pi / lifter) * m)
self.register_buffer("liftering_vector", numpy_to_torch(liftering_vector))

self.const = np.sqrt(2)
m = np.arange(self.mfcc_order + 1)
v = 1 + (lifter / 2) * np.sin((np.pi / lifter) * m)
v[0] = np.sqrt(2)
self.register_buffer("liftering_vector", numpy_to_torch(v))

def forward(self, x):
"""Compute MFCC.
Expand Down Expand Up @@ -128,6 +128,6 @@ def forward(self, x):
"""
y, E = self.fbank(x)
y = self.dct(y)
c = y[..., :1] * self.const
y = y[..., 1 : self.mfcc_order + 1] * self.liftering_vector
y = y[..., : self.mfcc_order + 1] * self.liftering_vector
c, y = torch.split(y, [1, self.mfcc_order], dim=-1)
return self.format_func(y, c, E)
165 changes: 165 additions & 0 deletions diffsptk/core/plp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn as nn

from ..misc.utils import numpy_to_torch
from .fbank import MelFilterBankAnalysis
from .levdur import LevinsonDurbin
from .mgc2mgc import MelGeneralizedCepstrumToMelGeneralizedCepstrum


class PerceptualLinearPredictiveCoefficientsAnalysis(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/plp.html>`_
for details.

Parameters
----------
mfcc_order : int >= 1 [scalar]
Order of MFCC, :math:`M`.

n_channel : int >= 1 [scalar]
Number of mel-filter banks, :math:`C`.

fft_length : int >= 2 [scalar]
Number of FFT bins, :math:`L`.

sample_rate : int >= 1 [scalar]
Sample rate in Hz.

lifter : int >= 1 [scalar]
Liftering coefficient.

compression_factor : float > 0 [scalar]
Amplitude compression factor.

f_min : float >= 0 [scalar]
Minimum frequency in Hz.

f_max : float <= sample_rate // 2 [scalar]
Maximum frequency in Hz.

floor : float > 0 [scalar]
Minimum mel-filter bank output in linear scale.

out_format : ['y', 'yE', 'yc', 'ycE']
`y` is MFCC, `c` is C0, and `E` is energy.

n_fft : int >> :math:`M` [scalar]
Number of FFT bins. Accurate conversion requires the large value.

"""

def __init__(
self,
plp_order,
n_channel,
fft_length,
sample_rate,
lifter=1,
compression_factor=0.33,
out_format="y",
n_fft=512,
**fbank_kwargs,
):
super(PerceptualLinearPredictiveCoefficientsAnalysis, self).__init__()

self.plp_order = plp_order
self.compression_factor = compression_factor

assert 1 <= self.plp_order < n_channel
assert 1 <= lifter
assert 0 < self.compression_factor

if out_format == 0 or out_format == "y":
self.format_func = lambda y, c, E: y
elif out_format == 1 or out_format == "yE":
self.format_func = lambda y, c, E: torch.cat((y, E), dim=-1)
elif out_format == 2 or out_format == "yc":
self.format_func = lambda y, c, E: torch.cat((y, c), dim=-1)
elif out_format == 3 or out_format == "ycE":
self.format_func = lambda y, c, E: torch.cat((y, c, E), dim=-1)
else:
raise ValueError(f"out_format {out_format} is not supported")

self.fbank = MelFilterBankAnalysis(
n_channel,
fft_length,
sample_rate,
use_power=True,
out_format="y,E",
**fbank_kwargs,
)
self.levdur = LevinsonDurbin(self.plp_order)
self.lpc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum(
self.plp_order,
self.plp_order,
in_gamma=-1,
in_norm=True,
in_mul=True,
n_fft=n_fft,
)

f = self.fbank.center_frequencies[:-1] ** 2
e = (f / (f + 1.6e5)) ** 2 * (f + 1.44e6) / (f + 9.61e6)
self.register_buffer("equal_loudness_curve", numpy_to_torch(e))

m = np.arange(self.plp_order + 1)
v = 1 + (lifter / 2) * np.sin((np.pi / lifter) * m)
v[0] = 2
self.register_buffer("liftering_vector", numpy_to_torch(v))

def forward(self, x):
"""Compute PLP.

Parameters
----------
x : Tensor [shape=(..., L/2+1)]
Power spectrum.

Returns
-------
y : Tensor [shape=(..., M)]
PLP without C0.

E : Tensor [shape=(..., 1)]
Energy.

c : Tensor [shape=(..., 1)]
C0.

Examples
--------
>>> x = diffsptk.ramp(19)
>>> stft = diffsptk.STFT(frame_length=10, frame_period=10, fft_length=32)
>>> plp = diffsptk.PLP(4, 8, 32, 8000)
>>> y = plp(stft(x))
>>> y
tensor([[-0.2896, -0.2356, -0.0586, -0.0387],
[ 0.4468, -0.5820, 0.0104, -0.0505]])

"""
y, E = self.fbank(x)
y = (torch.exp(y) * self.equal_loudness_curve) ** self.compression_factor
y = torch.cat((y[..., :1], y, y[..., -1:]), dim=-1)
y = torch.fft.hfft(y, norm="forward")[..., : self.plp_order + 1].real
y = self.levdur(y)
y = self.lpc2c(y)
y *= self.liftering_vector
c, y = torch.split(y, [1, self.plp_order], dim=-1)
return self.format_func(y, c, E)
2 changes: 1 addition & 1 deletion docs/core/mfcc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ mfcc
.. autoclass:: diffsptk.MelFrequencyCepstralCoefficientsAnalysis
:members:

.. seealso:: :ref:`stft` :ref:`fbank`
.. seealso:: :ref:`stft` :ref:`fbank` :ref:`plp`
11 changes: 11 additions & 0 deletions docs/core/plp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _plp:

plp
---

.. autoclass:: diffsptk.PLP

.. autoclass:: diffsptk.PerceptualLinearPredictiveCoefficientsAnalysis
:members:

.. seealso:: :ref:`stft` :ref:`fbank` :ref:`mfcc`
45 changes: 45 additions & 0 deletions tests/test_plp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("o", [0, 1, 2, 3])
def test_compatibility(
device, o, M=4, C=10, L=32, sr=8000, lifter=20, f_min=300, factor=0.3, B=2
):
spec = diffsptk.Spectrum(L, eps=0)
plp = diffsptk.PLP(
M, C, L, sr, lifter=lifter, f_min=f_min, compression_factor=factor, out_format=o
)

s = sr // 1000
U.check_compatibility(
device,
[plp, spec],
[],
f"nrand -l {B*L}",
f"plp -m {M} -n {C} -l {L} -s {s} -c {lifter} -L {f_min} -f {factor} -o {o}",
[],
dx=L,
dy=M + (o if o <= 1 else o - 1),
)

U.check_differentiable(device, [plp, spec], [B, L])