Skip to content

Commit

Permalink
add load directionn
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Mar 1, 2021
1 parent dd070a8 commit a9824bc
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyspj/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import SPJResult
from .continuity import ContinuitySPJResult
from .general import load_result
from .general import load_result, to_continuity, to_simple, ResultType
from .simple import SimpleSPJResult
87 changes: 77 additions & 10 deletions pyspj/models/general.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import unique, IntEnum
from typing import Tuple, Optional

from .base import SPJResult
Expand Down Expand Up @@ -48,19 +49,85 @@ def _load_result_from_tuple(data: tuple) -> SPJResult:
return _load_from_values(_correctness, _score, _message, _detail)


def load_result(data) -> SPJResult:
@unique
class ResultType(IntEnum):
FREE = 0
SIMPLE = 1
CONTINUITY = 2

@classmethod
def loads(cls, value) -> 'ResultType':
"""
Load result type from value
:param value: raw value
:return: result type object
"""
if isinstance(value, cls):
return value
elif isinstance(value, str):
if value.upper() in cls.__members__.keys():
return cls.__members__[value.upper()]
else:
raise KeyError('Unknown result type - {actual}.'.format(actual=repr(value)))
elif isinstance(value, int):
_mapping = {v.value: v for k, v in cls.__members__.items()}
if value in _mapping.keys():
return _mapping[value]
else:
raise ValueError('Unknown result type value - {actual}'.format(actual=repr(value)))
else:
raise TypeError('Int, str or {cls} expected but {actual} found.'.format(
cls=cls.__name__,
actual=repr(type(value).__name__)
))


def load_result(data, type_=None) -> SPJResult:
"""
load result from all kinds of data
:param data: raw data
:param type_: result type
:return: spj result
"""
if isinstance(data, SimpleSPJResult):
return data
elif isinstance(data, ContinuitySPJResult):
return data
elif isinstance(data, dict):
return _load_result_from_dict(data)
elif isinstance(data, (list, tuple)):
return _load_result_from_tuple(tuple(data))

def _func():
if isinstance(data, SimpleSPJResult):
return data
elif isinstance(data, ContinuitySPJResult):
return data
elif isinstance(data, dict):
return _load_result_from_dict(data)
elif isinstance(data, (list, tuple)):
return _load_result_from_tuple(tuple(data))
else:
return SimpleSPJResult(not not data)

_result = _func()
type_ = ResultType.loads(type_ or ResultType.FREE)
if type_ == ResultType.SIMPLE:
return to_simple(_result)
elif type_ == ResultType.CONTINUITY:
return to_continuity(_result)
else:
return SimpleSPJResult(not not data)
return _result


def to_simple(data) -> SimpleSPJResult:
"""
to simple result
:param data: original data
:return: simple spj result
"""
return SimpleSPJResult(**load_result(data).to_json())


def to_continuity(data) -> ContinuitySPJResult:
"""
to continuity result
:param data: original data
:return: continuity spj result
"""
_dict = dict(load_result(data).to_json())
if 'score' not in _dict:
_dict['score'] = 0.0
return ContinuitySPJResult(**_dict)
56 changes: 55 additions & 1 deletion test/models/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from pyspj.models import load_result, SimpleSPJResult, ContinuitySPJResult
from pyspj.models import load_result, SimpleSPJResult, ContinuitySPJResult, ResultType


@pytest.mark.unittest
Expand All @@ -25,12 +25,66 @@ def test_continuity(self):
assert load_result(((True, 0.5), '123')) == ContinuitySPJResult(True, 0.5, '123')
assert load_result(((True, 0.5), '123', '12345')) == result

def test_simple_force(self):
result = SimpleSPJResult(True, '123', '12345')
assert load_result(result, 'simple') == result
assert load_result(result.to_json(), 'simple') == result
assert load_result((True,), 'simple') == SimpleSPJResult(True, )
assert load_result(True, 'simple') == SimpleSPJResult(True, )
assert load_result(None, 'simple') == SimpleSPJResult(False, )
assert load_result((True, '123'), 'simple') == SimpleSPJResult(True, '123')
assert load_result((True, '123', '12345'), 'simple') == result

result = ContinuitySPJResult(True, 0.5, '123', '12345')
assert load_result(result, 'simple') == SimpleSPJResult(True, '123', '12345')
assert load_result(result.to_json(), 'simple') == SimpleSPJResult(True, '123', '12345')
assert load_result(((True, 0.5),), 'simple') == SimpleSPJResult(True)
assert load_result(((True, 0.5), '123'), 'simple') == SimpleSPJResult(True, '123')
assert load_result(((True, 0.5), '123', '12345'), 'simple') == SimpleSPJResult(True, '123', '12345')

def test_continuity_force(self):
result = SimpleSPJResult(True, '123', '12345')
assert load_result(result, 'continuity') == ContinuitySPJResult(True, 0.0, '123', '12345')
assert load_result(result.to_json(), 'continuity') == ContinuitySPJResult(True, 0.0, '123', '12345')
assert load_result((True,), 'continuity') == ContinuitySPJResult(True, 0.0)
assert load_result(True, 'continuity') == ContinuitySPJResult(True, 0.0)
assert load_result(None, 'continuity') == ContinuitySPJResult(False, 0.0, )
assert load_result((True, '123'), 'continuity') == ContinuitySPJResult(True, 0.0, '123')
assert load_result((True, '123', '12345'), 'continuity') == ContinuitySPJResult(True, 0.0, '123', '12345')

result = ContinuitySPJResult(True, 0.5, '123', '12345')
assert load_result(result, 'continuity') == result
assert load_result(result.to_json(), 'continuity') == result
assert load_result(((True, 0.5),), 'continuity') == ContinuitySPJResult(True, 0.5)
assert load_result(((True, 0.5), '123'), 'continuity') == ContinuitySPJResult(True, 0.5, '123')
assert load_result(((True, 0.5), '123', '12345'), 'continuity') == result

def test_invalid(self):
with pytest.raises(ValueError):
assert load_result(())
with pytest.raises(ValueError):
assert load_result((1, 2, 3, 4))

def test_result_type(self):
assert ResultType.loads(ResultType.FREE) == ResultType.FREE
assert ResultType.loads(ResultType.SIMPLE) == ResultType.SIMPLE
assert ResultType.loads(ResultType.CONTINUITY) == ResultType.CONTINUITY

assert ResultType.loads('free') == ResultType.FREE
assert ResultType.loads('simple') == ResultType.SIMPLE
assert ResultType.loads('continuity') == ResultType.CONTINUITY
with pytest.raises(KeyError):
ResultType.loads('sdkfjlsd')

assert ResultType.loads(0) == ResultType.FREE
assert ResultType.loads(1) == ResultType.SIMPLE
assert ResultType.loads(2) == ResultType.CONTINUITY
with pytest.raises(ValueError):
ResultType.loads(-100)

with pytest.raises(TypeError):
ResultType.loads([])


if __name__ == "__main__":
pytest.main([os.path.abspath(__file__)])

0 comments on commit a9824bc

Please sign in to comment.