Skip to content

Commit

Permalink
Merge pull request #53 from sp-nitech/lpc2lsp
Browse files Browse the repository at this point in the history
Add lpc2lsp
  • Loading branch information
takenori-y committed Dec 5, 2023
2 parents b558659 + 33717d4 commit 7489ef9
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ diffsptk
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/1.0.1/)
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.11.0%20%7C%202.1.0-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.11.0%20%7C%202.1.1-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyPI Version](https://img.shields.io/pypi/v/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![Codecov](https://codecov.io/gh/sp-nitech/diffsptk/branch/master/graph/badge.svg)](https://app.codecov.io/gh/sp-nitech/diffsptk)
[![License](https://img.shields.io/github/license/sp-nitech/diffsptk.svg)](https://github.com/sp-nitech/diffsptk/blob/master/LICENSE)
Expand Down
1 change: 1 addition & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .linear_intpl import LinearInterpolation
from .lpc import LinearPredictiveCodingAnalysis
from .lpc import LinearPredictiveCodingAnalysis as LPC
from .lpc2lsp import LinearPredictiveCoefficientsToLineSpectralPairs
from .lpc2par import LinearPredictiveCoefficientsToParcorCoefficients
from .lpccheck import LinearPredictiveCoefficientsStabilityCheck
from .magic_intpl import MagicNumberInterpolation
Expand Down
3 changes: 2 additions & 1 deletion diffsptk/core/excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn as nn
import torch.nn.functional as F

from ..misc.utils import TWO_PI
from ..misc.utils import UNVOICED_SYMBOL
from .linear_intpl import LinearInterpolation

Expand Down Expand Up @@ -107,7 +108,7 @@ def forward(self, p):
e = torch.zeros_like(p)
e[pulse_pos] = torch.sqrt(p[pulse_pos])
elif self.voiced_region == "sinusoidal":
e = torch.sin((2 * torch.pi) * phase)
e = torch.sin(TWO_PI * phase)
elif self.voiced_region == "sawtooth":
e = torch.fmod(phase, 2) - 1
else:
Expand Down
185 changes: 185 additions & 0 deletions diffsptk/core/lpc2lsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..misc.utils import TWO_PI
from ..misc.utils import check_size
from ..misc.utils import numpy_to_torch


class LinearPredictiveCoefficientsToLineSpectralPairs(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/lpc2lsp.html>`_
for details. **Note that this module cannot compute gradient**.
Parameters
----------
lpc_order : int >= 0 [scalar]
Order of LPC, :math:`M`.
n_split : int >= 1 [scalar]
Number of splits of unit semicircle.
n_iter : int >= 0 [scalar]
Number of pseudo iterations.
log_gain : bool [scalar]
If True, output gain in log scale.
sample_rate : int >= 1 [scalar]
Sample rate in Hz.
out_format : ['radian', 'cycle', 'khz', 'hz']
Output format.
"""

def __init__(
self,
lpc_order,
n_split=512,
n_iter=0,
log_gain=False,
sample_rate=None,
out_format="radian",
):
super(LinearPredictiveCoefficientsToLineSpectralPairs, self).__init__()

self.lpc_order = lpc_order
self.log_gain = log_gain

assert 0 <= self.lpc_order < n_split
assert 0 <= n_iter

if self.lpc_order % 2 == 0:
sign = np.ones(self.lpc_order // 2 + 2)
sign[::2] = -1
self.register_buffer("sign", numpy_to_torch(sign))
mask = np.ones(self.lpc_order // 2 + 2)
mask[::2] = 0
self.register_buffer("mask", numpy_to_torch(mask))

x = np.linspace(1, -1, n_split * (n_iter + 1) + 1)
self.register_buffer("x", numpy_to_torch(x))

# Avoid the use of Chebyshev polynomials.
omega = np.arccos(x)
k = np.arange(self.lpc_order // 2 + 2)
Tx = np.cos(k.reshape(-1, 1) * omega.reshape(1, -1))
scale = np.ones(self.lpc_order // 2 + 2)
scale[0] = 0.5
Tx = scale.reshape(-1, 1) * Tx
self.register_buffer("Tx", numpy_to_torch(Tx))

if out_format == 0 or out_format == "radian":
self.convert = lambda x: x
elif out_format == 1 or out_format == "cycle":
self.convert = lambda x: x / TWO_PI
elif out_format == 2 or out_format == "khz":
assert sample_rate is not None and 0 < sample_rate
self.convert = lambda x: x * (sample_rate / 1000 / TWO_PI)
elif out_format == 3 or out_format == "hz":
assert sample_rate is not None and 0 < sample_rate
self.convert = lambda x: x * (sample_rate / TWO_PI)
else:
raise ValueError(f"out_format {out_format} is not supported")

def forward(self, a):
"""Convert LPC to LSP.
Parameters
----------
a : Tensor [shape=(..., M+1)]
LPC coefficients.
Returns
-------
w : Tensor [shape=(..., M+1)]
LSP coefficients.
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([-1.5326, 1.0875, -1.5925, 0.6913, 1.6217])
>>> lpc = diffsptk.LPC(3, 5)
>>> a = lpc(x)
>>> a
tensor([ 2.7969, 0.3908, 0.0458, -0.0859])
>>> lpc2lsp = diffsptk.LinearPredictiveCoefficientsToLineSpectralPairs(3)
>>> w = lpc2lsp(a)
>>> w
tensor([2.7969, 0.9037, 1.8114, 2.4514])
"""
check_size(a.size(-1), self.lpc_order + 1, "dimension of LPC")

K, a = torch.split(a, [1, self.lpc_order], dim=-1)

p1 = a[..., : (self.lpc_order + 1) // 2]
p2 = a.flip(-1)[..., : (self.lpc_order + 1) // 2]
q1 = p1 + p2
q2 = p1 - p2
if self.lpc_order % 2 == 0:
d1 = F.pad(q1, (1, 0), value=1)
d2 = F.pad(q2, (1, 0), value=1)
c1_odd = torch.cumsum(d1 * self.sign[:-1], dim=-1)
c1_even = torch.cumsum(d1 * self.sign[1:], dim=-1)
c1 = c1_odd * self.mask[:-1] + c1_even * self.mask[1:]
c2 = torch.cumsum(d2, dim=-1)
elif self.lpc_order == 1:
c1 = F.pad(q1, (1, 0), value=1)
c2 = c1
else:
d1 = F.pad(q1, (1, 0), value=1)
d2_odd = F.pad(q2[..., 0::2], (1, 0), value=0)
d2_even = F.pad(q2[..., 1::2], (1, 0), value=1)
c1 = d1
c2_odd = torch.cumsum(d2_odd, dim=-1)
c2_even = torch.cumsum(d2_even, dim=-1)
c2 = torch.flatten(torch.stack([c2_odd, c2_even], dim=-1), start_dim=-2)
c2 = c2[..., 1:-1]
c1 = c1.flip(-1)
c2 = c2.flip(-1)

y1 = torch.matmul(c1, self.Tx[: c1.size(-1)])
y2 = torch.matmul(c2, self.Tx[: c2.size(-1)])

index1 = y1[..., :-1] * y1[..., 1:] <= 0
index2 = y2[..., :-1] * y2[..., 1:] <= 0
index = torch.logical_or(index1, index2)

i1 = F.pad(index1, (0, 1), value=False)
i2 = F.pad(index2, (0, 1), value=False)
i1 = torch.logical_or(i1, torch.roll(i1, 1, dims=-1))
i2 = torch.logical_or(i2, torch.roll(i2, 1, dims=-1))
y = y1 * i1 + y2 * i2

x_upper = torch.masked_select(self.x[:-1], index)
x_lower = torch.masked_select(self.x[1:], index)
y_upper = torch.masked_select(y[..., :-1], index)
y_lower = torch.masked_select(y[..., 1:], index)
x = (y_lower * x_upper - y_upper * x_lower) / (y_lower - y_upper)
w = torch.acos(x).view_as(a)

w = self.convert(w)
if self.log_gain:
K = torch.log(K)
w = torch.cat([K, w], dim=-1)
return w
1 change: 1 addition & 0 deletions diffsptk/core/lpc2par.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def forward(self, a):
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([ 0.7829, -0.2028, 1.6912, 0.1454, 0.4861])
>>> lpc = diffsptk.LPC(3, 5)
>>> a = lpc(x)
Expand Down
1 change: 1 addition & 0 deletions diffsptk/core/par2lpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def forward(self, k):
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([ 0.7829, -0.2028, 1.6912, 0.1454, 0.4861])
>>> lpc = diffsptk.LPC(3, 5)
>>> a = lpc(x)
Expand Down
1 change: 1 addition & 0 deletions diffsptk/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .signals import *
from .utils import TWO_PI as two_pi
from .utils import get_alpha
from .utils import read
from .utils import write
1 change: 1 addition & 0 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch

UNVOICED_SYMBOL = 0
TWO_PI = 2 * torch.pi


class Lambda(torch.nn.Module):
Expand Down
9 changes: 9 additions & 0 deletions docs/core/lpc2lsp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _lpc2lsp:

lpc2lsp
-------

.. autoclass:: diffsptk.LinearPredictiveCoefficientsToLineSpectralPairs
:members:

.. seealso:: :ref:`lpc`
4 changes: 4 additions & 0 deletions docs/misc/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ utils
.. autofunction:: diffsptk.read

.. autofunction:: diffsptk.write

.. data:: diffsptk.two_pi

The value of :math:`2\pi`.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
"Programming Language :: Python :: 3.11",
]
dependencies = [
"soundfile",
"soundfile >= 0.10.2",
"torch >= 1.11.0",
"torchcrepe >= 0.0.16",
"torchlpc >= 0.2.0",
"numpy",
"vector-quantize-pytorch >= 0.8.0",
"numpy",
]
dynamic = ["version"]

Expand Down
40 changes: 40 additions & 0 deletions tests/test_lpc2lsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("M", [1, 7, 8])
@pytest.mark.parametrize("out_format", [0, 1, 2, 3])
def test_compatibility(device, M, out_format, L=32, B=2):
lpc2lsp = diffsptk.LinearPredictiveCoefficientsToLineSpectralPairs(
M, n_iter=1, log_gain=True, sample_rate=8000, out_format=out_format
)

U.check_compatibility(
device,
lpc2lsp,
[],
f"nrand -l {B*L} | lpc -l {L} -m {M}",
f"lpc2lsp -m {M} -o {out_format} -i 1 -k 1 -s 8",
[],
dx=M + 1,
dy=M + 1,
)

0 comments on commit 7489ef9

Please sign in to comment.