Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev'
Browse files Browse the repository at this point in the history
merge error-handling feature into master branch
  • Loading branch information
wangmiao authored and wangmiao committed May 11, 2021
2 parents abeb477 + 18dc469 commit 6e38968
Show file tree
Hide file tree
Showing 7 changed files with 331 additions and 175 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ dmypy.json

# Pyre type checker
.pyre/

# vIM buffer files
*.swp
*.swo
31 changes: 26 additions & 5 deletions pesq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# Python Wrapper for PESQ Score (narrow band and wide band)

import numpy as np
from .cypesq import cypesq
from .cypesq import cypesq, cypesq_retvals, cypesq_error_message as pesq_error_message
from .cypesq import PesqError, InvalidSampleRateError, OutOfMemoryError
from .cypesq import BufferTooShortError, NoUtterancesError

USAGE = """
Run model on reference ref and degraded deg
Sample rate (fs) - No default. Must select either 8000 or 16000.
Note there is narrow band (nb) mode only when sampling rate is 8000Hz.
"""
def pesq(fs, ref, deg, mode):
def pesq(fs, ref, deg, mode, on_error=PesqError.RAISE_EXCEPTION):
"""
Args:
ref: numpy 1D array, reference audio signal
Expand All @@ -23,14 +25,33 @@ def pesq(fs, ref, deg, mode):
if mode != 'wb' and mode != 'nb':
print(USAGE)
raise ValueError("mode should be either 'nb' or 'wb'")

if fs != 8000 and fs != 16000:
print(USAGE)
raise ValueError("fs (sampling frequency) should be either 8000 or 16000")

if fs == 8000 and mode == 'wb':
print(USAGE)
raise ValueError("no wide band mode if fs = 8000")

maxval = max(np.max(np.abs(ref/1.0)), np.max(np.abs(deg/1.0)))
if mode == 'wb':
return cypesq(fs, (ref/maxval).astype(np.float32), (deg/maxval).astype(np.float32), 1)
return cypesq(fs, (ref/maxval).astype(np.float32), (deg/maxval).astype(np.float32), 0)

if mode == 'wb':
mode_code = 1
else:
mode_code = 0

if on_error == PesqError.RETURN_VALUES:
return cypesq_retvals(
fs,
(ref/maxval).astype(np.float32),
(deg/maxval).astype(np.float32),
mode_code
)
else:
return cypesq(
fs,
(ref/maxval).astype(np.float32),
(deg/maxval).astype(np.float32),
mode_code
)
93 changes: 88 additions & 5 deletions pesq/cypesq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,56 @@
#Python Wrapper for PESQ Score (narrow band and wide band)

import cython
cimport numpy as np
cimport numpy as np

class PesqError(RuntimeError):
# Error Return Values
SUCCESS = 0
UNKNOWN = -1
INVALID_SAMPLE_RATE = -2
OUT_OF_MEMORY_REF = -3
OUT_OF_MEMORY_DEG = -4
OUT_OF_MEMORY_TMP = -5
BUFFER_TOO_SHORT = -6
NO_UTTERANCES_DETECTED = -7

# On Error Type
RAISE_EXCEPTION = 0
RETURN_VALUES = 1

class InvalidSampleRateError(PesqError):
pass

class OutOfMemoryError(PesqError):
pass

class BufferTooShortError(PesqError):
pass

class NoUtterancesError(PesqError):
pass

cdef char** cypesq_error_messages = [
"Success",
"Unknown",
"Invalid sampling rate",
"Unable to allocate memory for reference buffer",
"Unable to allocate memory for degraded buffer",
"Unable to allocate memory for temporary buffer",
"Buffer needs to be at least 1/4 of a second long",
"No utterances detected"
]

cpdef char* cypesq_error_message(int code):
global cypesq_error_messages

if code > PesqError.SUCCESS:
code = PesqError.SUCCESS

if code < PesqError.NO_UTTERANCES_DETECTED:
code = PesqError.UNKNOWN

return cypesq_error_messages[-code]

cdef extern from "pesq.h":
DEF MAXNUTTERANCES = 50
Expand Down Expand Up @@ -50,13 +99,20 @@ cdef extern from "pesqmain.h":
cdef void pesq_measure(SIGNAL_INFO * ref_info, SIGNAL_INFO * deg_info, ERROR_INFO * err_info, long * Error_Flag, char ** Error_Type)


cpdef object cypesq(long sample_rate, np.ndarray[float, ndim=1, mode="c"] ref_data, np.ndarray[float, ndim=1, mode="c"] deg_data, int mode):

cpdef object cypesq_retvals(long sample_rate,
np.ndarray[float, ndim=1, mode="c"] ref_data,
np.ndarray[float, ndim=1, mode="c"] deg_data,
int mode):
# select rate
cdef long error_flag = 0;
cdef char * error_type = "unknown";

select_rate(sample_rate, &error_flag, &error_type)
if error_flag != 0:
return -1
# They are all literals, this is not a leak (probably)
return PesqError.INVALID_SAMPLE_RATE

# assign signal
cdef long length_ref
cdef long length_deg
Expand Down Expand Up @@ -107,8 +163,35 @@ cpdef object cypesq(long sample_rate, np.ndarray[float, ndim=1, mode="c"] ref_da
err_info.mode = WB_MODE

pesq_measure(&ref_info, &deg_info, &err_info, &error_flag, &error_type);
if error_flag!=0:
return -1
if error_flag != 0:
return error_flag

return err_info.mapped_mos

cpdef object cypesq(long sample_rate,
np.ndarray[float, ndim=1, mode="c"] ref_data,
np.ndarray[float, ndim=1, mode="c"] deg_data,
int mode):
cdef object ret = cypesq_retvals(sample_rate, ref_data, deg_data, mode)

# Null and Positive are valid values.
if ret >= 0:
return ret

cdef char* error_message = cypesq_error_message(ret)

if ret == PesqError.INVALID_SAMPLE_RATE:
raise InvalidSampleRateError(error_message)

if ret in [ PesqError.OUT_OF_MEMORY_REF, PesqError.OUT_OF_MEMORY_DEG, PesqError.OUT_OF_MEMORY_TMP ]:
raise OutOfMemoryError(error_message)

if ret == PesqError.BUFFER_TOO_SHORT:
raise BufferTooShortError(error_message)

if ret == PesqError.NO_UTTERANCES_DETECTED:
raise NoUtterancesError(error_message)

# Raise unknown otherwise
raise PesqError(error_message)

15 changes: 13 additions & 2 deletions pesq/pesq.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ extern long Align_Nfft_16k;
#ifndef PESQ_H
#define PESQ_H

#define PESQ_ERROR_SUCCESS 0
#define PESQ_ERROR_UNKNOWN -1
#define PESQ_ERROR_INVALID_SAMPLE_RATE -2
#define PESQ_ERROR_OUT_OF_MEMORY_REF -3
#define PESQ_ERROR_OUT_OF_MEMORY_DEG -4
#define PESQ_ERROR_OUT_OF_MEMORY_TMP -5
#define PESQ_ERROR_BUFFER_TOO_SHORT -6
#define PESQ_ERROR_NO_UTTERANCES_DETECTED -7

typedef struct {
char path_name[512];
char file_name [128];
Expand Down Expand Up @@ -320,8 +329,9 @@ void split_align( SIGNAL_INFO * ref_info, SIGNAL_INFO * deg_info,
long * Best_ED2, long * Best_D2, float * Best_DC2,
long * Best_BP );
void pesq_psychoacoustic_model(
SIGNAL_INFO * ref_info, SIGNAL_INFO * deg_info,
ERROR_INFO * err_info, float * ftmp);
SIGNAL_INFO * ref_info, SIGNAL_INFO * deg_info,
ERROR_INFO * err_info, long * Error_Flag, char ** Error_Type,
float * ftmp);
void apply_pesq( float * x_data, float * ref_surf,
float * y_data, float * deg_surf, long NVAD_windows, float * ftmp,
ERROR_INFO * err_info );
Expand Down Expand Up @@ -359,5 +369,6 @@ ERROR_INFO * err_info );
#ifndef A_WEIGHT
#define A_WEIGHT 0.0309
#endif

/* END OF FILE */

Loading

0 comments on commit 6e38968

Please sign in to comment.