Skip to content

Commit

Permalink
Merge pull request #20 from sp-nitech/msvq
Browse files Browse the repository at this point in the history
Add msvq and imsvq
  • Loading branch information
takenori-y committed Feb 3, 2023
2 parents dc794f4 + 91177eb commit 495f589
Show file tree
Hide file tree
Showing 11 changed files with 307 additions and 8 deletions.
2 changes: 2 additions & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .imglsadf import PseudoInverseMGLSADigitalFilter
from .imglsadf import PseudoInverseMGLSADigitalFilter as IMLSA
from .imsvq import InverseMultiStageVectorQuantization
from .interpolate import Interpolation
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks as IPQMF
Expand All @@ -54,6 +55,7 @@
from .mlpg import MaximumLikelihoodParameterGeneration
from .mlpg import MaximumLikelihoodParameterGeneration as MLPG
from .mpir2c import MinimumPhaseImpulseResponseToCepstrum
from .msvq import MultiStageVectorQuantization
from .ndps2c import NegativeDerivativeOfPhaseSpectrumToCepstrum
from .norm0 import AllPoleToAllZeroDigitalFilterCoefficients
from .par2lar import ParcorCoefficientsToLogAreaRatio
Expand Down
66 changes: 66 additions & 0 deletions diffsptk/core/imsvq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
import torch.nn as nn


class InverseMultiStageVectorQuantization(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/imsvq.html>`_
for details.
"""

def __init__(self):
super(InverseMultiStageVectorQuantization, self).__init__()

def forward(self, indices, codebooks):
"""Perform inverse residual vector quantization.
Parameters
----------
indices : Tensor [shape=(..., Q)]
Codebook indices.
codebooks : Tensor [shape=(Q, K, M+1)]
Codebooks.
Returns
-------
xq : Tensor [shape=(..., M+1)]
Quantized vectors.
Examples
--------
>>> msvq = diffsptk.MultiStageVectorQuantization(4, 3, 2)
>>> imsvq = diffsptk.InverseMultiStageVectorQuantization()
>>> indices = torch.tensor([[0, 1], [1, 0]])
>>> xq = imsvq(indices, msvq.codebooks)
>>> xq
tensor([[-0.8029, -0.1674, 0.5697, 0.9734, 0.1920],
[ 0.0720, -1.0491, -0.4491, -0.2043, -0.3582]])
"""
target_shape = list(indices.shape[:-1])
target_shape.append(codebooks.size(-1))
xq = 0
for i in range(indices.size(-1)):
code_vector = torch.index_select(
codebooks[i], 0, indices[..., i].view(-1).long()
)
xq = xq + code_vector
xq = xq.view(target_shape)
return xq
114 changes: 114 additions & 0 deletions diffsptk/core/msvq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import warnings

import torch.nn as nn

warnings.simplefilter("ignore", UserWarning)
from vector_quantize_pytorch import ResidualVQ # noqa: E402


class MultiStageVectorQuantization(nn.Module):
"""See `this page <https://github.com/lucidrains/vector-quantize-pytorch>`_
for details.
Parameters
----------
order : int >= 0 [scalar]
Order of vector, :math:`M`.
codebook_size : int >= 1 [scalar]
Codebook size, :math:`K`.
n_stage : int >= 1 [scalar]
Number of stages (quantizers), :math:`Q`.
**kwargs : additional keyword arguments
See vector-quantize-pytorch repository for details.
"""

def __init__(self, order, codebook_size, n_stage, **kwargs):
super(MultiStageVectorQuantization, self).__init__()

assert 0 <= order
assert 1 <= codebook_size
assert 1 <= n_stage

self.vq = ResidualVQ(
dim=order + 1, codebook_size=codebook_size, num_quantizers=n_stage, **kwargs
).float()

@property
def codebooks(self):
return self.vq.codebooks

def forward(self, x, codebooks=None, **kwargs):
"""Perform residual vector quantization.
Parameters
----------
x : Tensor [shape=(..., M+1)]
Input vectors.
codebooks : Tensor [shape=(Q, K, M+1)]
External codebooks. If None, use internal codebooks.
**kwargs : additional keyword arguments
See vector-quantize-pytorch repository for details.
Returns
-------
xq : Tensor [shape=(..., M+1)]
Quantized vectors.
indices : Tensor [shape=(..., Q)]
Codebook indices.
losses : Tensor [shape=(Q,)]
Commitment losses.
Examples
--------
>>> x = diffsptk.nrand(4)
>>> x
tensor([-0.5206, 1.0048, -0.3370, 1.3364, -0.2933])
>>> msvq = diffsptk.MultiStageVectorQuantization(4, 3, 2).eval()
>>> xq, indices, _ = msvq(x)
>>> xq
tensor([-0.4561, 0.9835, -0.3787, -0.1488, -0.8025])
>>> indices
tensor([0, 2])
"""
if codebooks is not None:
cb = self.codebooks
for i, layer in enumerate(self.vq.layers):
layer._codebook.embed[:] = codebooks.view_as(cb)[i]

d = x.dim()
if d == 1:
x = x.unsqueeze(0)

xq, indices, losses = self.vq(x, **kwargs)

if d == 1:
xq = xq.squeeze(0)
indices = indices.squeeze(0)
losses = losses.squeeze()

return xq, indices, losses
11 changes: 8 additions & 3 deletions diffsptk/core/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def __init__(self, order, codebook_size, **kwargs):
assert 0 <= order
assert 1 <= codebook_size

self.vq = VectorQuantize(order + 1, codebook_size, **kwargs).float()
self.vq = VectorQuantize(
dim=order + 1, codebook_size=codebook_size, **kwargs
).float()

@property
def codebook(self):
return self.vq.codebook

def forward(self, x, codebook=None):
def forward(self, x, codebook=None, **kwargs):
"""Perform vector quantization.
Parameters
Expand All @@ -62,6 +64,9 @@ def forward(self, x, codebook=None):
codebook : Tensor [shape=(K, M+1)]
External codebook. If None, use internal codebook.
**kwargs : additional keyword arguments
See vector-quantize-pytorch repository for details.
Returns
-------
xq : Tensor [shape=(..., M+1)]
Expand All @@ -85,7 +90,7 @@ def forward(self, x, codebook=None):
"""
if codebook is not None:
self.vq.codebook[:] = codebook
self.codebook[:] = codebook.view_as(self.vq.codebook)

d = x.dim()
if d == 1:
Expand Down
9 changes: 9 additions & 0 deletions docs/core/imsvq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _imsvq:

imsvq
-----

.. autoclass:: diffsptk.InverseMultiStageVectorQuantization
:members:

.. seealso:: :ref:`ivq` :ref:`msvq`
9 changes: 9 additions & 0 deletions docs/core/msvq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.. _msvq:

msvq
----

.. autoclass:: diffsptk.MultiStageVectorQuantization
:members:

.. seealso:: :ref:`vq` :ref:`imsvq`
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"torch >= 1.10.0",
"torchcrepe >= 0.0.16",
"numpy",
"vector-quantize-pytorch >= 0.7.0",
"vector-quantize-pytorch >= 0.8.0",
],
extras_require={
"dev": [
Expand Down
43 changes: 43 additions & 0 deletions tests/test_imsvq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_compatibility(device, m=9, K=4, Q=2):
imsvq = diffsptk.InverseMultiStageVectorQuantization()

tmp1 = "imsvq.tmp1"
tmp2 = "imsvq.tmp2"
tmp3 = "imsvq.tmp3"
U.check_compatibility(
device,
imsvq,
[
f"echo 0 3 1 2 3 2 1 0 | x2x +ad > {tmp1}",
f"nrand -s 234 -l {K*(m+1)} > {tmp2}",
f"nrand -s 345 -l {K*(m+1)} > {tmp3}",
],
[f"cat {tmp1}", f"cat {tmp2} {tmp3}"],
f"x2x +di {tmp1} | imsvq -m {m} -s {tmp2} -s {tmp3}",
[f"rm {tmp1} {tmp2} {tmp3}"],
dx=[Q, (K, m + 1)],
dy=m + 1,
)
4 changes: 2 additions & 2 deletions tests/test_ivq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_compatibility(device, m=9, K=4, B=8):
def test_compatibility(device, m=9, K=4):
ivq = diffsptk.InverseVectorQuantization()

tmp1 = "ivq.tmp1"
tmp2 = "ivq.tmp2"
U.check_compatibility(
device,
ivq,
[f"ramp -l {K} > {tmp1}", f"nrand -s 234 -l {K*(m+1)} > {tmp2}"],
[f"echo 0 3 1 2 | x2x +ad > {tmp1}", f"nrand -s 234 -l {K*(m+1)} > {tmp2}"],
[f"cat {tmp1}", f"cat {tmp2}"],
f"x2x +di {tmp1} | imsvq -m {m} -s {tmp2}",
[f"rm {tmp1} {tmp2}"],
Expand Down
45 changes: 45 additions & 0 deletions tests/test_msvq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import pytest

import diffsptk
import tests.utils as U


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_compatibility(device, m=9, K=4, Q=2, B=8):
msvq = diffsptk.MultiStageVectorQuantization(m, K, Q)

tmp1 = "msvq.tmp1"
tmp2 = "msvq.tmp2"
tmp3 = "msvq.tmp3"
U.check_compatibility(
device,
[lambda x: x[1], msvq],
[
f"nrand -s 123 -l {B*(m+1)} > {tmp1}",
f"nrand -s 234 -l {K*(m+1)} > {tmp2}",
f"nrand -s 345 -l {K*(m+1)} > {tmp3}",
],
[f"cat {tmp1}", f"cat {tmp2} {tmp3}"],
f"msvq -m {m} -s {tmp2} -s {tmp3} < {tmp1} | x2x +id",
[f"rm {tmp1} {tmp2} {tmp3}"],
dx=[m + 1, m + 1],
dy=Q,
)

U.check_differentiable(device, [lambda x: x[2].sum(), msvq], [m + 1])
10 changes: 8 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ def check_compatibility(
x.append(torch.from_numpy(call(cmd)).to(device))
if is_array(dx):
if dx[i] is not None:
x[-1] = x[-1].reshape(-1, dx[i])
if is_array(dx[i]):
x[-1] = x[-1].reshape(-1, *dx[i])
else:
x[-1] = x[-1].reshape(-1, dx[i])
elif dx is not None:
x[-1] = x[-1].reshape(-1, dx)
if is_array(dx):
x[-1] = x[-1].reshape(-1, *dx)
else:
x[-1] = x[-1].reshape(-1, dx)
else:
pass

Expand Down

0 comments on commit 495f589

Please sign in to comment.