Skip to content

Commit

Permalink
Merge pull request #13 from sp-nitech/mlsadf
Browse files Browse the repository at this point in the history
Add mglsadf
  • Loading branch information
takenori-y committed Nov 22, 2022
2 parents 17dae15 + f6538f4 commit 07d25b2
Show file tree
Hide file tree
Showing 24 changed files with 816 additions and 230 deletions.
206 changes: 51 additions & 155 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ diffsptk
*diffsptk* is a differentiable version of [SPTK](https://github.com/sp-nitech/SPTK) based on the PyTorch framework.

[![Latest Manual](https://img.shields.io/badge/docs-latest-blue.svg)](https://sp-nitech.github.io/diffsptk/latest/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/0.4.0/)
[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/0.5.0/)
[![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyTorch Version](https://img.shields.io/badge/pytorch-1.10.0%20%7C%201.13.0-orange.svg)](https://pypi.python.org/pypi/diffsptk)
[![PyPI Version](https://img.shields.io/pypi/v/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk)
Expand Down Expand Up @@ -38,50 +38,75 @@ pip install -e diffsptk

Examples
--------
### Mel-cepstral analysis
### Mel-cepstral analysis and synthesis
```python
import diffsptk
import torch

# Generate waveform.
x = torch.randn(100)
# Set analysis condition.
fl = 400
fp = 80
n_fft = 512
M = 24
alpha = 0.42

# Compute STFT of x.
stft = diffsptk.STFT(frame_length=12, frame_period=10, fft_length=16)
# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Compute STFT amplitude of x.
stft = diffsptk.STFT(frame_length=fl, frame_period=fp, fft_length=n_fft)
X = stft(x)

# Estimate 4-th order mel-cepstrum of x.
mcep = diffsptk.MelCepstralAnalysis(cep_order=4, fft_length=16, alpha=0.1, n_iter=1)
# Estimate mel-cepstrum of x.
mcep = diffsptk.MelCepstralAnalysis(cep_order=M, fft_length=n_fft, alpha=alpha, n_iter=1)
mc = mcep(X)

# Reconstruct x.
mlsa = diffsptk.MLSA(filter_order=M, alpha=alpha, frame_period=fp, taylor_order=30)
x_hat = mlsa(mlsa(x, -mc), mc)

# Write reconstructed waveform.
diffsptk.write("reconst.wav", x_hat, sr)

# Compute error.
error = (x_hat - x).abs().sum()
print(error)
```

### Mel-spectrogram extraction
```python
import diffsptk
import torch

# Generate waveform.
x = torch.randn(100)
# Set analysis condition.
fl = 400
fp = 80
n_fft = 512
n_channel = 80

# Compute STFT of x.
stft = diffsptk.STFT(frame_length=12, frame_period=10, fft_length=32)
# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Compute STFT amplitude of x.
stft = diffsptk.STFT(frame_length=fl, frame_period=fp, fft_length=n_fft)
X = stft(x)

# Apply 4 mel-filter banks to the STFT.
fbank = diffsptk.MelFilterBankAnalysis(n_channel=4, fft_length=32, sample_rate=8000, floor=1e-1)
# Apply mel-filter banks to the STFT.
fbank = diffsptk.MelFilterBankAnalysis(
n_channel=n_channel,
fft_length=n_fft,
sample_rate=sr,
)
Y = fbank(X)
```

### Subband decomposition
```python
import diffsptk
import torch

K = 4 # Number of subbands.
M = 40 # Order of filter.

# Generate waveform.
x = torch.randn(100)
# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Decompose x.
pqmf = diffsptk.PQMF(K, M)
Expand All @@ -91,144 +116,15 @@ y = decimate(pqmf(x), dim=-1)
# Reconstruct x.
interpolate = diffsptk.Interpolation(K)
ipqmf = diffsptk.IPQMF(K, M)
x_hat = ipqmf(interpolate(K * y, dim=-1))

# Compute error between two signals.
error = torch.abs(x_hat - x).sum()
```
x_hat = ipqmf(interpolate(K * y, dim=-1)).reshape(-1)

# Write reconstructed waveform.
diffsptk.write("reconst.wav", x_hat, sr)

Status
------
~~module~~ will not be implemented in this repository.
- [x] acorr
- [ ] ~~acr2csm~~
- [ ] ~~aeq~~ (*torch.allclose*)
- [ ] ~~amgcep~~
- [ ] ~~average~~ (*torch.mean*)
- [x] b2mc
- [ ] ~~bcp~~ (*torch.split*)
- [ ] ~~bcut~~
- [x] c2acr
- [x] c2mpir
- [x] c2ndps
- [x] cdist
- [ ] ~~clip~~ (*torch.clip*)
- [ ] ~~csm2acr~~
- [x] dct
- [x] decimate
- [x] delay
- [x] delta
- [x] dequantize
- [x] df2
- [x] dfs
- [ ] ~~dmp~~
- [ ] ~~dtw~~
- [ ] ~~dtw_merge~~
- [ ] ~~entropy~~ (*torch.special.entr*)
- [ ] ~~excite~~
- [ ] ~~extract~~
- [x] fbank
- [ ] ~~fd~~
- [ ] ~~fdrw~~
- [ ] ~~fft~~ (*torch.fft.fft*)
- [ ] ~~fft2~~ (*torch.fft.fft2*)
- [x] fftcep
- [ ] ~~fftr~~ (*torch.fft.rfft*)
- [ ] ~~fftr2~~ (*torch.fft.rfft2*)
- [x] frame
- [x] freqt
- [ ] ~~glogsp~~
- [ ] ~~gmm~~
- [ ] ~~gmmp~~
- [x] gnorm
- [ ] ~~gpolezero~~
- [ ] ~~grlogsp~~
- [x] grpdelay
- [ ] ~~gseries~~
- [ ] ~~gspecgram~~
- [ ] ~~gwave~~
- [ ] ~~histogram~~ (*torch.histogram*)
- [ ] ~~huffman~~
- [ ] ~~huffman_decode~~
- [ ] ~~huffman_encode~~
- [x] idct
- [ ] ~~ifft~~ (*torch.fft.ifft*)
- [ ] ~~ifft2~~ (*torch.fft.ifft2*)
- [x] ignorm
- [ ] imglsadf (*will be appeared*)
- [x] impulse
- [x] imsvq
- [x] interpolate
- [x] ipqmf
- [x] iulaw
- [x] lar2par
- [ ] ~~lbg~~
- [x] levdur
- [x] linear_intpl
- [x] lpc
- [ ] ~~lpc2c~~
- [ ] ~~lpc2lsp~~
- [x] lpc2par
- [x] lpccheck
- [ ] ~~lsp2lpc~~
- [ ] ~~lspcheck~~
- [ ] ~~lspdf~~
- [ ] ~~ltcdf~~
- [x] mc2b
- [x] mcpf
- [ ] ~~median~~ (*torch.median*)
- [ ] ~~merge~~ (*torch.cat*)
- [x] mfcc
- [x] mgc2mgc
- [x] mgc2sp
- [x] mgcep
- [ ] mglsadf (*will be appeared*)
- [ ] ~~mglsp2sp~~
- [ ] ~~minmax~~
- [x] mlpg (*support only unit variance*)
- [ ] ~~mlsacheck~~
- [x] mpir2c
- [ ] ~~mseq~~
- [ ] ~~msvq~~
- [ ] ~~nan~~ (*torch.isnan*)
- [x] ndps2c
- [x] norm0
- [ ] ~~nrand~~ (*torch.randn*)
- [x] par2lar
- [x] par2lpc
- [x] pca
- [ ] ~~pcas~~
- [x] phase
- [x] pitch
- [ ] ~~pitch_mark~~
- [ ] ~~poledf~~
- [x] pqmf
- [x] quantize
- [x] ramp
- [ ] ~~reverse~~ (*torch.flip*)
- [ ] ~~rlevdur~~
- [x] rmse
- [ ] ~~root_pol~~
- [x] sin
- [x] smcep
- [x] snr
- [x] sopr
- [x] spec
- [x] step
- [ ] ~~swab~~
- [ ] ~~symmetrize~~
- [ ] ~~train~~
- [ ] ~~transpose~~ (*torch.transpose*)
- [x] ulaw
- [ ] ~~vc~~
- [ ] ~~vopr~~
- [ ] ~~vstat~~ (*torch.var_mean*)
- [ ] ~~vsum~~ (*torch.sum*)
- [x] window
- [ ] ~~x2x~~
- [x] zcross
- [x] zerodf
# Compute error.
error = (x_hat - x).abs().sum()
print(error)
```


License
Expand Down
Binary file added assets/data.wav
Binary file not shown.
2 changes: 2 additions & 0 deletions diffsptk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from . import core
from .core import *
from .misc.signals import *
from .misc.utils import read
from .misc.utils import write
from .version import __version__
4 changes: 4 additions & 0 deletions diffsptk/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from .idct import InverseDiscreteCosineTransform
from .idct import InverseDiscreteCosineTransform as IDCT
from .ignorm import GeneralizedCepstrumInverseGainNormalization
from .imglsadf import PseudoInverseMGLSADigitalFilter
from .imglsadf import PseudoInverseMGLSADigitalFilter as IMLSA
from .interpolate import Interpolation
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks
from .ipqmf import InversePseudoQuadratureMirrorFilterBanks as IPQMF
Expand All @@ -46,6 +48,8 @@
from .mgc2sp import MelGeneralizedCepstrumToSpectrum
from .mgcep import MelGeneralizedCepstralAnalysis
from .mgcep import MelGeneralizedCepstralAnalysis as MelCepstralAnalysis
from .mglsadf import PseudoMGLSADigitalFilter
from .mglsadf import PseudoMGLSADigitalFilter as MLSA
from .mlpg import MaximumLikelihoodParameterGeneration
from .mlpg import MaximumLikelihoodParameterGeneration as MLPG
from .mpir2c import MinimumPhaseImpulseResponseToCepstrum
Expand Down
30 changes: 12 additions & 18 deletions diffsptk/core/c2mpir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import torch
import torch.nn as nn

from ..misc.utils import cexp
from ..misc.utils import check_size


class CepstrumToMinimumPhaseImpulseResponse(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/c2mpir.html>`_
for details. This module may be slow due to recursive computation.
for details. The conversion uses FFT instead of recursive formula.
Parameters
----------
Expand All @@ -32,21 +33,24 @@ class CepstrumToMinimumPhaseImpulseResponse(nn.Module):
impulse_response_length : int >= 1 [scalar]
Length of impulse response, :math:`N`.
n_fft : int >> :math:`N` [scalar]
Number of FFT bins. Accurate conversion requires the large value.
"""

def __init__(self, cep_order, impulse_response_length):
def __init__(self, cep_order, impulse_response_length, n_fft=512):
super(CepstrumToMinimumPhaseImpulseResponse, self).__init__()

self.cep_order = cep_order
self.impulse_response_length = impulse_response_length
self.n_fft = n_fft

assert 0 <= self.cep_order
assert 1 <= self.impulse_response_length

self.register_buffer("ramp", torch.arange(1, self.cep_order + 1))
assert max(self.cep_order + 1, self.impulse_response_length) < self.n_fft

def forward(self, c):
"""Convert cepstrum to impulse response.
"""Convert cepstrum to minimum phase impulse response.
Parameters
----------
Expand All @@ -56,7 +60,7 @@ def forward(self, c):
Returns
-------
h : Tensor [shape=(..., N)]
Truncated impulse response.
Truncated minimum phase impulse response.
Examples
--------
Expand All @@ -69,16 +73,6 @@ def forward(self, c):
"""
check_size(c.size(-1), self.cep_order + 1, "dimension of cepstrum")

c0 = c[..., 0]
c1 = (c[..., 1:] * self.ramp).flip(-1)

h = torch.empty(
(*(c.shape[:-1]), self.impulse_response_length), device=c.device
)
h[..., 0] = torch.exp(c0)
for n in range(1, self.impulse_response_length):
s = n - self.cep_order
h[..., n] = (h[..., max(0, s) : n].clone() * c1[..., max(0, -s) :]).sum(
-1
) / n
C = torch.fft.fft(c, n=self.n_fft)
h = torch.fft.ifft(cexp(C))[..., : self.impulse_response_length].real
return h
Loading

0 comments on commit 07d25b2

Please sign in to comment.