Skip to content

Commit

Permalink
Merge pull request #67 from sp-nitech/bug_fix
Browse files Browse the repository at this point in the history
Bug fix
  • Loading branch information
takenori-y committed Feb 5, 2024
2 parents 09d866e + 4f805a7 commit c41e0cd
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ diffsptk
*diffsptk* is a differentiable version of [SPTK](https://github.com/sp-nitech/SPTK) based on the PyTorch framework.

[![Latest Manual](https://img.shields.io/badge/docs-latest-blue.svg)](https://sp-nitech.github.io/diffsptk/latest/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/1.2.0/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/1.2.1/)
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.11.0%20%7C%202.2.0-orange.svg)](https://pypi.python.org/pypi/diffsptk)
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/linear_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, x):
assert x.dim() == 3, "Input must be 3D tensor"
B, T, D = x.shape

x = x.mT # (B, D, T)
x = x.mT.contiguous() # (B, D, T)
x = replicate1(x, left=False)
x = F.interpolate(
x,
Expand Down
16 changes: 9 additions & 7 deletions diffsptk/core/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,15 @@ def __init__(
self.stft = nn.Sequential(
Frame(frame_length, frame_period, center=center, zmean=zmean),
Window(frame_length, fft_length, norm=norm, window=window),
Lambda(torch.fft.rfft)
if out_format == "complex"
else Spectrum(
fft_length,
out_format=out_format,
eps=eps,
relative_floor=relative_floor,
(
Lambda(torch.fft.rfft)
if out_format == "complex"
else Spectrum(
fft_length,
out_format=out_format,
eps=eps,
relative_floor=relative_floor,
)
),
)

Expand Down
2 changes: 1 addition & 1 deletion diffsptk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.2.0"
__version__ = "1.2.1"
9 changes: 7 additions & 2 deletions tests/test_linear_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_compatibility(device, P=4, N=10):
@pytest.mark.parametrize("P", [1, 4])
def test_compatibility(device, P, N=10):
linear_intpl = diffsptk.LinearInterpolation(P)

tmp = "linear_intpl.tmp"
Expand All @@ -30,7 +31,11 @@ def test_compatibility(device, P=4, N=10):
linear_intpl,
[f"ramp -s 1 -e {N} > {tmp}"],
f"cat {tmp}",
f"step -v 1 -l {N*P} | zerodf {tmp} -i 1 -m 0 -p {P}",
(
f"cat {tmp}"
if P == 1
else f"step -v 1 -l {N*P} | zerodf {tmp} -i 1 -m 0 -p {P}"
),
[f"rm {tmp}"],
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_mglsadf.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def test_compatibility(
@pytest.mark.parametrize("ignore_gain", [False, True])
@pytest.mark.parametrize("phase", ["minimum", "maximum", "zero"])
@pytest.mark.parametrize("mode", ["multi-stage", "single-stage", "freq-domain"])
def test_differentiable(device, ignore_gain, phase, mode, B=4, T=20, M=4):
def test_differentiable(device, ignore_gain, phase, mode, B=4, T=20, P=2, M=4):
if mode == "multi-stage":
params = {"cep_order": 10}
elif mode == "single-stage":
params = {"ir_length": 20, "n_fft": 32}
elif mode == "freq-domain":
params = {"frame_length": 4, "fft_length": 16}
params = {"frame_length": 6, "fft_length": 16}

mglsadf = diffsptk.MLSA(
M, 1, ignore_gain=ignore_gain, phase=phase, mode=mode, **params
M, P, ignore_gain=ignore_gain, phase=phase, mode=mode, **params
)
U.check_differentiable(device, mglsadf, [(B, T), (B, T, M + 1)])
U.check_differentiable(device, mglsadf, [(B, T), (B, T // P, M + 1)])

0 comments on commit c41e0cd

Please sign in to comment.