Skip to content

Commit

Permalink
add test functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ludlows committed May 6, 2022
1 parent 4807717 commit 2677454
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
16 changes: 8 additions & 8 deletions pesq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""


def check_fs_mode(mode, fs, usage=USAGE):
def _check_fs_mode(mode, fs, usage=USAGE):
if mode != 'wb' and mode != 'nb':
print(usage)
raise ValueError("mode should be either 'nb' or 'wb'")
Expand All @@ -33,7 +33,7 @@ def check_fs_mode(mode, fs, usage=USAGE):
raise ValueError("no wide band mode if fs = 8000")


def pesq_inner(fs, ref, deg, mode, on_error):
def _pesq_inner(fs, ref, deg, mode, on_error):
"""
Args:
ref: numpy 1D array, reference audio signal
Expand Down Expand Up @@ -75,8 +75,8 @@ def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
Returns:
pesq_score: float, P.862.2 Prediction (MOS-LQO)
"""
check_fs_mode(mode, fs, USAGE)
return pesq_inner(fs, ref, deg, mode, on_error)
_check_fs_mode(mode, fs, USAGE)
return _pesq_inner(fs, ref, deg, mode, on_error)


def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
Expand All @@ -90,23 +90,23 @@ def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
Returns:
pesq_score: numpy 1D array, P.862.2 Prediction (MOS-LQO)
"""
check_fs_mode(mode, fs, USAGE_BATCH)
_check_fs_mode(mode, fs, USAGE_BATCH)
# check dimension
if len(ref.shape) == 1:
if len(deg.shape) == 1 and ref.shape == deg.shape:
return pesq_inner(fs, ref, deg, mode, on_error)
return _pesq_inner(fs, ref, deg, mode, on_error)
elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = pesq_inner(fs, ref, deg[i, :], mode, on_error)
pesq_score[i] = _pesq_inner(fs, ref, deg[i, :], mode, on_error)
return pesq_score
else:
raise ValueError("The shapes of `deg` is invalid!")
elif len(ref.shape) == 2:
if deg.shape == ref.shape:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = pesq_inner(fs, ref[i, :], deg[i, :], mode, on_error)
pesq_score[i] = _pesq_inner(fs, ref[i, :], deg[i, :], mode, on_error)
return pesq_score
else:
raise ValueError("The shape of `deg` is invalid!")
Expand Down
36 changes: 35 additions & 1 deletion tests/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pathlib import Path

from pesq import pesq, NoUtterancesError, PesqError
from pesq import pesq, pesq_batch, NoUtterancesError, PesqError


def test():
Expand Down Expand Up @@ -50,3 +50,37 @@ def test_no_utterances_wb_mode():
on_error=PesqError.RETURN_VALUES)

assert score == PesqError.NO_UTTERANCES_DETECTED, score


def test_pesq_batch():
data_dir = Path(__file__).parent.parent / 'audio'
ref_path = data_dir / 'speech.wav'
deg_path = data_dir / 'speech_bab_0dB.wav'

sample_rate, ref = scipy.io.wavfile.read(ref_path)
sample_rate, deg = scipy.io.wavfile.read(deg_path)

n_file = 10
ideally = np.array([1.0832337141036987 for i in range(n_file)])

# 1D - 1D
score = pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='wb')
assert score == 1.0832337141036987, score

# 1D - 2D
deg_2d = np.repeat(deg[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb')
assert np.allclose(scores, ideally), scores

# 2D - 2D
ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='wb')
assert np.allclose(scores, ideally), scores

# narrowband
score = pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='nb')
assert score == 1.6072081327438354, score


if __name__ == "__main__":
test_pesq_batch()

0 comments on commit 2677454

Please sign in to comment.