Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update root_pol #55

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
change algorithm of root_pol
  • Loading branch information
takenori-y committed Dec 7, 2023
commit a689b1659a500570b4dad41ebd1d1eb535b2af46
2 changes: 1 addition & 1 deletion diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
from .rlevdur import ReverseLevinsonDurbin
from .rmse import RootMeanSquaredError
from .rmse import RootMeanSquaredError as RMSE
from .root_pol import DurandKernerMethod
from .root_pol import PolynomialToRoots
from .smcep import SecondOrderAllPassMelCepstralAnalysis
from .snr import SignalToNoiseRatio
from .snr import SignalToNoiseRatio as SNR
Expand Down
83 changes: 17 additions & 66 deletions diffsptk/core/root_pol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,30 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..misc.utils import check_size
from ..misc.utils import numpy_to_torch
from ..misc.utils import vander


class DurandKernerMethod(nn.Module):
class PolynomialToRoots(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/root_pol.html>`_
for details.

order : int >= 1 [scalar]
Order of coefficients.

n_iter : int >= 1 [scalar]
Number of iterations.

eps : float >= 0 [scalar]
Convergence threshold.

out_format : ['rectangular', 'polar']
Output format.

"""

def __init__(self, order, n_iter=100, eps=1e-14, out_format="rectangular"):
super(DurandKernerMethod, self).__init__()
def __init__(self, order, out_format="rectangular"):
super(PolynomialToRoots, self).__init__()

self.order = order
self.n_iter = n_iter
self.eps = eps

assert 1 <= self.order
assert 1 <= self.n_iter
assert 0 <= self.eps

if out_format == 0 or out_format == "rectangular":
self.convert = lambda x: x
Expand All @@ -60,17 +46,10 @@ def __init__(self, order, n_iter=100, eps=1e-14, out_format="rectangular"):
else:
raise ValueError(f"out_format {out_format} is not supported")

ramp = np.arange(order + 1)
exponent = 1 / ramp[2:]
self.register_buffer("exponent", numpy_to_torch(exponent))

angle = ramp[:-1] * np.pi / (order / 2) + np.pi / (order * 2)
self.register_buffer("sin", numpy_to_torch(np.sin(angle)))
self.register_buffer("cos", numpy_to_torch(np.cos(angle)))
self.register_buffer("eye", torch.eye(order).bool())
self.register_buffer("eye", torch.eye(order - 1, order))

def forward(self, a):
"""Find roots of equations.
"""Find roots of polynomial.

Parameters
----------
Expand All @@ -82,53 +61,25 @@ def forward(self, a):
x : Tensor [shape=(..., M)]
Complex roots.

is_converged : Tensor [shape=(...,)]
True if convergence is reached.

Examples
--------
>>> a = torch.tensor([3, 4, 5])
>>> root_pol = diffsptk.DurandKernerMethod(a.size(-1) - 1)
>>> x, is_converged = root_pol(a)
>>> root_pol = diffsptk.PolynomialToRoots(a.size(-1) - 1)
>>> x = root_pol(a)
>>> x
tensor([[-0.6667+1.1055j, -0.6667-1.1055j]])
>>> is_converged
tensor([True])

"""
check_size(a.size(-1), self.order + 1, "dimension of coefficients")
if not torch.any(a[..., 0] == 0):
raise RuntimeError("leading coefficient must be non-zero")

a = a[..., 1:] / a[..., :1] # (..., M)
radius, _ = torch.max(
2 * torch.pow(a[..., 1:].abs(), self.exponent),
dim=-1,
keepdim=True,
)
center = -a[..., :1] / self.order
x = torch.complex(
center + radius * self.cos,
center + radius * self.sin,
)
a = F.pad(a, (1, 0), value=1)
a = a.unsqueeze(-1).to(x.dtype)

for _ in range(self.n_iter):
y = x
for m in range(self.order):
xm = x[..., m : m + 1]
v = vander(xm, N=self.order + 1)
numer = torch.matmul(v, a).squeeze(-1)

w = (xm - x) + self.eye[m : m + 1]
denom = w.prod(dim=-1, keepdim=True)

delta = numer / denom
x = x - delta * self.eye[m : m + 1]

if (y - x).abs().max() <= self.eps:
break

is_converged = torch.max((y - x).abs(), dim=-1)[0] <= self.eps
x = self.convert(x)
# Make companion matrix.
a = -a[..., 1:] / a[..., :1] # (..., M)
E = self.eye.expand(a.size()[:-1] + self.eye.size())
A = torch.cat([a.unsqueeze(-2), E], dim=-2) # (..., M, M)

return x, is_converged
# Find roots as eigenvalues.
x, _ = torch.linalg.eig(A)
x = self.convert(x)
return x
2 changes: 1 addition & 1 deletion docs/core/root_pol.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
root_pol
--------

.. autoclass:: diffsptk.DurandKernerMethod
.. autoclass:: diffsptk.PolynomialToRoots
:members:
19 changes: 8 additions & 11 deletions tests/test_root_pol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,22 @@

@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("out_format", [0, 1])
def test_compatibility(device, out_format, M=12, B=2, n_iter=100):
root_pol = diffsptk.DurandKernerMethod(M, n_iter=n_iter, out_format=out_format)
def test_compatibility(device, out_format, M=12, B=2):
root_pol = diffsptk.PolynomialToRoots(M, out_format=out_format)

def eq(y_hat, y):
re = np.real(y_hat).flatten()
im = np.imag(y_hat).flatten()
y2 = np.empty((re.size + im.size,), dtype=y.dtype)
y2[0::2] = re
y2[1::2] = im
return U.allclose(y2, y)
y_hat = np.sort_complex(y_hat)
y = np.sort_complex(y[0::2] + 1j * y[1::2])
return U.allclose(y_hat, y)

U.check_compatibility(
device,
[lambda x: x[0], root_pol],
root_pol,
[],
f"nrand -m {M}",
f"root_pol -m {M} -i {n_iter} -o {out_format}",
f"root_pol -m {M} -i 100 -o {out_format}",
[],
eq=eq,
)

U.check_differentiable(device, [lambda x: x[0], root_pol], [B, M + 1])
U.check_differentiable(device, root_pol, [B, M + 1])
Loading