Skip to content

Commit

Permalink
Merge pull request #55 from sp-nitech/root_pol
Browse files Browse the repository at this point in the history
Update root_pol
  • Loading branch information
takenori-y committed Dec 7, 2023
2 parents 48058d1 + a812d5d commit abea444
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 79 deletions.
2 changes: 1 addition & 1 deletion diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from .rlevdur import ReverseLevinsonDurbin
from .rmse import RootMeanSquaredError
from .rmse import RootMeanSquaredError as RMSE
from .root_pol import DurandKernerMethod
from .root_pol import PolynomialToRoots
from .smcep import SecondOrderAllPassMelCepstralAnalysis
from .snr import SignalToNoiseRatio
from .snr import SignalToNoiseRatio as SNR
Expand Down
83 changes: 17 additions & 66 deletions diffsptk/core/root_pol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,30 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..misc.utils import check_size
from ..misc.utils import numpy_to_torch
from ..misc.utils import vander


class DurandKernerMethod(nn.Module):
class PolynomialToRoots(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/root_pol.html>`_
for details.
order : int >= 1 [scalar]
Order of coefficients.
n_iter : int >= 1 [scalar]
Number of iterations.
eps : float >= 0 [scalar]
Convergence threshold.
out_format : ['rectangular', 'polar']
Output format.
"""

def __init__(self, order, n_iter=100, eps=1e-14, out_format="rectangular"):
super(DurandKernerMethod, self).__init__()
def __init__(self, order, out_format="rectangular"):
super(PolynomialToRoots, self).__init__()

self.order = order
self.n_iter = n_iter
self.eps = eps

assert 1 <= self.order
assert 1 <= self.n_iter
assert 0 <= self.eps

if out_format == 0 or out_format == "rectangular":
self.convert = lambda x: x
Expand All @@ -60,17 +46,10 @@ def __init__(self, order, n_iter=100, eps=1e-14, out_format="rectangular"):
else:
raise ValueError(f"out_format {out_format} is not supported")

ramp = np.arange(order + 1)
exponent = 1 / ramp[2:]
self.register_buffer("exponent", numpy_to_torch(exponent))

angle = ramp[:-1] * np.pi / (order / 2) + np.pi / (order * 2)
self.register_buffer("sin", numpy_to_torch(np.sin(angle)))
self.register_buffer("cos", numpy_to_torch(np.cos(angle)))
self.register_buffer("eye", torch.eye(order).bool())
self.register_buffer("eye", torch.eye(order - 1, order))

def forward(self, a):
"""Find roots of equations.
"""Find roots of polynomial.
Parameters
----------
Expand All @@ -82,53 +61,25 @@ def forward(self, a):
x : Tensor [shape=(..., M)]
Complex roots.
is_converged : Tensor [shape=(...,)]
True if convergence is reached.
Examples
--------
>>> a = torch.tensor([3, 4, 5])
>>> root_pol = diffsptk.DurandKernerMethod(a.size(-1) - 1)
>>> x, is_converged = root_pol(a)
>>> root_pol = diffsptk.PolynomialToRoots(a.size(-1) - 1)
>>> x = root_pol(a)
>>> x
tensor([[-0.6667+1.1055j, -0.6667-1.1055j]])
>>> is_converged
tensor([True])
"""
check_size(a.size(-1), self.order + 1, "dimension of coefficients")
if torch.any(a[..., 0] == 0):
raise RuntimeError("leading coefficient must be non-zero")

a = a[..., 1:] / a[..., :1] # (..., M)
radius, _ = torch.max(
2 * torch.pow(a[..., 1:].abs(), self.exponent),
dim=-1,
keepdim=True,
)
center = -a[..., :1] / self.order
x = torch.complex(
center + radius * self.cos,
center + radius * self.sin,
)
a = F.pad(a, (1, 0), value=1)
a = a.unsqueeze(-1).to(x.dtype)

for _ in range(self.n_iter):
y = x
for m in range(self.order):
xm = x[..., m : m + 1]
v = vander(xm, N=self.order + 1)
numer = torch.matmul(v, a).squeeze(-1)

w = (xm - x) + self.eye[m : m + 1]
denom = w.prod(dim=-1, keepdim=True)

delta = numer / denom
x = x - delta * self.eye[m : m + 1]

if (y - x).abs().max() <= self.eps:
break

is_converged = torch.max((y - x).abs(), dim=-1)[0] <= self.eps
x = self.convert(x)
# Make companion matrix.
a = -a[..., 1:] / a[..., :1] # (..., M)
E = self.eye.expand(a.size()[:-1] + self.eye.size())
A = torch.cat([a.unsqueeze(-2), E], dim=-2) # (..., M, M)

return x, is_converged
# Find roots as eigenvalues.
x, _ = torch.linalg.eig(A)
x = self.convert(x)
return x
2 changes: 1 addition & 1 deletion docs/core/root_pol.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
root_pol
--------

.. autoclass:: diffsptk.DurandKernerMethod
.. autoclass:: diffsptk.PolynomialToRoots
:members:
19 changes: 8 additions & 11 deletions tests/test_root_pol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,22 @@

@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("out_format", [0, 1])
def test_compatibility(device, out_format, M=12, B=2, n_iter=100):
root_pol = diffsptk.DurandKernerMethod(M, n_iter=n_iter, out_format=out_format)
def test_compatibility(device, out_format, M=12, B=2):
root_pol = diffsptk.PolynomialToRoots(M, out_format=out_format)

def eq(y_hat, y):
re = np.real(y_hat).flatten()
im = np.imag(y_hat).flatten()
y2 = np.empty((re.size + im.size,), dtype=y.dtype)
y2[0::2] = re
y2[1::2] = im
return U.allclose(y2, y)
y_hat = np.sort_complex(y_hat)
y = np.sort_complex(y[0::2] + 1j * y[1::2])
return U.allclose(y_hat, y)

U.check_compatibility(
device,
[lambda x: x[0], root_pol],
root_pol,
[],
f"nrand -m {M}",
f"root_pol -m {M} -i {n_iter} -o {out_format}",
f"root_pol -m {M} -i 100 -o {out_format}",
[],
eq=eq,
)

U.check_differentiable(device, [lambda x: x[0], root_pol], [B, M + 1])
U.check_differentiable(device, root_pol, [B, M + 1])

0 comments on commit abea444

Please sign in to comment.