Skip to content

Commit

Permalink
expose parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Sep 27, 2023
1 parent 16e6aa5 commit 6ee1399
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
10 changes: 9 additions & 1 deletion diffsptk/core/istft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class InverseShortTermFourierTransform(nn.Module):
fft_length : int >= L [scalar]
Number of FFT bins, :math:`N`.
center : bool [scalar]
If True, assume that the center of data is the center of frame, otherwise
assume that the center of data is the left edge of frame.
norm : ['none', 'power', 'magnitude']
Normalization type of window.
Expand All @@ -49,6 +53,8 @@ def __init__(
frame_length,
frame_period,
fft_length,
*,
center=True,
norm="power",
window="blackman",
):
Expand All @@ -57,7 +63,9 @@ def __init__(
self.ifft = Lambda(
lambda x: torch.fft.irfft(x, n=fft_length)[..., :frame_length]
)
self.unframe = Unframe(frame_length, frame_period, norm=norm, window=window)
self.unframe = Unframe(
frame_length, frame_period, center=center, norm=norm, window=window
)

def forward(self, y, out_length=None):
"""Compute inverse short-term Fourier transform.
Expand Down
10 changes: 7 additions & 3 deletions diffsptk/core/mglsadf.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,13 @@ def __init__(
phase="minimum",
frame_length=400,
fft_length=512,
**kwargs,
n_fft=512,
**stft_kwargs,
):
super(FrequencyDomainFIRFilter, self).__init__()

assert 2 * frame_period < frame_length

self.ignore_gain = ignore_gain

if self.ignore_gain:
Expand All @@ -380,17 +383,18 @@ def __init__(
)

self.stft = ShortTermFourierTransform(
frame_length, frame_period, fft_length, out_format="complex", **kwargs
frame_length, frame_period, fft_length, out_format="complex", **stft_kwargs
)
self.istft = InverseShortTermFourierTransform(
frame_length, frame_period, fft_length, **kwargs
frame_length, frame_period, fft_length, **stft_kwargs
)
self.mgc2sp = MelGeneralizedCepstrumToSpectrum(
filter_order,
fft_length,
alpha=alpha,
gamma=gamma,
out_format="magnitude" if phase == "zero" else "complex",
n_fft=n_fft,
)

def forward(self, x, mc):
Expand Down
8 changes: 7 additions & 1 deletion diffsptk/core/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class ShortTermFourierTransform(nn.Module):
fft_length : int >= L [scalar]
Number of FFT bins, :math:`N`.
center : bool [scalar]
If True, assume that the center of data is the center of frame, otherwise
assume that the center of data is the left edge of frame.
zmean : bool [scalar]
If True, perform mean subtraction on each frame.
Expand All @@ -63,6 +67,8 @@ def __init__(
frame_length,
frame_period,
fft_length,
*,
center=True,
zmean=False,
norm="power",
window="blackman",
Expand All @@ -73,7 +79,7 @@ def __init__(
super(ShortTermFourierTransform, self).__init__()

self.stft = nn.Sequential(
Frame(frame_length, frame_period, zmean=zmean),
Frame(frame_length, frame_period, center=center, zmean=zmean),
Window(frame_length, fft_length, norm=norm, window=window),
Lambda(torch.fft.rfft)
if out_format == "complex"
Expand Down

0 comments on commit 6ee1399

Please sign in to comment.