Skip to content

Commit

Permalink
Merge pull request #22 from sp-nitech/istft
Browse files Browse the repository at this point in the history
Add istft
  • Loading branch information
takenori-y committed Feb 7, 2023
2 parents 51c0433 + 8c9b893 commit b78d577
Show file tree
Hide file tree
Showing 14 changed files with 209 additions and 29 deletions.
2 changes: 2 additions & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from .interpolate import Interpolation
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks as IPQMF
from .istft import InverseShortTermFourierTransform
from .istft import InverseShortTermFourierTransform as ISTFT
from .iulaw import MuLawExpansion
from .ivq import InverseVectorQuantization
from .lar2par import LogAreaRatioToParcorCoefficients
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/c2acr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(self, c):
Examples
--------
>>> c = torch.randn(5)
>>> c = diffsptk.nrand(4)
>>> c
tensor([-0.1751, 0.1950, -0.3211, 0.3523, -0.5453])
>>> c2acr = diffsptk.CepstrumToAutocorrelation(4, 16)
Expand Down
18 changes: 9 additions & 9 deletions diffsptk/core/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ class Delay(nn.Module):
start : int [scalar]
Start point, :math:`S`. If negative, advance signal.
keep_len : bool [scalar]
keeplen : bool [scalar]
If True, output has the same length of input.
"""

def __init__(self, start, keep_len=False):
def __init__(self, start, keeplen=False):
super(Delay, self).__init__()

self.start = start
self.keep_len = keep_len
self.keeplen = keeplen

def forward(self, x, dim=-1):
"""Delay signal.
Expand All @@ -61,29 +61,29 @@ def forward(self, x, dim=-1):
>>> y = delay(x)
>>> y
tensor([0., 0., 1., 2., 3.])
>>> delay = diffsptk.Delay(2, keep_len=True)
>>> delay = diffsptk.Delay(2, keeplen=True)
>>> y = delay(x)
>>> y
tensor([0., 0., 1.])
"""
# Generate zeros if needed.
if self.start > 0 or self.keep_len:
if self.start > 0 or self.keeplen:
shape = list(x.shape)
shape[dim] = abs(self.start)
zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device)

# Delay.
# Delay signal.
if 0 < self.start:
y = torch.cat((zeros, x), dim=dim)
if self.keep_len:
if self.keeplen:
y, _ = torch.split(y, [y.size(dim) - self.start, self.start], dim=dim)
return y

# Advance
# Advance signal.
if self.start < 0:
_, y = torch.split(x, [-self.start, x.size(dim) + self.start], dim=dim)
if self.keep_len:
if self.keeplen:
y = torch.cat((y, zeros), dim=dim)
return y

Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/gnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GeneralizedCepstrumGainNormalization(nn.Module):
Order of cepstrum, :math:`M`.
gamma : float [-1 <= gamma <= 1]
Gamma.
Gamma, :math:`\\gamma`.
c : int >= 1 [scalar]
Number of stages.
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/ignorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GeneralizedCepstrumInverseGainNormalization(nn.Module):
Order of cepstrum, :math:`M`.
gamma : float [-1 <= gamma <= 1]
Gamma.
Gamma, :math:`\\gamma`.
c : int >= 1 [scalar]
Number of stages.
Expand Down
90 changes: 90 additions & 0 deletions diffsptk/core/istft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn

from ..misc.utils import Lambda
from .unframe import Unframe


class InverseShortTermFourierTransform(nn.Module):
"""This is the opposite module to ShortTermFourierTransform.
Parameters
----------
frame_length : int >= 1 [scalar]
Frame length, :math:`L`.
frame_peirod : int >= 1 [scalar]
Frame period, :math:`P`.
fft_length : int >= L [scalar]
Number of FFT bins, :math:`N`.
norm : ['none', 'power', 'magnitude']
Normalization type of window.
window : ['blackman', 'hamming', 'hanning', 'bartlett', 'trapezoidal', \
'rectangular']
Window type.
"""

def __init__(
self,
frame_length,
frame_period,
fft_length,
norm="power",
window="blackman",
):
super(InverseShortTermFourierTransform, self).__init__()

self.ifft = Lambda(
lambda x: torch.fft.irfft(x, n=fft_length)[..., :frame_length]
)
self.unframe = Unframe(frame_length, frame_period, norm=norm, window=window)

def forward(self, y, out_length=None):
"""Compute inverse short-term Fourier transform.
Parameters
----------
y : Tensor [shape=(..., T/P, N/2+1)]
Complex spectrum.
Returns
-------
x : Tensor [shape=(..., T)]
Waveform.
Examples
--------
>>> x = diffsptk.ramp(1, 3)
>>> x
tensor([1., 2., 3.])
>>> stft_params = {"frame_length": 3, "frame_period": 1, "fft_length": 8}
>>> stft = diffsptk.STFT(**stft_params, out_format="complex")
>>> istft = diffsptk.ISTFT(**stft_params)
>>> y = istft(stft(x), out_length=3)
>>> y
tensor([1., 2., 3.])
"""
x = self.ifft(y)
x = self.unframe(x, out_length=out_length)
return x
1 change: 1 addition & 0 deletions diffsptk/core/levdur.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class PseudoLevinsonDurbinRecursion(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/levdur.html>`_
for details. Note that the current implementation does not use the Durbin's
algorithm though the class name includes it.
"""

def __init__(self):
Expand Down
16 changes: 10 additions & 6 deletions diffsptk/core/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn

from ..misc.utils import Lambda
from .frame import Frame
from .spec import Spectrum
from .window import Window
Expand All @@ -27,13 +29,13 @@ class ShortTermFourierTransform(nn.Module):
Parameters
----------
frame_length : int >= 1 [scalar]
Frame length, :math:`L_1`.
Frame length, :math:`L`.
frame_peirod : int >= 1 [scalar]
Frame period, :math:`P`.
fft_length : int >= L1 [scalar]
Number of FFT bins, :math:`L_2`.
fft_length : int >= L [scalar]
Number of FFT bins, :math:`N`.
zmean : bool [scalar]
If True, perform mean subtraction on each frame.
Expand All @@ -45,7 +47,7 @@ class ShortTermFourierTransform(nn.Module):
'rectangular']
Window type.
out_format : ['db', 'log-magnitude', 'magnitude', 'power']
out_format : ['db', 'log-magnitude', 'magnitude', 'power', 'complex']
Output format.
eps : float >= 0 [scalar]
Expand Down Expand Up @@ -73,7 +75,9 @@ def __init__(
self.stft = nn.Sequential(
Frame(frame_length, frame_period, zmean=zmean),
Window(frame_length, fft_length, norm=norm, window=window),
Spectrum(
Lambda(torch.fft.rfft)
if out_format == "complex"
else Spectrum(
fft_length,
out_format=out_format,
eps=eps,
Expand All @@ -91,7 +95,7 @@ def forward(self, x):
Returns
-------
y : Tensor [shape=(..., T/P, L/2+1)]
y : Tensor [shape=(..., T/P, N/2+1)]
Spectrum.
Examples
Expand Down
33 changes: 27 additions & 6 deletions diffsptk/core/unframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn
import torch.nn.functional as F

from .window import Window


class Unframe(nn.Module):
"""This is the opposite module to Frame.
Expand All @@ -34,9 +35,23 @@ class Unframe(nn.Module):
If True, assume that the center of data is the center of frame, otherwise
assume that the center of data is the left edge of frame.
norm : ['none', 'power', 'magnitude']
Normalization type of window.
window : ['blackman', 'hamming', 'hanning', 'bartlett', 'trapezoidal', \
'rectangular']
Window type.
"""

def __init__(self, frame_length, frame_period, center=True):
def __init__(
self,
frame_length,
frame_period,
center=True,
norm="none",
window="rectangular",
):
super(Unframe, self).__init__()

self.frame_length = frame_length
Expand All @@ -50,6 +65,11 @@ def __init__(self, frame_length, frame_period, center=True):
else:
self.left_pad_width = 0

self.register_buffer(
"window",
Window(frame_length, window=window, norm=norm).window.view(1, -1, 1),
)

def forward(self, y, out_length=None):
"""Revert framed waveform.
Expand Down Expand Up @@ -84,9 +104,8 @@ def forward(self, y, out_length=None):
"""
d = y.dim()
assert 2 <= d <= 4

N = y.size(-2)
assert 2 <= d <= 4, "Input must be 2D, 3D, or 4D tensor"

def fold(x):
x = F.fold(
Expand All @@ -100,13 +119,15 @@ def fold(x):
x = x[..., 0, 0, s:e]
return x

w = self.window.repeat(1, 1, N)
x = y.transpose(-2, -1)

if d == 2:
x = x.unsqueeze(0)

n = fold(torch.ones_like(x))
w = fold(w)
x = fold(x)
x = x / n
x = x / w

if d == 2:
x = x.squeeze(0)
Expand Down
10 changes: 10 additions & 0 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
import torch


class Lambda(torch.nn.Module):
def __init__(self, func, **opt):
super(Lambda, self).__init__()
self.func = func
self.opt = opt

def forward(self, x):
return self.func(x, **self.opt)


def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)

Expand Down
11 changes: 11 additions & 0 deletions docs/core/istft.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _istft:

istft
-----

.. autoclass:: diffsptk.ISTFT

.. autoclass:: diffsptk.InverseShortTermFourierTransform
:members:

.. seealso:: :ref:`frame` :ref:`window` :ref:`spec` :ref:`stft`
2 changes: 1 addition & 1 deletion docs/core/stft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ stft
.. autoclass:: diffsptk.ShortTermFourierTransform
:members:

.. seealso:: :ref:`frame` :ref:`window` :ref:`spec`
.. seealso:: :ref:`frame` :ref:`window` :ref:`spec` :ref:`istft`
8 changes: 4 additions & 4 deletions tests/test_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("S", [-2, 0, 2])
@pytest.mark.parametrize("keep_len", [False, True])
def test_compatibility(device, S, keep_len, T=20, B=2):
delay = diffsptk.Delay(S, keep_len)
@pytest.mark.parametrize("keeplen", [False, True])
def test_compatibility(device, S, keeplen, T=20, B=2):
delay = diffsptk.Delay(S, keeplen=keeplen)

opt = "-k" if keep_len else ""
opt = "-k" if keeplen else ""
U.check_compatibility(
device,
delay,
Expand Down
Loading

0 comments on commit b78d577

Please sign in to comment.