From edf2cc42145347755db1bbe4ca05b178d57a72f8 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 28 Mar 2024 01:18:49 +0900 Subject: [PATCH] add histogram --- diffsptk/functional.py | 50 ++++++++++++++- diffsptk/modules/__init__.py | 2 + diffsptk/modules/cqt.py | 2 +- diffsptk/modules/histogram.py | 114 ++++++++++++++++++++++++++++++++++ diffsptk/modules/icqt.py | 2 +- diffsptk/modules/zcross.py | 16 +++-- docs/modules/entropy.rst | 2 + docs/modules/histogram.rst | 11 ++++ tests/test_histogram.py | 60 ++++++++++++++++++ tests/test_zcross.py | 8 ++- 10 files changed, 257 insertions(+), 10 deletions(-) create mode 100644 diffsptk/modules/histogram.py create mode 100644 docs/modules/histogram.rst create mode 100644 tests/test_histogram.py diff --git a/diffsptk/functional.py b/diffsptk/functional.py index 8e9accc..087ede5 100644 --- a/diffsptk/functional.py +++ b/diffsptk/functional.py @@ -599,6 +599,46 @@ def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs): ) +def histogram(x, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3): + """Compute histogram. + + Parameters + ---------- + x : Tensor [shape=(..., T)] + Input data. + + n_bin : int >= 1 + Number of bins, :math:`K`. + + lower_bound : float < U + Lower bound of the histogram, :math:`L`. + + upper_bound : float > L + Upper bound of the histogram, :math:`U`. + + norm : bool + If True, normalize the histogram. + + softness : float > 0 + A smoothing parameter. The smaller value makes the output closer to the true + histogram, but the gradient vanishes. + + Returns + ------- + out : Tensor [shape=(..., K)] + Histogram in [L, U]. + + """ + return nn.Histogram._func( + x, + n_bin=n_bin, + lower_bound=lower_bound, + upper_bound=upper_bound, + norm=norm, + softness=softness, + ) + + def ialaw(y, abs_max=1, a=87.6): """Expand waveform by A-law algorithm. @@ -1846,7 +1886,7 @@ def yingram(x, sample_rate=22050, lag_min=22, lag_max=None, n_bin=20): ) -def zcross(x, frame_length, norm=False): +def zcross(x, frame_length, norm=False, softness=1e-3): """Compute zero-crossing rate. Parameters @@ -1860,13 +1900,19 @@ def zcross(x, frame_length, norm=False): norm : bool If True, divide zero-crossing rate by frame length. + softness : float > 0 + A smoothing parameter. The smaller value makes the output closer to the true + zero-crossing rate, but the gradient vanishes. + Returns ------- out : Tensor [shape=(..., T/L)] Zero-crossing rate. """ - return nn.ZeroCrossingAnalysis._func(x, frame_length=frame_length, norm=norm) + return nn.ZeroCrossingAnalysis._func( + x, frame_length=frame_length, norm=norm, softness=softness + ) def zerodf(x, b, frame_period=80, ignore_gain=False): diff --git a/diffsptk/modules/__init__.py b/diffsptk/modules/__init__.py index 5bbfa8d..eacd8e1 100644 --- a/diffsptk/modules/__init__.py +++ b/diffsptk/modules/__init__.py @@ -20,6 +20,7 @@ from .entropy import Entropy from .excite import ExcitationGeneration from .fbank import MelFilterBankAnalysis +from .fbank import MelFilterBankAnalysis as FBANK from .fftcep import CepstralAnalysis from .frame import Frame from .freqt import FrequencyTransform @@ -28,6 +29,7 @@ from .gmm import GaussianMixtureModeling as GMM from .gnorm import GeneralizedCepstrumGainNormalization from .grpdelay import GroupDelay +from .histogram import Histogram from .ialaw import ALawExpansion from .icqt import InverseConstantQTransform from .icqt import InverseConstantQTransform as ICQT diff --git a/diffsptk/modules/cqt.py b/diffsptk/modules/cqt.py index c61b2a9..07bb7bb 100644 --- a/diffsptk/modules/cqt.py +++ b/diffsptk/modules/cqt.py @@ -96,7 +96,7 @@ def __init__( tuning=0, filter_scale=1, norm=1, - sparsity=0.01, + sparsity=1e-2, window="hann", scale=True, **kwargs, diff --git a/diffsptk/modules/histogram.py b/diffsptk/modules/histogram.py new file mode 100644 index 0000000..9c9a291 --- /dev/null +++ b/diffsptk/modules/histogram.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------ # +# Copyright 2022 SPTK Working Group # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +# ------------------------------------------------------------------------ # + +import torch +import torch.nn as nn + +from ..misc.utils import to + + +class Histogram(nn.Module): + """See `this page `_ + for details. + + Parameters + ---------- + n_bin : int >= 1 + Number of bins, :math:`K`. + + lower_bound : float < U + Lower bound of the histogram, :math:`L`. + + upper_bound : float > L + Upper bound of the histogram, :math:`U`. + + norm : bool + If True, normalize the histogram. + + softness : float > 0 + A smoothing parameter. The smaller value makes the output closer to the true + histogram, but the gradient vanishes. + + References + ---------- + .. [1] M. Avi-Aharon et al., "DeepHist: Differentiable joint and color histogram + layers for image-to-image translation," *arXiv preprint arXiv:2005.03995*, + 2020. + + """ + + def __init__( + self, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3 + ): + super(Histogram, self).__init__() + + assert 1 <= n_bin + assert lower_bound < upper_bound + assert 0 < softness + + self.norm = norm + self.softness = softness + + centers = self._precompute(n_bin, lower_bound, upper_bound) + self.register_buffer("centers", centers) + + def forward(self, x): + """Compute histogram. + + Parameters + ---------- + x : Tensor [shape=(..., T)] + Input data. + + Returns + ------- + out : Tensor [shape=(..., K)] + Histogram. + + Examples + -------- + >>> x = diffsptk.ramp(9) + >>> histogram = diffsptk.Histogram(n_bin=4, lower_bound=-0.1, upper_bound=9.1) + >>> h = histogram(x) + >>> h + tensor([3., 2., 2., 3.]) + + """ + return self._forward(x, self.norm, self.softness, self.centers) + + @staticmethod + def _forward(x, norm, softness, centers): + y = x.unsqueeze(-2) - centers.unsqueeze(-1) # (..., K, T) + g = 0.5 * (centers[1] - centers[0]) + h = torch.sigmoid((y + g) / softness) - torch.sigmoid((y - g) / softness) + h = h.sum(-1) + if norm: + h /= h.sum(-1, keepdim=True) + return h + + @staticmethod + def _func(x, n_bin, lower_bound, upper_bound, norm, softness): + centers = Histogram._precompute( + n_bin, lower_bound, upper_bound, dtype=x.dtype, device=x.device + ) + return Histogram._forward(x, norm, softness, centers) + + @staticmethod + def _precompute(n_bin, lower_bound, upper_bound, dtype=None, device=None): + width = (upper_bound - lower_bound) / n_bin + bias = lower_bound + 0.5 * width + centers = torch.arange(n_bin, dtype=torch.double, device=device) * width + bias + return to(centers, dtype=dtype) diff --git a/diffsptk/modules/icqt.py b/diffsptk/modules/icqt.py index 879d830..c435aae 100644 --- a/diffsptk/modules/icqt.py +++ b/diffsptk/modules/icqt.py @@ -95,7 +95,7 @@ def __init__( tuning=0, filter_scale=1, norm=1, - sparsity=0.01, + sparsity=1e-2, window="hann", scale=True, **kwargs, diff --git a/diffsptk/modules/zcross.py b/diffsptk/modules/zcross.py index ed41c84..e84aa2f 100644 --- a/diffsptk/modules/zcross.py +++ b/diffsptk/modules/zcross.py @@ -22,7 +22,7 @@ class ZeroCrossingAnalysis(nn.Module): """See `this page `_ - for details. **Note that this module cannot compute gradient**. + for details. Parameters ---------- @@ -32,15 +32,21 @@ class ZeroCrossingAnalysis(nn.Module): norm : bool If True, divide zero-crossing rate by frame length. + softness : float > 0 + A smoothing parameter. The smaller value makes the output closer to the true + zero-crossing rate, but the gradient vanishes. + """ - def __init__(self, frame_length, norm=False): + def __init__(self, frame_length, norm=False, softness=1e-3): super(ZeroCrossingAnalysis, self).__init__() assert 1 <= frame_length + assert 0 < softness self.frame_length = frame_length self.norm = norm + self.softness = softness def forward(self, x): """Compute zero-crossing rate. @@ -66,11 +72,11 @@ def forward(self, x): tensor([2., 1.]) """ - return self._forward(x, self.frame_length, self.norm) + return self._forward(x, self.frame_length, self.norm, self.softness) @staticmethod - def _forward(x, frame_length, norm): - x = torch.sign(x) + def _forward(x, frame_length, norm, softness): + x = torch.tanh(x / softness) x = replicate1(x, right=False) x = x.unfold(-1, frame_length + 1, frame_length) z = 0.5 * (x[..., 1:] - x[..., :-1]).abs().sum(-1) diff --git a/docs/modules/entropy.rst b/docs/modules/entropy.rst index a675d5b..d661a89 100644 --- a/docs/modules/entropy.rst +++ b/docs/modules/entropy.rst @@ -7,3 +7,5 @@ entropy :members: .. autofunction:: diffsptk.functional.entropy + +.. seealso:: :ref:`histogram` diff --git a/docs/modules/histogram.rst b/docs/modules/histogram.rst new file mode 100644 index 0000000..9649bf4 --- /dev/null +++ b/docs/modules/histogram.rst @@ -0,0 +1,11 @@ +.. _histogram: + +histogram +--------- + +.. autoclass:: diffsptk.Histogram + :members: + +.. autofunction:: diffsptk.functional.histogram + +.. seealso:: :ref:`entropy` diff --git a/tests/test_histogram.py b/tests/test_histogram.py new file mode 100644 index 0000000..1391979 --- /dev/null +++ b/tests/test_histogram.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------ # +# Copyright 2022 SPTK Working Group # +# # +# Licensed under the Apache License, Version 2.0 (the "License"); # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http://www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +# ------------------------------------------------------------------------ # + +import pytest +import torch + +import diffsptk +import tests.utils as U + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("module", [False, True]) +@pytest.mark.parametrize("norm", [False, True]) +def test_compatibility( + device, module, norm, K=4, lower_bound=-1, upper_bound=1, L=50, B=2 +): + histogram = U.choice( + module, + diffsptk.Histogram, + diffsptk.functional.histogram, + {}, + { + "n_bin": K, + "lower_bound": lower_bound, + "upper_bound": upper_bound, + "norm": norm, + "softness": 1e-4, + }, + ) + + opt = "-n" if norm else "" + U.check_compatibility( + device, + histogram, + [], + [f"nrand -l {B*L}"], + f"histogram -t {L} -b {K} -l {lower_bound} -u {upper_bound} {opt}", + [], + dx=L, + dy=K, + ) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_differentiability(device, K=4, L=50): + histogram = diffsptk.Histogram(K, lower_bound=-1, upper_bound=1, softness=1e-1) + U.check_differentiability(device, histogram, [L]) diff --git a/tests/test_zcross.py b/tests/test_zcross.py index 63a043b..3ec1987 100644 --- a/tests/test_zcross.py +++ b/tests/test_zcross.py @@ -29,7 +29,7 @@ def test_compatibility(device, module, norm, L=10, T=50): diffsptk.ZeroCrossingAnalysis, diffsptk.functional.zcross, {}, - {"frame_length": L, "norm": norm}, + {"frame_length": L, "norm": norm, "softness": 1e-3}, ) opt = "-o 1" if norm else "" @@ -41,3 +41,9 @@ def test_compatibility(device, module, norm, L=10, T=50): f"zcross -l {L} {opt}", [], ) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_differentiability(device, L=10, T=50): + zcross = diffsptk.ZeroCrossingAnalysis(L, softness=1e-1) + U.check_differentiability(device, zcross, [T])