Skip to content

Commit

Permalink
enable multiple processors
Browse files Browse the repository at this point in the history
  • Loading branch information
ludlows committed May 7, 2022
1 parent d15c64d commit 5e86029
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 28 deletions.
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ And using 8000Hz is supported for narrow band only.
The code supports error-handling behaviors now.

```python
def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
Expand Down Expand Up @@ -78,20 +78,26 @@ print(pesq(rate, ref, deg, 'nb'))
# Usage for batch version

```python
def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
def pesq_batch(fs, ref, deg, mode='wb', n_processor=None, on_error=PesqError.RAISE_EXCEPTION):
"""
Running `pesq` using multiple processors
Args:
ref: numpy 1D or 2D array, shape (n_file, n_sample), reference audio signal
deg: numpy 1D or 2D array, shape (n_file, n_sample), degraded audio signal
on_error:
ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal
deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: error-handling behavior, it could be PesqError.RETURN_VALUES or PesqError.RAISE_EXCEPTION by default
n_processor: None (without multiprocessing) or number of processors
on_error:
Returns:
pesq_score: numpy 1D array, shape (n_file,) P.862.2 Prediction (MOS-LQO)
pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO)
"""
```
this function uses `multiprocessing` features to boost time efficiency.

When the `ref` is an 1-D numpy array and `deg` is a 2-D numpy array, the result of `pesq_batch` is identical to the value of `[pesq(fs, ref, deg[i,:],**kwargs) for i in range(deg.shape[0])]`.

When the `ref` is a 2-D numpy array and `deg` is a 2-D numpy array, the result of `pesq_batch` is identical to the value of `[pesq(fs, ref[i,:], deg[i,:],**kwargs) for i in range(deg.shape[0])]`.


# Correctness
Expand Down
79 changes: 61 additions & 18 deletions pesq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# Python Wrapper for PESQ Score (narrowband and wideband)

import numpy as np
from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message
from .cypesq import PesqError, InvalidSampleRateError, OutOfMemoryError
from .cypesq import BufferTooShortError, NoUtterancesError
from multiprocessing import Pool, Queue, Process
from functools import partial
from .cypesq import cypesq, cypesq_retvals
from .cypesq import PesqError

USAGE = """
Run model on reference(ref) and degraded(deg)
Expand Down Expand Up @@ -33,7 +34,7 @@ def _check_fs_mode(mode, fs, usage=USAGE):
raise ValueError("no wide band mode if fs = 8000")


def _pesq_inner(fs, ref, deg, mode, on_error):
def _pesq_inner(ref, deg, fs=16000, mode='wb', on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
Expand Down Expand Up @@ -64,7 +65,36 @@ def _pesq_inner(fs, ref, deg, mode, on_error):
)


def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
def _processor_coordinator(func, args_q, results_q):
while True:
index, arg = args_q.get()
if index is None:
break
try:
result = func(*arg)
except Exception as e:
result = e
results_q.put((index, result))


def _processor_mapping(func, args, n_processor):
args_q = Queue(maxsize=1)
results_q = Queue()
processors = [Process(target=_processor_coordinator, args=(func, args_q, results_q)) for _ in range(n_processor)]
for p in processors:
p.daemon = True
p.start()
for i, arg in enumerate(args):
args_q.put((i, arg))
# send stop messages
for _ in range(n_processor):
args_q.put((None, None))
results = [results_q.get() for _ in range(len(args))]
[p.join() for p in processors]
return [v[1] for v in sorted(results)]


def pesq(fs, ref, deg, mode='wb', on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
Expand All @@ -76,38 +106,51 @@ def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
pesq_score: float, P.862.2 Prediction (MOS-LQO)
"""
_check_fs_mode(mode, fs, USAGE)
return _pesq_inner(fs, ref, deg, mode, on_error)
return _pesq_inner(ref, deg, fs, mode, on_error)


def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
def pesq_batch(fs, ref, deg, mode, n_processor=None, on_error=PesqError.RAISE_EXCEPTION):
"""
Running `pesq` using multiple processors
Args:
on_error:
ref: numpy 1D (n_sample,) or 2D array (n_file, n_sample), reference audio signal
deg: numpy 1D (n_sample,) or 2D array (n_file, n_sample), degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: PesqError.RAISE_EXCEPTION (default) or PesqError.RETURN_VALUES
n_processor: None (without multiprocessing) or number of processors
on_error:
Returns:
pesq_score: numpy 1D array, P.862.2 Prediction (MOS-LQO)
pesq_score: list of pesq scores, P.862.2 Prediction (MOS-LQO)
"""
_check_fs_mode(mode, fs, USAGE_BATCH)
# check dimension
if len(ref.shape) == 1:
if len(deg.shape) == 1 and ref.shape == deg.shape:
return _pesq_inner(fs, ref, deg, mode, on_error)
return [_pesq_inner(ref, deg, fs, mode, PesqError.RETURN_VALUES)]
elif len(deg.shape) == 2 and ref.shape[-1] == deg.shape[-1]:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = _pesq_inner(fs, ref, deg[i, :], mode, on_error)
return pesq_score
if n_processor is None:
pesq_score = [np.nan for i in range(deg.shape[0])]
for i in range(deg.shape[0]):
pesq_score[i] = _pesq_inner(ref, deg[i, :], fs, mode, on_error)
return pesq_score
else:
with Pool(n_processor) as p:
return p.map(partial(_pesq_inner, ref, fs=fs, mode=mode, on_error=on_error),
[deg[i, :] for i in range(deg.shape[0])])
else:
raise ValueError("The shapes of `deg` is invalid!")
elif len(ref.shape) == 2:
if deg.shape == ref.shape:
pesq_score = np.array([np.nan for i in range(deg.shape[0])])
for i in range(deg.shape[0]):
pesq_score[i] = _pesq_inner(fs, ref[i, :], deg[i, :], mode, on_error)
return pesq_score
if n_processor is None:
pesq_score = [np.nan for i in range(deg.shape[0])]
for i in range(deg.shape[0]):
pesq_score[i] = _pesq_inner(ref[i, :], deg[i, :], fs, mode, on_error)
return pesq_score
else:
return _processor_mapping(_pesq_inner,
[(ref[i, :], deg[i, :], fs, mode, on_error) for i in range(deg.shape[0])],
n_processor)
else:
raise ValueError("The shape of `deg` is invalid!")
else:
Expand Down
21 changes: 17 additions & 4 deletions tests/test_pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,35 @@ def test_pesq_batch():

# 1D - 1D
score = pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='wb')
assert score == 1.0832337141036987, score
assert score == [1.0832337141036987], score

# 1D - 2D
deg_2d = np.repeat(deg[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb')
assert np.allclose(scores, ideally), scores
assert np.allclose(np.array(scores), ideally), scores

# 2D - 2D
ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='wb')
assert np.allclose(scores, ideally), scores
assert np.allclose(np.array(scores), ideally), scores

# narrowband
score = pesq_batch(ref=ref, deg=deg, fs=sample_rate, mode='nb')
assert score == 1.6072081327438354, score
assert score == [1.6072081327438354], score

# 1D - 2D multiprocessing
deg_2d = np.repeat(deg[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref, deg=deg_2d, fs=sample_rate, mode='wb', n_processor=4)
assert np.allclose(np.array(scores), ideally), scores

# 2D - 2D multiprocessing
ref_2d = np.repeat(ref[np.newaxis, :], n_file, axis=0)
scores = pesq_batch(ref=ref_2d, deg=deg_2d, fs=sample_rate, mode='wb', n_processor=4)
assert np.allclose(np.array(scores), ideally), scores


if __name__ == "__main__":
test()
test_no_utterances_nb_mode()
test_no_utterances_wb_mode()
test_pesq_batch()

0 comments on commit 5e86029

Please sign in to comment.