From bb67ec2ad515f404b050c0e8dece159c51326cc6 Mon Sep 17 00:00:00 2001 From: ludlows Date: Tue, 10 May 2022 16:45:09 +0800 Subject: [PATCH] refactor pesq module --- pesq/__init__.py | 156 +------------------------------------------ pesq/_pesq.py | 162 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_pesq.py | 6 +- 3 files changed, 167 insertions(+), 157 deletions(-) create mode 100644 pesq/_pesq.py diff --git a/pesq/__init__.py b/pesq/__init__.py index 42b3693..979cb80 100644 --- a/pesq/__init__.py +++ b/pesq/__init__.py @@ -2,157 +2,5 @@ # github.com/ludlows # Python Wrapper for PESQ Score (narrowband and wideband) -import numpy as np -from multiprocessing import Pool, Queue, Process, cpu_count -from functools import partial -from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message -from .cypesq import PesqError, InvalidSampleRateError, OutOfMemoryError -from .cypesq import BufferTooShortError, NoUtterancesError - -USAGE = """ - Run model on reference(ref) and degraded(deg) - Sample rate (fs) - No default. Must select either 8000 or 16000. - Note there is narrow band (nb) mode only when sampling rate is 8000Hz. - """ - -USAGE_BATCH = USAGE + """ - The shapes of ref and deg should be identical if both are 2D numpy arrays. - Once the `ref` is 1D and the `deg` is 2D, the broadcast operation is applied. - """ - - -def _check_fs_mode(mode, fs, usage=USAGE): - if mode != 'wb' and mode != 'nb': - print(usage) - raise ValueError("mode should be either 'nb' or 'wb'") - - if fs != 8000 and fs != 16000: - print(usage) - raise ValueError("fs (sampling frequency) should be either 8000 or 16000") - - if fs == 8000 and mode == 'wb': - print(usage) - raise ValueError("no wide band mode if fs = 8000") - - -def _pesq_inner(ref, deg, fs=16000, mode='wb', on_error=PesqError.RAISE_EXCEPTION): - """ - Args: - ref: numpy 1D array, reference audio signal - deg: numpy 1D array, degraded audio signal - fs: integer, sampling rate - mode: 'wb' (wide-band) or 'nb' (narrow-band) - on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES - Returns: - pesq_score: float, P.862.2 Prediction (MOS-LQO) - """ - max_val = max(np.max(np.abs(ref / 1.0)), np.max(np.abs(deg / 1.0))) - if mode == 'wb': - mode_code = 1 - else: - mode_code = 0 - if on_error == PesqError.RETURN_VALUES: - return cypesq_retvals( - fs, - (ref / max_val).astype(np.float32), - (deg / max_val).astype(np.float32), - mode_code - ) - return cypesq( - fs, - (ref / max_val).astype(np.float32), - (deg / max_val).astype(np.float32), - mode_code - ) - - -def _processor_coordinator(func, args_q, results_q): - while True: - index, arg = args_q.get() - if index is None: - break - try: - result = func(*arg) - except Exception as e: - result = e - results_q.put((index, result)) - - -def _processor_mapping(func, args, n_processor): - args_q = Queue(maxsize=1) - results_q = Queue() - processors = [Process(target=_processor_coordinator, args=(func, args_q, results_q)) for _ in range(n_processor)] - for p in processors: - p.daemon = True - p.start() - for i, arg in enumerate(args): - args_q.put((i, arg)) - # send stop messages - for _ in range(n_processor): - args_q.put((None, None)) - results = [results_q.get() for _ in range(len(args))] - [p.join() for p in processors] - return [v[1] for v in sorted(results)] - - -def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION): - """ - Args: - ref: numpy 1D array, reference audio signal - deg: numpy 1D array, degraded audio signal - fs: integer, sampling rate - mode: 'wb' (wide-band) or 'nb' (narrow-band) - on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES - Returns: - pesq_score: float, P.862.2 Prediction (MOS-LQO) - """ - _check_fs_mode(mode, fs, USAGE) - return _pesq_inner(ref, deg, fs, mode, on_error) - - -def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION): - """ - Running `pesq` using multiple processors - Args: - on_error: - ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal - deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal - fs: integer, sampling rate - mode: 'wb' (wide-band) or 'nb' (narrow-band) - n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing) - on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES - Returns: - pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO) - """ - _check_fs_mode(mode, fs, USAGE_BATCH) - # check dimension - if len(ref.shape) == 1: - if len(deg.shape) == 1 and ref.shape == deg.shape: - return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)] - elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]: - if n_processor <= 0: - pesq_score = [np.nan for i in range(deg.shape[0])] - for i in range(deg.shape[0]): - pesq_score[i] = _pesq_inner(ref, deg[i, :], fs, mode, on_error) - return pesq_score - else: - with Pool(n_processor) as p: - return p.map(partial(_pesq_inner, ref, fs=fs, mode=mode, on_error=on_error), - [deg[i, :] for i in range(deg.shape[0])]) - else: - raise ValueError("The shapes of `deg` is invalid!") - elif len(ref.shape) == 2: - if deg.shape == ref.shape: - if n_processor <= 0: - pesq_score = [np.nan for i in range(deg.shape[0])] - for i in range(deg.shape[0]): - pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error) - return pesq_score - else: - return _processor_mapping(_pesq_inner, - [(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])], - n_processor) - else: - raise ValueError("The shape of `deg` is invalid!") - else: - raise ValueError("The shape of `ref` should be either 1D or 2D!") +from ._pesq import pesq, pesq_batch +from ._pesq import PesqError, InvalidSampleRateError, OutOfMemoryError, BufferTooShortError, NoUtterancesError diff --git a/pesq/_pesq.py b/pesq/_pesq.py new file mode 100644 index 0000000..f9eb27a --- /dev/null +++ b/pesq/_pesq.py @@ -0,0 +1,162 @@ +# 2019-May +# github.com/ludlows +# Python Wrapper for PESQ Score (narrowband and wideband) + +import numpy as np +from multiprocessing import Pool, Queue, Process, cpu_count +from functools import partial +from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message +from .cypesq import PesqError, InvalidSampleRateError, OutOfMemoryError +from .cypesq import BufferTooShortError, NoUtterancesError + + +__all__ = ['pesq', 'pesq_batch', 'PesqError', 'InvalidSampleRateError', 'OutOfMemoryError', 'BufferTooShortError', + 'NoUtterancesError'] + +USAGE = """ + Run model on reference(ref) and degraded(deg) + Sample rate (fs) - No default. Must select either 8000 or 16000. + Note there is narrow band (nb) mode only when sampling rate is 8000Hz. + """ + +USAGE_BATCH = USAGE + """ + The shapes of ref and deg should be identical if both are 2D numpy arrays. + Once the `ref` is 1D and the `deg` is 2D, the broadcast operation is applied. + """ + + +def _check_fs_mode(mode, fs, usage=USAGE): + if mode != 'wb' and mode != 'nb': + print(usage) + raise ValueError("mode should be either 'nb' or 'wb'") + + if fs != 8000 and fs != 16000: + print(usage) + raise ValueError("fs (sampling frequency) should be either 8000 or 16000") + + if fs == 8000 and mode == 'wb': + print(usage) + raise ValueError("no wide band mode if fs = 8000") + + +def _pesq_inner(ref, deg, fs=16000, mode='wb', on_error=PesqError.RAISE_EXCEPTION): + """ + Args: + ref: numpy 1D array, reference audio signal + deg: numpy 1D array, degraded audio signal + fs: integer, sampling rate + mode: 'wb' (wide-band) or 'nb' (narrow-band) + on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES + Returns: + pesq_score: float, P.862.2 Prediction (MOS-LQO) + """ + max_val = max(np.max(np.abs(ref / 1.0)), np.max(np.abs(deg / 1.0))) + if mode == 'wb': + mode_code = 1 + else: + mode_code = 0 + if on_error == PesqError.RETURN_VALUES: + return cypesq_retvals( + fs, + (ref / max_val).astype(np.float32), + (deg / max_val).astype(np.float32), + mode_code + ) + return cypesq( + fs, + (ref / max_val).astype(np.float32), + (deg / max_val).astype(np.float32), + mode_code + ) + + +def _processor_coordinator(func, args_q, results_q): + while True: + index, arg = args_q.get() + if index is None: + break + try: + result = func(*arg) + except Exception as e: + result = e + results_q.put((index, result)) + + +def _processor_mapping(func, args, n_processor): + args_q = Queue(maxsize=1) + results_q = Queue() + processors = [Process(target=_processor_coordinator, args=(func, args_q, results_q)) for _ in range(n_processor)] + for p in processors: + p.daemon = True + p.start() + for i, arg in enumerate(args): + args_q.put((i, arg)) + # send stop messages + for _ in range(n_processor): + args_q.put((None, None)) + results = [results_q.get() for _ in range(len(args))] + [p.join() for p in processors] + return [v[1] for v in sorted(results)] + + +def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION): + """ + Args: + ref: numpy 1D array, reference audio signal + deg: numpy 1D array, degraded audio signal + fs: integer, sampling rate + mode: 'wb' (wide-band) or 'nb' (narrow-band) + on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES + Returns: + pesq_score: float, P.862.2 Prediction (MOS-LQO) + """ + _check_fs_mode(mode, fs, USAGE) + return _pesq_inner(ref, deg, fs, mode, on_error) + + +def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION): + """ + Running `pesq` using multiple processors + Args: + on_error: + ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal + deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal + fs: integer, sampling rate + mode: 'wb' (wide-band) or 'nb' (narrow-band) + n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing) + on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES + Returns: + pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO) + """ + _check_fs_mode(mode, fs, USAGE_BATCH) + # check dimension + if len(ref.shape) == 1: + if len(deg.shape) == 1 and ref.shape == deg.shape: + return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)] + elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]: + if n_processor <= 0: + pesq_score = [np.nan for i in range(deg.shape[0])] + for i in range(deg.shape[0]): + pesq_score[i] = _pesq_inner(ref, deg[i, :], fs, mode, on_error) + return pesq_score + else: + with Pool(n_processor) as p: + return p.map(partial(_pesq_inner, ref, fs=fs, mode=mode, on_error=on_error), + [deg[i, :] for i in range(deg.shape[0])]) + else: + raise ValueError("The shapes of `deg` is invalid!") + elif len(ref.shape) == 2: + if deg.shape == ref.shape: + if n_processor <= 0: + pesq_score = [np.nan for i in range(deg.shape[0])] + for i in range(deg.shape[0]): + pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error) + return pesq_score + else: + return _processor_mapping(_pesq_inner, + [(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])], + n_processor) + else: + raise ValueError("The shape of `deg` is invalid!") + else: + raise ValueError("The shape of `ref` should be either 1D or 2D!") diff --git a/tests/test_pesq.py b/tests/test_pesq.py index df5a395..e76d87e 100755 --- a/tests/test_pesq.py +++ b/tests/test_pesq.py @@ -1,9 +1,9 @@ -import pytest +from pathlib import Path + import numpy as np +import pytest import scipy.io.wavfile -from pathlib import Path - from pesq import pesq, pesq_batch, NoUtterancesError, PesqError