Skip to content

Commit

Permalink
pesq_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
ludlows committed May 6, 2022
1 parent 3538396 commit 4807717
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 59 deletions.
108 changes: 67 additions & 41 deletions pesq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 2019-May
# github.com/ludlows
# Python Wrapper for PESQ Score (narrow band and wide band)
# Python Wrapper for PESQ Score (narrowband and wideband)

import numpy as np
from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message
Expand All @@ -14,75 +14,101 @@
"""

USAGE_BATCH = USAGE + """
The shapes of ref and deg should be same if both are 2D numpy arrays.
Once the deg is 1D numpy array, the broadcast operation is applied.
The shapes of ref and deg should be identical if both are 2D numpy arrays.
Once the `ref` is 1D and the `deg` is 2D, the broadcast operation is applied.
"""


def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
deg: numpy 1D array, degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
Returns:
pesq_score: float, P.862.2 Prediction (MOS-LQO)
"""
def check_fs_mode(mode, fs, usage=USAGE):
if mode != 'wb' and mode != 'nb':
print(USAGE)
print(usage)
raise ValueError("mode should be either 'nb' or 'wb'")

if fs != 8000 and fs != 16000:
print(USAGE)
print(usage)
raise ValueError("fs (sampling frequency) should be either 8000 or 16000")

if fs == 8000 and mode == 'wb':
print(USAGE)
print(usage)
raise ValueError("no wide band mode if fs = 8000")

maxval = max(np.max(np.abs(ref/1.0)), np.max(np.abs(deg/1.0)))

def pesq_inner(fs, ref, deg, mode, on_error):
"""
Args:
ref: numpy 1D array, reference audio signal
deg: numpy 1D array, degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
Returns:
pesq_score: float, P.862.2 Prediction (MOS-LQO)
"""
max_val = max(np.max(np.abs(ref / 1.0)), np.max(np.abs(deg / 1.0)))
if mode == 'wb':
mode_code = 1
else:
mode_code = 0

if on_error == PesqError.RETURN_VALUES:
return cypesq_retvals(
fs,
(ref/maxval).astype(np.float32),
(deg/maxval).astype(np.float32),
(ref / max_val).astype(np.float32),
(deg / max_val).astype(np.float32),
mode_code
)
return cypesq(
fs,
(ref/maxval).astype(np.float32),
(deg/maxval).astype(np.float32),
mode_code
)
fs,
(ref / max_val).astype(np.float32),
(deg / max_val).astype(np.float32),
mode_code
)


def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
deg: numpy 1D array, degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
Returns:
pesq_score: float, P.862.2 Prediction (MOS-LQO)
"""
check_fs_mode(mode, fs, USAGE)
return pesq_inner(fs, ref, deg, mode, on_error)


def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
deg: numpy 1D or 2D array, degraded audio signal
ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal
deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
Returns:
pesq_score: numpy 1D array, P.862.2 Prediction (MOS-LQO)
"""
# check mode
if mode != 'wb' and mode != 'nb':
print(USAGE_BATCH)
raise ValueError("mode should be either 'nb' or 'wb'")
# check fs
if fs != 8000 and fs != 16000:
print(USAGE_BATCH)
raise ValueError("fs (sampling frequency) should be either 8000 or 16000")

if fs == 8000 and mode == 'wb':
print(USAGE_BATCH)
raise ValueError("no wide band mode if fs = 8000")
# normalization

check_fs_mode(mode, fs, USAGE_BATCH)
# check dimension
if len(ref.shape) == 1:
if len(deg.shape) == 1 and ref.shape == deg.shape:
return pesq_inner(fs, ref, deg, mode, on_error)
elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = pesq_inner(fs, ref, deg[i, :], mode, on_error)
return pesq_score
else:
raise ValueError("The shapes of `deg` is invalid!")
elif len(ref.shape) == 2:
if deg.shape == ref.shape:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = pesq_inner(fs, ref[i, :], deg[i, :], mode, on_error)
return pesq_score
else:
raise ValueError("The shape of `deg` is invalid!")
else:
raise ValueError("The shape of `ref` should be either 1D or 2D!")
9 changes: 4 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# 2019-May
# github.com/ludlows
# Python Wrapper for PESQ Score (narrow band and wide band)
# Python Wrapper for PESQ Score (narrowband and wideband)
from setuptools import find_packages
from setuptools import setup, Extension


with open("README.md", "r") as fh:
long_description = fh.read()

Expand All @@ -27,20 +26,20 @@ def include_dirs(self, dirs):
extensions = [
CyPesqExtension(
"cypesq",
["pesq/cypesq.pyx", "pesq/dsp.c", "pesq/pesqdsp.c","pesq/pesqmod.c"],
["pesq/cypesq.pyx", "pesq/dsp.c", "pesq/pesqdsp.c", "pesq/pesqmod.c"],
include_dirs=['pesq'],
language="c")
]
setup(
name="pesq",
version="0.0.3",
version="0.0.4",
author="ludlows",
description="Python Wrapper for PESQ Score (narrow band and wide band)",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/ludlows/python-pesq",
packages=find_packages(),
package_data={'pesq':["*.pyx", "*.h", "dsp.c", "pesqdsp.c", "pesqmod.c"]},
package_data={'pesq': ["*.pyx", "*.h", "dsp.c", "pesqdsp.c", "pesqmod.c"]},
ext_package='pesq',
ext_modules=extensions,
setup_requires=['setuptools>=18.0', 'cython', 'numpy', 'pytest-runner'],
Expand Down
28 changes: 15 additions & 13 deletions tests/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pesq import pesq, NoUtterancesError, PesqError


def test():
data_dir = Path(__file__).parent.parent / 'audio'
ref_path = data_dir / 'speech.wav'
Expand All @@ -22,29 +23,30 @@ def test():

assert score == 1.6072081327438354, score


def test_no_utterances_nb_mode():
SAMPLE_RATE = 8000
silent_ref = np.zeros(SAMPLE_RATE)
deg = np.random.randn(SAMPLE_RATE)
sample_rate = 8000
silent_ref = np.zeros(sample_rate)
deg = np.random.randn(sample_rate)

with pytest.raises(NoUtterancesError) as e:
pesq(ref=silent_ref, deg=deg, fs=SAMPLE_RATE, mode='nb')
pesq(ref=silent_ref, deg=deg, fs=sample_rate, mode='nb')

score = pesq(ref=silent_ref, deg=deg, fs=SAMPLE_RATE, mode='nb',
on_error=PesqError.RETURN_VALUES)
score = pesq(ref=silent_ref, deg=deg, fs=sample_rate, mode='nb',
on_error=PesqError.RETURN_VALUES)

assert score == PesqError.NO_UTTERANCES_DETECTED, score


def test_no_utterances_wb_mode():
SAMPLE_RATE = 16000
silent_ref = np.zeros(SAMPLE_RATE)
deg = np.random.randn(SAMPLE_RATE)
sample_rate = 16000
silent_ref = np.zeros(sample_rate)
deg = np.random.randn(sample_rate)

with pytest.raises(NoUtterancesError) as e:
pesq(ref=silent_ref, deg=deg, fs=SAMPLE_RATE, mode='wb')
pesq(ref=silent_ref, deg=deg, fs=sample_rate, mode='wb')

score = pesq(ref=silent_ref, deg=deg, fs=SAMPLE_RATE, mode='wb',
on_error=PesqError.RETURN_VALUES)
score = pesq(ref=silent_ref, deg=deg, fs=sample_rate, mode='wb',
on_error=PesqError.RETURN_VALUES)

assert score == PesqError.NO_UTTERANCES_DETECTED, score

0 comments on commit 4807717

Please sign in to comment.