Skip to content

Commit

Permalink
add histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Mar 27, 2024
1 parent 7fda5bf commit edf2cc4
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 10 deletions.
50 changes: 48 additions & 2 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,46 @@ def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs):
)


def histogram(x, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3):
"""Compute histogram.
Parameters
----------
x : Tensor [shape=(..., T)]
Input data.
n_bin : int >= 1
Number of bins, :math:`K`.
lower_bound : float < U
Lower bound of the histogram, :math:`L`.
upper_bound : float > L
Upper bound of the histogram, :math:`U`.
norm : bool
If True, normalize the histogram.
softness : float > 0
A smoothing parameter. The smaller value makes the output closer to the true
histogram, but the gradient vanishes.
Returns
-------
out : Tensor [shape=(..., K)]
Histogram in [L, U].
"""
return nn.Histogram._func(
x,
n_bin=n_bin,
lower_bound=lower_bound,
upper_bound=upper_bound,
norm=norm,
softness=softness,
)


def ialaw(y, abs_max=1, a=87.6):
"""Expand waveform by A-law algorithm.
Expand Down Expand Up @@ -1846,7 +1886,7 @@ def yingram(x, sample_rate=22050, lag_min=22, lag_max=None, n_bin=20):
)


def zcross(x, frame_length, norm=False):
def zcross(x, frame_length, norm=False, softness=1e-3):
"""Compute zero-crossing rate.
Parameters
Expand All @@ -1860,13 +1900,19 @@ def zcross(x, frame_length, norm=False):
norm : bool
If True, divide zero-crossing rate by frame length.
softness : float > 0
A smoothing parameter. The smaller value makes the output closer to the true
zero-crossing rate, but the gradient vanishes.
Returns
-------
out : Tensor [shape=(..., T/L)]
Zero-crossing rate.
"""
return nn.ZeroCrossingAnalysis._func(x, frame_length=frame_length, norm=norm)
return nn.ZeroCrossingAnalysis._func(
x, frame_length=frame_length, norm=norm, softness=softness
)


def zerodf(x, b, frame_period=80, ignore_gain=False):
Expand Down
2 changes: 2 additions & 0 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .entropy import Entropy
from .excite import ExcitationGeneration
from .fbank import MelFilterBankAnalysis
from .fbank import MelFilterBankAnalysis as FBANK
from .fftcep import CepstralAnalysis
from .frame import Frame
from .freqt import FrequencyTransform
Expand All @@ -28,6 +29,7 @@
from .gmm import GaussianMixtureModeling as GMM
from .gnorm import GeneralizedCepstrumGainNormalization
from .grpdelay import GroupDelay
from .histogram import Histogram
from .ialaw import ALawExpansion
from .icqt import InverseConstantQTransform
from .icqt import InverseConstantQTransform as ICQT
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/cqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
tuning=0,
filter_scale=1,
norm=1,
sparsity=0.01,
sparsity=1e-2,
window="hann",
scale=True,
**kwargs,
Expand Down
114 changes: 114 additions & 0 deletions diffsptk/modules/histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn

from ..misc.utils import to


class Histogram(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/histogram.html>`_
for details.
Parameters
----------
n_bin : int >= 1
Number of bins, :math:`K`.
lower_bound : float < U
Lower bound of the histogram, :math:`L`.
upper_bound : float > L
Upper bound of the histogram, :math:`U`.
norm : bool
If True, normalize the histogram.
softness : float > 0
A smoothing parameter. The smaller value makes the output closer to the true
histogram, but the gradient vanishes.
References
----------
.. [1] M. Avi-Aharon et al., "DeepHist: Differentiable joint and color histogram
layers for image-to-image translation," *arXiv preprint arXiv:2005.03995*,
2020.
"""

def __init__(
self, n_bin=10, lower_bound=0, upper_bound=1, norm=False, softness=1e-3
):
super(Histogram, self).__init__()

assert 1 <= n_bin
assert lower_bound < upper_bound
assert 0 < softness

self.norm = norm
self.softness = softness

centers = self._precompute(n_bin, lower_bound, upper_bound)
self.register_buffer("centers", centers)

def forward(self, x):
"""Compute histogram.
Parameters
----------
x : Tensor [shape=(..., T)]
Input data.
Returns
-------
out : Tensor [shape=(..., K)]
Histogram.
Examples
--------
>>> x = diffsptk.ramp(9)
>>> histogram = diffsptk.Histogram(n_bin=4, lower_bound=-0.1, upper_bound=9.1)
>>> h = histogram(x)
>>> h
tensor([3., 2., 2., 3.])
"""
return self._forward(x, self.norm, self.softness, self.centers)

@staticmethod
def _forward(x, norm, softness, centers):
y = x.unsqueeze(-2) - centers.unsqueeze(-1) # (..., K, T)
g = 0.5 * (centers[1] - centers[0])
h = torch.sigmoid((y + g) / softness) - torch.sigmoid((y - g) / softness)
h = h.sum(-1)
if norm:
h /= h.sum(-1, keepdim=True)
return h

@staticmethod
def _func(x, n_bin, lower_bound, upper_bound, norm, softness):
centers = Histogram._precompute(
n_bin, lower_bound, upper_bound, dtype=x.dtype, device=x.device
)
return Histogram._forward(x, norm, softness, centers)

@staticmethod
def _precompute(n_bin, lower_bound, upper_bound, dtype=None, device=None):
width = (upper_bound - lower_bound) / n_bin
bias = lower_bound + 0.5 * width
centers = torch.arange(n_bin, dtype=torch.double, device=device) * width + bias
return to(centers, dtype=dtype)
2 changes: 1 addition & 1 deletion diffsptk/modules/icqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
tuning=0,
filter_scale=1,
norm=1,
sparsity=0.01,
sparsity=1e-2,
window="hann",
scale=True,
**kwargs,
Expand Down
16 changes: 11 additions & 5 deletions diffsptk/modules/zcross.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class ZeroCrossingAnalysis(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/zcross.html>`_
for details. **Note that this module cannot compute gradient**.
for details.
Parameters
----------
Expand All @@ -32,15 +32,21 @@ class ZeroCrossingAnalysis(nn.Module):
norm : bool
If True, divide zero-crossing rate by frame length.
softness : float > 0
A smoothing parameter. The smaller value makes the output closer to the true
zero-crossing rate, but the gradient vanishes.
"""

def __init__(self, frame_length, norm=False):
def __init__(self, frame_length, norm=False, softness=1e-3):
super(ZeroCrossingAnalysis, self).__init__()

assert 1 <= frame_length
assert 0 < softness

self.frame_length = frame_length
self.norm = norm
self.softness = softness

def forward(self, x):
"""Compute zero-crossing rate.
Expand All @@ -66,11 +72,11 @@ def forward(self, x):
tensor([2., 1.])
"""
return self._forward(x, self.frame_length, self.norm)
return self._forward(x, self.frame_length, self.norm, self.softness)

@staticmethod
def _forward(x, frame_length, norm):
x = torch.sign(x)
def _forward(x, frame_length, norm, softness):
x = torch.tanh(x / softness)
x = replicate1(x, right=False)
x = x.unfold(-1, frame_length + 1, frame_length)
z = 0.5 * (x[..., 1:] - x[..., :-1]).abs().sum(-1)
Expand Down
2 changes: 2 additions & 0 deletions docs/modules/entropy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ entropy
:members:

.. autofunction:: diffsptk.functional.entropy

.. seealso:: :ref:`histogram`
11 changes: 11 additions & 0 deletions docs/modules/histogram.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. _histogram:

histogram
---------

.. autoclass:: diffsptk.Histogram
:members:

.. autofunction:: diffsptk.functional.histogram

.. seealso:: :ref:`entropy`
60 changes: 60 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest
import torch

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("module", [False, True])
@pytest.mark.parametrize("norm", [False, True])
def test_compatibility(
device, module, norm, K=4, lower_bound=-1, upper_bound=1, L=50, B=2
):
histogram = U.choice(
module,
diffsptk.Histogram,
diffsptk.functional.histogram,
{},
{
"n_bin": K,
"lower_bound": lower_bound,
"upper_bound": upper_bound,
"norm": norm,
"softness": 1e-4,
},
)

opt = "-n" if norm else ""
U.check_compatibility(
device,
histogram,
[],
[f"nrand -l {B*L}"],
f"histogram -t {L} -b {K} -l {lower_bound} -u {upper_bound} {opt}",
[],
dx=L,
dy=K,
)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_differentiability(device, K=4, L=50):
histogram = diffsptk.Histogram(K, lower_bound=-1, upper_bound=1, softness=1e-1)
U.check_differentiability(device, histogram, [L])
8 changes: 7 additions & 1 deletion tests/test_zcross.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_compatibility(device, module, norm, L=10, T=50):
diffsptk.ZeroCrossingAnalysis,
diffsptk.functional.zcross,
{},
{"frame_length": L, "norm": norm},
{"frame_length": L, "norm": norm, "softness": 1e-3},
)

opt = "-o 1" if norm else ""
Expand All @@ -41,3 +41,9 @@ def test_compatibility(device, module, norm, L=10, T=50):
f"zcross -l {L} {opt}",
[],
)


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_differentiability(device, L=10, T=50):
zcross = diffsptk.ZeroCrossingAnalysis(L, softness=1e-1)
U.check_differentiability(device, zcross, [T])

0 comments on commit edf2cc4

Please sign in to comment.