-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
167 additions
and
157 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# 2019-May | ||
# github.com/ludlows | ||
# Python Wrapper for PESQ Score (narrowband and wideband) | ||
|
||
import numpy as np | ||
from multiprocessing import Pool, Queue, Process, cpu_count | ||
from functools import partial | ||
from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message | ||
from .cypesq import PesqError, InvalidSampleRateError, OutOfMemoryError | ||
from .cypesq import BufferTooShortError, NoUtterancesError | ||
|
||
|
||
__all__ = ['pesq', 'pesq_batch', 'PesqError', 'InvalidSampleRateError', 'OutOfMemoryError', 'BufferTooShortError', | ||
'NoUtterancesError'] | ||
|
||
USAGE = """ | ||
Run model on reference(ref) and degraded(deg) | ||
Sample rate (fs) - No default. Must select either 8000 or 16000. | ||
Note there is narrow band (nb) mode only when sampling rate is 8000Hz. | ||
""" | ||
|
||
USAGE_BATCH = USAGE + """ | ||
The shapes of ref and deg should be identical if both are 2D numpy arrays. | ||
Once the `ref` is 1D and the `deg` is 2D, the broadcast operation is applied. | ||
""" | ||
|
||
|
||
def _check_fs_mode(mode, fs, usage=USAGE): | ||
if mode != 'wb' and mode != 'nb': | ||
print(usage) | ||
raise ValueError("mode should be either 'nb' or 'wb'") | ||
|
||
if fs != 8000 and fs != 16000: | ||
print(usage) | ||
raise ValueError("fs (sampling frequency) should be either 8000 or 16000") | ||
|
||
if fs == 8000 and mode == 'wb': | ||
print(usage) | ||
raise ValueError("no wide band mode if fs = 8000") | ||
|
||
|
||
def _pesq_inner(ref, deg, fs=16000, mode='wb', on_error=PesqError.RAISE_EXCEPTION): | ||
""" | ||
Args: | ||
ref: numpy 1D array, reference audio signal | ||
deg: numpy 1D array, degraded audio signal | ||
fs: integer, sampling rate | ||
mode: 'wb' (wide-band) or 'nb' (narrow-band) | ||
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES | ||
Returns: | ||
pesq_score: float, P.862.2 Prediction (MOS-LQO) | ||
""" | ||
max_val = max(np.max(np.abs(ref / 1.0)), np.max(np.abs(deg / 1.0))) | ||
if mode == 'wb': | ||
mode_code = 1 | ||
else: | ||
mode_code = 0 | ||
if on_error == PesqError.RETURN_VALUES: | ||
return cypesq_retvals( | ||
fs, | ||
(ref / max_val).astype(np.float32), | ||
(deg / max_val).astype(np.float32), | ||
mode_code | ||
) | ||
return cypesq( | ||
fs, | ||
(ref / max_val).astype(np.float32), | ||
(deg / max_val).astype(np.float32), | ||
mode_code | ||
) | ||
|
||
|
||
def _processor_coordinator(func, args_q, results_q): | ||
while True: | ||
index, arg = args_q.get() | ||
if index is None: | ||
break | ||
try: | ||
result = func(*arg) | ||
except Exception as e: | ||
result = e | ||
results_q.put((index, result)) | ||
|
||
|
||
def _processor_mapping(func, args, n_processor): | ||
args_q = Queue(maxsize=1) | ||
results_q = Queue() | ||
processors = [Process(target=_processor_coordinator, args=(func, args_q, results_q)) for _ in range(n_processor)] | ||
for p in processors: | ||
p.daemon = True | ||
p.start() | ||
for i, arg in enumerate(args): | ||
args_q.put((i, arg)) | ||
# send stop messages | ||
for _ in range(n_processor): | ||
args_q.put((None, None)) | ||
results = [results_q.get() for _ in range(len(args))] | ||
[p.join() for p in processors] | ||
return [v[1] for v in sorted(results)] | ||
|
||
|
||
def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION): | ||
""" | ||
Args: | ||
ref: numpy 1D array, reference audio signal | ||
deg: numpy 1D array, degraded audio signal | ||
fs: integer, sampling rate | ||
mode: 'wb' (wide-band) or 'nb' (narrow-band) | ||
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES | ||
Returns: | ||
pesq_score: float, P.862.2 Prediction (MOS-LQO) | ||
""" | ||
_check_fs_mode(mode, fs, USAGE) | ||
return _pesq_inner(ref, deg, fs, mode, on_error) | ||
|
||
|
||
def pesq_batch(fs, ref, deg, mode, n_processor=cpu_count(), on_error=PesqError.RAISE_EXCEPTION): | ||
""" | ||
Running `pesq` using multiple processors | ||
Args: | ||
on_error: | ||
ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal | ||
deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal | ||
fs: integer, sampling rate | ||
mode: 'wb' (wide-band) or 'nb' (narrow-band) | ||
n_processor: cpu_count() (default) or number of processors (chosen by the user) or 0 (without multiprocessing) | ||
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES | ||
Returns: | ||
pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO) | ||
""" | ||
_check_fs_mode(mode, fs, USAGE_BATCH) | ||
# check dimension | ||
if len(ref.shape) == 1: | ||
if len(deg.shape) == 1 and ref.shape == deg.shape: | ||
return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)] | ||
elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]: | ||
if n_processor <= 0: | ||
pesq_score = [np.nan for i in range(deg.shape[0])] | ||
for i in range(deg.shape[0]): | ||
pesq_score[i] = _pesq_inner(ref, deg[i, :], fs, mode, on_error) | ||
return pesq_score | ||
else: | ||
with Pool(n_processor) as p: | ||
return p.map(partial(_pesq_inner, ref, fs=fs, mode=mode, on_error=on_error), | ||
[deg[i, :] for i in range(deg.shape[0])]) | ||
else: | ||
raise ValueError("The shapes of `deg` is invalid!") | ||
elif len(ref.shape) == 2: | ||
if deg.shape == ref.shape: | ||
if n_processor <= 0: | ||
pesq_score = [np.nan for i in range(deg.shape[0])] | ||
for i in range(deg.shape[0]): | ||
pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error) | ||
return pesq_score | ||
else: | ||
return _processor_mapping(_pesq_inner, | ||
[(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])], | ||
n_processor) | ||
else: | ||
raise ValueError("The shape of `deg` is invalid!") | ||
else: | ||
raise ValueError("The shape of `ref` should be either 1D or 2D!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters