Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unframe #21

Merged
merged 1 commit into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add unframe
  • Loading branch information
takenori-y committed Feb 6, 2023
commit cba67c22b98ae2bfab0fb239bc1e7503ee523848
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ The latest stable release can be installed through PyPI by running
```sh
pip install diffsptk
```
Alternatively,
The development release can be installed from the master branch:
```sh
git clone https://github.com/sp-nitech/diffsptk.git
pip install -e diffsptk
pip install git+https://github.com/sp-nitech/diffsptk.git@master
```


Expand Down
1 change: 1 addition & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .stft import ShortTermFourierTransform
from .stft import ShortTermFourierTransform as STFT
from .ulaw import MuLawCompression
from .unframe import Unframe
from .vq import VectorQuantization
from .window import Window
from .zcross import ZeroCrossingAnalysis
Expand Down
16 changes: 8 additions & 8 deletions diffsptk/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ def forward(self, x):

Returns
-------
y : Tensor [shape=(..., N, L)]
y : Tensor [shape=(..., T/P, L)]
Framed waveform.

Examples
--------
>>> x = torch.arange(1, 10)
>>> frame = diffsptk.Frame(5, 3)
>>> x = diffsptk.ramp(1, 9)
>>> frame = diffsptk.Frame(5, 2)
>>> y = frame(x)
>>> y
tensor([[0, 0, 1, 2, 3],
[1, 2, 3, 4, 5],
[3, 4, 5, 6, 7],
[5, 6, 7, 8, 9],
[7, 8, 9, 0, 0]])
tensor([[0., 0., 1., 2., 3.],
[1., 2., 3., 4., 5.],
[3., 4., 5., 6., 7.],
[5., 6., 7., 8., 9.],
[7., 8., 9., 0., 0.]])

"""
y = self.pad(x)
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/core/stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, x):

Returns
-------
y : Tensor [shape=(..., N, L/2+1)]
y : Tensor [shape=(..., T/P, L/2+1)]
Spectrum.

Examples
Expand Down
114 changes: 114 additions & 0 deletions diffsptk/core/unframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn
import torch.nn.functional as F


class Unframe(nn.Module):
"""This is the opposite module to Frame.

Parameters
----------
frame_length : int >= 1 [scalar]
Frame length, :math:`L`.

frame_peirod : int >= 1 [scalar]
Frame period, :math:`P`.

center : bool [scalar]
If True, assume that the center of data is the center of frame, otherwise
assume that the center of data is the left edge of frame.

"""

def __init__(self, frame_length, frame_period, center=True):
super(Unframe, self).__init__()

self.frame_length = frame_length
self.frame_period = frame_period

assert 1 <= self.frame_length
assert 1 <= self.frame_period

if center:
self.left_pad_width = self.frame_length // 2
else:
self.left_pad_width = 0

def forward(self, y, out_length=None):
"""Revert framed waveform.

Parameters
----------
y : Tensor [shape=(..., T/P, L)]
Framed waveform.

out_length : int [scalar]
Length of original signal, `T`.

Returns
-------
x : Tensor [shape=(..., T)]
Waveform.

Examples
--------
>>> x = diffsptk.ramp(1, 9)
>>> frame = diffsptk.Frame(5, 2)
>>> y = frame(x)
>>> y
tensor([[0., 0., 1., 2., 3.],
[1., 2., 3., 4., 5.],
[3., 4., 5., 6., 7.],
[5., 6., 7., 8., 9.],
[7., 8., 9., 0., 0.]])
>>> unframe = diffsptk.Unframe(5, 2)
>>> z = unframe(y, out_length=x.size(0))
>>> z
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])

"""
d = y.dim()
assert 2 <= d <= 4

N = y.size(-2)

def fold(x):
x = F.fold(
x,
(1, (N - 1) * self.frame_period + self.frame_length),
(1, self.frame_length),
stride=(1, self.frame_period),
)
s = self.left_pad_width
e = None if out_length is None else s + out_length
x = x[..., 0, 0, s:e]
return x

x = y.transpose(-2, -1)
if d == 2:
x = x.unsqueeze(0)

n = fold(torch.ones_like(x))
x = fold(x)
x = x / n

if d == 2:
x = x.squeeze(0)

return x
2 changes: 1 addition & 1 deletion docs/core/frame.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ frame
.. autoclass:: diffsptk.Frame
:members:

.. seealso:: :ref:`window`
.. seealso:: :ref:`unframe` :ref:`window`
9 changes: 9 additions & 0 deletions docs/core/unframe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _unframe:

unframe
-------

.. autoclass:: diffsptk.Unframe
:members:

.. seealso:: :ref:`frame` :ref:`window`
48 changes: 48 additions & 0 deletions tests/test_unframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest
import torch

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("fl", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("fp", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("center", [True, False])
def test_compatibility(device, fl, fp, center, T=20):
if device == "cuda" and not torch.cuda.is_available():
return
if fl < fp:
return

frame = diffsptk.Frame(fl, fp, center=center)
unframe = diffsptk.Unframe(fl, fp, center=center)

x = diffsptk.ramp(T)
y = frame(x)

x2 = diffsptk.ramp(torch.max(y))
z = unframe(y, out_length=x2.size(-1))
assert torch.allclose(x2, z)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_differentiable(device, fl=5, fp=3, B=2, N=4):
unframe = diffsptk.Unframe(fl, fp)
U.check_differentiable(device, unframe, [B, N, fl])