Skip to content

sp-nitech/diffsptk

Repository files navigation

diffsptk

diffsptk is a differentiable version of SPTK based on the PyTorch framework.

Latest Manual Stable Manual Python Version PyTorch Version PyPI Version Codecov License GitHub Actions

Requirements

  • Python 3.8+
  • PyTorch 1.10.0+

Documentation

See this page for a reference manual.

Installation

The latest stable release can be installed through PyPI by running

pip install diffsptk

Alternatively,

git clone https://github.com/sp-nitech/diffsptk.git
pip install -e diffsptk

Examples

Mel-cepstral analysis and synthesis

import diffsptk

# Set analysis condition.
fl = 400
fp = 80
n_fft = 512
M = 24
alpha = 0.42

# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Compute STFT amplitude of x.
stft = diffsptk.STFT(frame_length=fl, frame_period=fp, fft_length=n_fft)
X = stft(x)

# Estimate mel-cepstrum of x.
mcep = diffsptk.MelCepstralAnalysis(cep_order=M, fft_length=n_fft, alpha=alpha, n_iter=1)
mc = mcep(X)

# Reconstruct x.
mlsa = diffsptk.MLSA(filter_order=M, alpha=alpha, frame_period=fp, taylor_order=30)
x_hat = mlsa(mlsa(x, -mc), mc)

# Write reconstructed waveform.
diffsptk.write("reconst.wav", x_hat, sr)

# Compute error.
error = (x_hat - x).abs().sum()
print(error)

Mel-spectrogram extraction

import diffsptk

# Set analysis condition.
fl = 400
fp = 80
n_fft = 512
n_channel = 80

# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Compute STFT amplitude of x.
stft = diffsptk.STFT(frame_length=fl, frame_period=fp, fft_length=n_fft)
X = stft(x)

# Apply mel-filter banks to the STFT.
fbank = diffsptk.MelFilterBankAnalysis(
    n_channel=n_channel,
    fft_length=n_fft,
    sample_rate=sr,
)
Y = fbank(X)

Subband decomposition

import diffsptk

K = 4   # Number of subbands.
M = 40  # Order of filter.

# Read waveform.
x, sr = diffsptk.read("assets/data.wav")

# Decompose x.
pqmf = diffsptk.PQMF(K, M)
decimate = diffsptk.Decimation(K)
y = decimate(pqmf(x), dim=-1)

# Reconstruct x.
interpolate = diffsptk.Interpolation(K)
ipqmf = diffsptk.IPQMF(K, M)
x_hat = ipqmf(interpolate(K * y, dim=-1)).reshape(-1)

# Write reconstructed waveform.
diffsptk.write("reconst.wav", x_hat, sr)

# Compute error.
error = (x_hat - x).abs().sum()
print(error)

License

This software is released under the Apache License 2.0.