Skip to content

Commit

Permalink
start to enable batch processing features
Browse files Browse the repository at this point in the history
  • Loading branch information
ludlows committed May 5, 2022
1 parent c56b43c commit 3538396
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,25 @@ print(pesq(rate, ref, deg, 'wb'))
print(pesq(rate, ref, deg, 'nb'))
```

# Usage for batch version

```python
def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D or 2D array, shape (n_audio, length) ,reference audio signal
deg: numpy 1D or 2D array, shape (n_audio, length), degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
on_error: error-handling behavior, it could be PesqError.RETURN_VALUES or PesqError.RAISE_EXCEPTION by default
Returns:
pesq_score: numpy 1D array, shape (n_audio,) P.862.2 Prediction (MOS-LQO)
"""
```




# Correctness

The correctness is verified by running samples in audio folder.
Expand All @@ -98,3 +117,5 @@ Please click [here](https://github.com/ludlows/python-pesq/network/dependents) t
# Acknowledgement

This work was funded by the Natural Sciences and Engineering Research Council of Canada.

This work was also funded by the Concordia University, Montreal, Canada.
41 changes: 36 additions & 5 deletions pesq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
from .cypesq import BufferTooShortError, NoUtterancesError

USAGE = """
Run model on reference ref and degraded deg
Sample rate (fs) - No default. Must select either 8000 or 16000.
Note there is narrow band (nb) mode only when sampling rate is 8000Hz.
Run model on reference(ref) and degraded(deg)
Sample rate (fs) - No default. Must select either 8000 or 16000.
Note there is narrow band (nb) mode only when sampling rate is 8000Hz.
"""

USAGE_BATCH = USAGE + """
The shapes of ref and deg should be same if both are 2D numpy arrays.
Once the deg is 1D numpy array, the broadcast operation is applied.
"""


def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
Expand Down Expand Up @@ -48,10 +55,34 @@ def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
(deg/maxval).astype(np.float32),
mode_code
)
else:
return cypesq(
return cypesq(
fs,
(ref/maxval).astype(np.float32),
(deg/maxval).astype(np.float32),
mode_code
)

def pesq_batch(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
deg: numpy 1D or 2D array, degraded audio signal
fs: integer, sampling rate
mode: 'wb' (wide-band) or 'nb' (narrow-band)
Returns:
pesq_score: numpy 1D array, P.862.2 Prediction (MOS-LQO)
"""
# check mode
if mode != 'wb' and mode != 'nb':
print(USAGE_BATCH)
raise ValueError("mode should be either 'nb' or 'wb'")
# check fs
if fs != 8000 and fs != 16000:
print(USAGE_BATCH)
raise ValueError("fs (sampling frequency) should be either 8000 or 16000")

if fs == 8000 and mode == 'wb':
print(USAGE_BATCH)
raise ValueError("no wide band mode if fs = 8000")
# normalization

0 comments on commit 3538396

Please sign in to comment.