Skip to content

Commit

Permalink
refactor quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Feb 1, 2024
1 parent 70954e2 commit c25d8be
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 48 deletions.
73 changes: 73 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,44 @@ def delay(x, start=0, keeplen=False, dim=-1):
return fm.Delay._forward(x, start=start, keeplen=keeplen, dim=dim)


def dequantize(y, abs_max=1, n_bit=8, quantizer="mid-rise"):
"""Dequantize input.
Parameters
----------
y : Tensor [shape=(...,)]
Quantized input.
abs_max : float > 0
Absolute maximum value of input.
n_bit : int >= 1
Number of quantization bits.
quantizer : ['mid-rise', 'mid-tread']
Quantizer.
Returns
-------
Tensor [shape=(...,)]
Dequantized input.
Examples
--------
>>> x = diffsptk.ramp(-4, 4)
>>> x
tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.])
>>> y = diffsptk.functional.quantize(x, 4, 2)
>>> z = diffsptk.functional.dequantize(y, 4, 2)
>>> z
tensor([-3., -3., -1., -1., 1., 1., 3., 3., 3.])
"""
return fm.InverseUniformQuantization_.forward(
y, abs_max=abs_max, n_bit=n_bit, quantizer=quantizer
)


def grpdelay(b=None, a=None, *, fft_length=512, alpha=1, gamma=1, **kwargs):
"""Compute group delay.
Expand Down Expand Up @@ -190,6 +228,41 @@ def phase(b=None, a=None, *, fft_length=512, unwrap=False):
return fm.Phase._forward(b, a, fft_length=fft_length, unwrap=unwrap)


def quantize(x, abs_max=1, n_bit=8, quantizer="mid-rise"):
"""Quantize input.
Parameters
----------
x : Tensor [shape=(...,)]
Input.
abs_max : float > 0
Absolute maximum value of input.
n_bit : int >= 1
Number of quantization bits.
quantizer : ['mid-rise', 'mid-tread']
Quantizer.
Returns
-------
Tensor [shape=(...,)]
Quantized input.
Examples
--------
>>> x = diffsptk.ramp(-4, 4)
>>> y = diffsptk.functional.quantize(4, 2)
>>> y
tensor([0, 0, 1, 1, 2, 2, 3, 3, 3], dtype=torch.int32)
"""
return fm.UniformQuantization._forward(
x, abs_max=abs_max, n_bit=n_bit, quantizer=quantizer
)


def spec(
b=None, a=None, *, fft_length=512, eps=0, relative_floor=None, out_format="power"
):
Expand Down
8 changes: 8 additions & 0 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ def to_3d(x):
return y


def check(x, names):
if type(x) is int and 0 <= x <= len(names):
x = list(names)[x]
if x not in names:
raise ValueError(f"Unsupported value: {x}")
return x


def reflect(x):
d = x.size(-1)
y = x.view(-1, d)
Expand Down
47 changes: 25 additions & 22 deletions diffsptk/modules/dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import torch
import torch.nn as nn

from ..misc.utils import check
from .quantize import _quantization_level


class InverseUniformQuantization(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/dequantize.html>`_
for details.
Parameters
----------
abs_max : float > 0 [scalar]
abs_max : float > 0
Absolute maximum value of input.
n_bit : int >= 1 [scalar]
n_bit : int >= 1
Number of quantization bits.
quantizer : ['mid-rise', 'mid-tread']
Expand All @@ -39,19 +42,11 @@ def __init__(self, abs_max=1, n_bit=8, quantizer="mid-rise"):
super(InverseUniformQuantization, self).__init__()

self.abs_max = abs_max
self.quantizer = quantizer
self.n_bit = n_bit
self.quantizer = check(quantizer, _quantization_level)

assert 0 < self.abs_max
assert 1 <= n_bit

if quantizer == 0 or quantizer == "mid-rise":
self.level = int(2**n_bit)
self.quantizer = "mid-rise"
elif quantizer == 1 or quantizer == "mid-tread":
self.level = int(2**n_bit) - 1
self.quantizer = "mid-tread"
else:
raise ValueError("quantizer {quantizer} is not supported")
assert 1 <= self.n_bit

def forward(self, y):
"""Dequantize input.
Expand All @@ -63,7 +58,7 @@ def forward(self, y):
Returns
-------
x : Tensor [shape=(...,)]
Tensor [shape=(...,)]
Dequantized input.
Examples
Expand All @@ -78,13 +73,21 @@ def forward(self, y):
tensor([-3., -3., -1., -1., 1., 1., 3., 3., 3.])
"""
if self.quantizer == "mid-rise":
y = y - (self.level // 2 - 0.5)
elif self.quantizer == "mid-tread":
y = y - (self.level - 1) // 2
return self._forward(y, self.abs_max, self.n_bit, self.quantizer)

@staticmethod
def _forward(y, abs_max, n_bit, quantizer):
try:
level = _quantization_level[quantizer](n_bit)
except KeyError:
raise ValueError(f"quantizer {quantizer} is not supported")

if quantizer == "mid-rise":
y = y - (level // 2 - 0.5)
elif quantizer == "mid-tread":
y = y - (level - 1) // 2
else:
raise RuntimeError

x = y * (2 * self.abs_max / self.level)
x = torch.clip(x, min=-self.abs_max, max=self.abs_max)
raise ValueError(f"quantizer {quantizer} is not supported")
x = y * (2 * abs_max / level)
x = torch.clip(x, min=-abs_max, max=abs_max)
return x
47 changes: 27 additions & 20 deletions diffsptk/modules/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
import torch
import torch.nn as nn

from ..misc.utils import check

_quantization_level = {
"mid-rise": lambda n_bit: 1 << n_bit,
"mid-tread": lambda n_bit: (1 << n_bit) - 1,
}


class Floor(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -59,19 +66,11 @@ def __init__(self, abs_max=1, n_bit=8, quantizer="mid-rise"):
super(UniformQuantization, self).__init__()

self.abs_max = abs_max
self.quantizer = quantizer
self.n_bit = n_bit
self.quantizer = check(quantizer, _quantization_level)

assert 0 < self.abs_max
assert 1 <= n_bit

if quantizer == 0 or quantizer == "mid-rise":
self.level = int(2**n_bit)
self.quantizer = "mid-rise"
elif quantizer == 1 or quantizer == "mid-tread":
self.level = int(2**n_bit) - 1
self.quantizer = "mid-tread"
else:
raise ValueError("quantizer {quantizer} is not supported")
assert 1 <= self.n_bit

def forward(self, x):
"""Quantize input.
Expand All @@ -83,7 +82,7 @@ def forward(self, x):
Returns
-------
y : Tensor [shape=(...,)]
Tensor [shape=(...,)]
Quantized input.
Examples
Expand All @@ -95,15 +94,23 @@ def forward(self, x):
tensor([0, 0, 1, 1, 2, 2, 3, 3, 3], dtype=torch.int32)
"""
x = x * self.level / (2 * self.abs_max)
if self.quantizer == "mid-rise":
x = x + self.level // 2
return self._forward(x, self.abs_max, self.n_bit, self.quantizer)

@staticmethod
def _forward(x, abs_max, n_bit, quantizer):
try:
level = _quantization_level[quantizer](n_bit)
except KeyError:
raise ValueError(f"quantizer {quantizer} is not supported")

x = x * (level / (2 * abs_max))
if quantizer == "mid-rise":
x += level // 2
y = Floor.apply(x)
elif self.quantizer == "mid-tread":
x = x + (self.level - 1) // 2
elif quantizer == "mid-tread":
x += (level - 1) // 2
y = Round.apply(x)
else:
raise RuntimeError

y = torch.clip(y, min=0, max=self.level - 1)
raise ValueError(f"quantizer {quantizer} is not supported")
y = torch.clip(y, min=0, max=level - 1)
return y
12 changes: 6 additions & 6 deletions diffsptk/modules/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn as nn

from ..misc.utils import check
from ..misc.utils import remove_gain

_spec2spec = {
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, fft_length, eps=0, relative_floor=None, out_format="power"):

self.fft_length = fft_length
self.eps = eps
self.out_format = check(out_format)

assert 2 <= self.fft_length
assert 0 <= self.eps
Expand All @@ -62,11 +64,6 @@ def __init__(self, fft_length, eps=0, relative_floor=None, out_format="power"):
assert relative_floor < 0
self.relative_floor = 10 ** (relative_floor / 10)

if type(out_format) is int and 0 <= out_format <= len(_spec2spec.keys()):
out_format = list(_spec2spec.keys())[out_format]
assert out_format in _spec2spec.keys()
self.out_format = out_format

def forward(self, b=None, a=None):
"""Compute spectrum.
Expand Down Expand Up @@ -125,5 +122,8 @@ def _forward(b, a, fft_length, eps, relative_floor, out_format):
if relative_floor is not None:
m, _ = torch.max(s, dim=-1, keepdim=True)
s = torch.maximum(s, m * relative_floor)
s = _spec2spec[out_format](s)
try:
s = _spec2spec[out_format](s)
except KeyError:
raise ValueError(f"out_format {out_format} is not supported.")
return s

0 comments on commit c25d8be

Please sign in to comment.