Skip to content

Commit

Permalink
Merge pull request #68 from sp-nitech/next_release
Browse files Browse the repository at this point in the history
Next release
  • Loading branch information
takenori-y committed Mar 10, 2024
2 parents c29ec71 + 6b3ec8d commit bf388aa
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 144 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ _build/

# tools
tools/SPTK/
tools/toml/
tools/taplo/

# misc
__pycache__/
Expand Down
38 changes: 27 additions & 11 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PROJECT := diffsptk
MODULE :=

PYTHON_VERSION := 3.8
PYTHON_VERSION := 3.9
TORCH_VERSION := 1.11.0
TORCHAUDIO_VERSION := 0.11.0
PLATFORM := cu113
Expand All @@ -31,6 +31,7 @@ dev:
. ./venv/bin/activate; python -m pip install torch==$(TORCH_VERSION)+$(PLATFORM) torchaudio==$(TORCHAUDIO_VERSION)+$(PLATFORM) \
-f https://download.pytorch.org/whl/$(PLATFORM)/torch_stable.html; \
. ./venv/bin/activate; python -m pip install -e .[dev]
. ./venv/bin/activate; python -m pip install icc-rt

dist:
. ./venv/bin/activate; python -m build --wheel
Expand All @@ -48,14 +49,30 @@ doc-clean:
fi

check:
. ./venv/bin/activate; python -m black --check $(PROJECT) tests
. ./venv/bin/activate; python -m isort --check $(PROJECT) tests --project $(PROJECT)
. ./venv/bin/activate; python -m pflake8 $(PROJECT) tests
@if [ ! -x ./tools/taplo/taplo ]; then \
echo ""; \
echo "Error: please install taplo-cli"; \
echo ""; \
echo " make tool"; \
echo ""; \
exit 1; \
fi
. ./venv/bin/activate; python -m ruff check $(PROJECT) tests
. ./venv/bin/activate; python -m isort --check $(PROJECT) tests
./tools/taplo/taplo format --check pyproject.toml

format:
. ./venv/bin/activate; python -m black $(PROJECT) tests
. ./venv/bin/activate; python -m isort $(PROJECT) tests --project $(PROJECT)
. ./venv/bin/activate; python -m pflake8 $(PROJECT) tests
@if [ ! -x ./tools/taplo/taplo ]; then \
echo ""; \
echo "Error: please install taplo-cli"; \
echo ""; \
echo " make tool"; \
echo ""; \
exit 1; \
fi
. ./venv/bin/activate; python -m ruff format $(PROJECT) tests
. ./venv/bin/activate; python -m isort $(PROJECT) tests
./tools/taplo/taplo format pyproject.toml

test:
@if [ ! -d tools/SPTK/bin ]; then \
Expand All @@ -81,17 +98,16 @@ tool-clean:
cd tools; make clean

update:
@if [ ! -x tools/toml/toml ]; then \
@if [ ! -x ./tools/taplo/taplo ]; then \
echo ""; \
echo "Error: please install toml-cli"; \
echo "Error: please install taplo-cli"; \
echo ""; \
echo " make tool"; \
echo ""; \
exit 1; \
fi
. ./venv/bin/activate; python -m pip install --upgrade pip
@for package in $$(./tools/toml/toml get pyproject.toml project.optional-dependencies.dev | \
sed 's/"//g' | tr -d '[]' | tr , ' '); do \
@for package in $$(./tools/taplo/taplo get -f pyproject.toml project.optional-dependencies.dev); do \
. ./venv/bin/activate; python -m pip install --upgrade $$package; \
done

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ diffsptk
*diffsptk* is a differentiable version of [SPTK](https://github.com/sp-nitech/SPTK) based on the PyTorch framework.

[![Latest Manual](https://img.shields.io/badge/docs-latest-blue.svg)](https://sp-nitech.github.io/diffsptk/latest/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/1.2.1/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/2.0.0/)
[![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk)
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.11.0%20%7C%202.2.1-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyPI Version](https://img.shields.io/pypi/v/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![Codecov](https://codecov.io/gh/sp-nitech/diffsptk/branch/master/graph/badge.svg)](https://app.codecov.io/gh/sp-nitech/diffsptk)
[![License](https://img.shields.io/github/license/sp-nitech/diffsptk.svg)](https://github.com/sp-nitech/diffsptk/blob/master/LICENSE)
[![GitHub Actions](https://github.com/sp-nitech/diffsptk/workflows/package/badge.svg)](https://github.com/sp-nitech/diffsptk/actions)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)


Requirements
Expand Down
7 changes: 7 additions & 0 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License. #
# ------------------------------------------------------------------------ #

from importlib import import_module

import numpy as np
import soundfile as sf
import torch
Expand All @@ -34,6 +36,11 @@ def forward(self, x):
return self.func(x, **self.opt)


def delayed_import(module_path, item_name):
module = import_module(module_path)
return getattr(module, item_name)


def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)

Expand Down
101 changes: 75 additions & 26 deletions diffsptk/modules/cqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,16 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. #
# ------------------------------------------------------------------------ #

import librosa
import numpy as np
import torch
import torch.nn as nn
import torchaudio

from ..misc.utils import Lambda
from ..misc.utils import delayed_import
from ..misc.utils import numpy_to_torch
from .stft import ShortTimeFourierTransform as STFT

_vqt_filter_fft = librosa.core.constantq.__vqt_filter_fft
_bpo_to_alpha = librosa.core.constantq.__bpo_to_alpha


class ConstantQTransform(nn.Module):
"""Perform constant-Q transform based on the librosa implementation.
Expand Down Expand Up @@ -105,6 +103,14 @@ def __init__(
):
super(ConstantQTransform, self).__init__()

import librosa

et_relative_bw = delayed_import("librosa.core.constantq", "__et_relative_bw")
vqt_filter_fft = delayed_import("librosa.core.constantq", "__vqt_filter_fft")
early_downsample_count = delayed_import(
"librosa.core.constantq", "__early_downsample_count"
)

assert 1 <= frame_period
assert 1 <= sample_rate

Expand All @@ -121,16 +127,55 @@ def __init__(
tuning=tuning,
)

alpha = _bpo_to_alpha(B)
if K == 1:
alpha = et_relative_bw(B)
else:
alpha = librosa.filters._relative_bandwidth(freqs=freqs)

lengths, filter_cutoff = librosa.filters.wavelet_lengths(
freqs=freqs,
sr=sample_rate,
window=window,
filter_scale=filter_scale,
alpha=alpha,
)

if scale:
lengths, _ = librosa.filters.wavelet_lengths(
freqs=freqs,
sr=sample_rate,
window=window,
filter_scale=filter_scale,
alpha=alpha,
early_downsample = []
downsample_count = early_downsample_count(
sample_rate * 0.5, filter_cutoff, frame_period, n_octave
)
if 0 < downsample_count:
downsample_factor = 2**downsample_count
early_downsample.append(
torchaudio.transforms.Resample(
orig_freq=downsample_factor,
new_freq=1,
dtype=torch.get_default_dtype(),
**kwargs,
)
)
if scale:
downsample_scale = np.sqrt(downsample_factor)
else:
downsample_scale = downsample_factor
early_downsample.append(Lambda(lambda x: x * downsample_scale))

# Update frame period and sample rate.
frame_period //= downsample_factor
sample_rate /= downsample_factor

# Update lengths for scaling.
if scale:
lengths, _ = librosa.filters.wavelet_lengths(
freqs=freqs,
sr=sample_rate,
window=window,
filter_scale=filter_scale,
alpha=alpha,
)
self.early_downsample = nn.Sequential(*early_downsample)

if scale:
cqt_scale = np.reciprocal(np.sqrt(lengths))
else:
cqt_scale = np.ones(K)
Expand All @@ -152,17 +197,17 @@ def __init__(
for i in range(n_octave):
sl = slice(-n_filter * (i + 1), None if i == 0 else (-n_filter * i))

fft_basis, fft_length, _ = _vqt_filter_fft(
fft_basis, fft_length, _ = vqt_filter_fft(
sr[i],
freqs[sl],
filter_scale,
norm,
sparsity,
window=window,
alpha=alpha,
alpha=alpha[sl],
)

fft_basis[:] *= np.sqrt(sample_rate / sr[i])
fft_basis *= np.sqrt(sample_rate / sr[i])
self.register_buffer(
f"fft_basis_{i}", numpy_to_torch(fft_basis.todense()).T
)
Expand All @@ -181,19 +226,23 @@ def __init__(
)

if fp[i] % 2 == 0:
resample_scale = np.sqrt(2)
resamplers.append(
torchaudio.transforms.Resample(
orig_freq=2,
new_freq=1,
dtype=torch.get_default_dtype(),
**kwargs,
nn.Sequential(
torchaudio.transforms.Resample(
orig_freq=2,
new_freq=1,
dtype=torch.get_default_dtype(),
**kwargs,
),
Lambda(lambda x: x * resample_scale),
)
)
else:
resamplers.append(Lambda(lambda x: x))

self.transforms = nn.ModuleList(transforms)
self.resamplers = nn.ModuleList(resamplers)
self.resample_scale = 1 / np.sqrt(0.5)
self.frame_period = frame_period

def forward(self, x):
"""Compute constant-Q transform.
Expand All @@ -217,15 +266,15 @@ def forward(self, x):
tensor([[1.1259, 1.2069, 1.3008, 1.3885]])
"""
x = self.early_downsample(x)

cs = []
fp = self.frame_period
for i in range(len(self.transforms)):
X = self.transforms[i](x)
W = getattr(self, f"fft_basis_{i}")
cs.append(torch.matmul(X, W))
if fp % 2 == 0:
fp //= 2
x = self.resamplers[i](x) * self.resample_scale
if i != len(self.transforms) - 1:
x = self.resamplers[i](x)
c = self._trim_stack(cs) * self.cqt_scale
return c

Expand Down
21 changes: 13 additions & 8 deletions diffsptk/modules/icqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,15 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. #
# ------------------------------------------------------------------------ #

import librosa
import numpy as np
import torch
import torch.nn as nn
import torchaudio

from ..misc.utils import delayed_import
from ..misc.utils import numpy_to_torch
from .istft import InverseShortTimeFourierTransform as ISTFT

_vqt_filter_fft = librosa.core.constantq.__vqt_filter_fft
_bpo_to_alpha = librosa.core.constantq.__bpo_to_alpha


class InverseConstantQTransform(nn.Module):
"""Perform inverse constant-Q transform based on the librosa implementation.
Expand Down Expand Up @@ -105,6 +102,11 @@ def __init__(
):
super(InverseConstantQTransform, self).__init__()

import librosa

et_relative_bw = delayed_import("librosa.core.constantq", "__et_relative_bw")
vqt_filter_fft = delayed_import("librosa.core.constantq", "__vqt_filter_fft")

assert 1 <= frame_period
assert 1 <= sample_rate

Expand All @@ -120,7 +122,10 @@ def __init__(
tuning=tuning,
)

alpha = _bpo_to_alpha(n_bin_per_octave)
if K == 1:
alpha = et_relative_bw(B)
else:
alpha = librosa.filters._relative_bandwidth(freqs=freqs)

lengths, _ = librosa.filters.wavelet_lengths(
freqs=freqs,
Expand All @@ -132,7 +137,7 @@ def __init__(
if scale:
cqt_scale = np.sqrt(lengths)
else:
cqt_scale = np.ones(n_bin)
cqt_scale = np.ones(K)
self.register_buffer("cqt_scale", numpy_to_torch(cqt_scale))

fp = [frame_period]
Expand All @@ -156,14 +161,14 @@ def __init__(
sl = slice(B * i, B * i + n_filter)
slices.append(sl)

fft_basis, fft_length, _ = _vqt_filter_fft(
fft_basis, fft_length, _ = vqt_filter_fft(
sr[i],
freqs[sl],
filter_scale,
norm,
sparsity,
window=window,
alpha=alpha,
alpha=alpha[sl],
)

fft_basis = np.asarray(fft_basis.conj().todense())
Expand Down
4 changes: 1 addition & 3 deletions diffsptk/modules/linear_intpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def _forward(x, upsampling_factor):
size=T * upsampling_factor + 1,
mode="linear",
align_corners=True,
)[
..., :-1
] # Remove the padded value.
)[..., :-1] # Remove the padded value.
y = x.mT.reshape(B, -1, D)

if d == 1:
Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@
#
html_theme = 'pydata_sphinx_theme'

# Disable navigation with keys.
html_theme_options = {"navigation_with_keys": False}

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []

# Suppress warnings.
# Do not show all members of a class automatically.
numpydoc_show_class_members = False
Loading

0 comments on commit bf388aa

Please sign in to comment.