diff --git a/flama/background.py b/flama/background.py index a5305649..5df2f609 100644 --- a/flama/background.py +++ b/flama/background.py @@ -31,7 +31,7 @@ def __init__( self.func = self._create_task_function(func) self.args = args self.kwargs = kwargs - self.concurrency = Concurrency(concurrency) + self.concurrency = Concurrency[concurrency] if isinstance(concurrency, str) else concurrency def _create_task_function(self, func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]: if asyncio.iscoroutinefunction(func): diff --git a/flama/serialize/__init__.py b/flama/serialize/__init__.py index 4c560bd7..6cc9cae1 100644 --- a/flama/serialize/__init__.py +++ b/flama/serialize/__init__.py @@ -1,5 +1,5 @@ -from flama.serialize.data_structures import ModelArtifact -from flama.serialize.dump import dump -from flama.serialize.load import load +from flama.serialize.data_structures import * # noqa +from flama.serialize.dump import * # noqa +from flama.serialize.load import * # noqa -__all__ = ["dump", "load", "ModelArtifact"] +__all__ = ["dump", "load"] # noqa diff --git a/flama/serialize/data_structures.py b/flama/serialize/data_structures.py index b9475543..6afea937 100644 --- a/flama/serialize/data_structures.py +++ b/flama/serialize/data_structures.py @@ -15,6 +15,8 @@ from flama.serialize.base import Serializer from flama.serialize.types import Framework +__all__ = ["ModelArtifact", "Compression"] + class Compression(enum.Enum): fast = "gz" @@ -31,7 +33,7 @@ def serializer(cls, framework: t.Union[str, Framework]) -> Serializer: Framework.sklearn: ("sklearn", "SKLearnSerializer"), Framework.keras: ("tensorflow", "TensorFlowSerializer"), Framework.tensorflow: ("tensorflow", "TensorFlowSerializer"), - }[Framework(framework)] + }[Framework[framework] if isinstance(framework, str) else framework] except KeyError: # pragma: no cover raise ValueError("Wrong framework") @@ -242,7 +244,8 @@ def dump( :param compression: Compression type. :param kwargs: Keyword arguments passed to library dump method. """ - with tarfile.open(path, f"w:{Compression(compression).value}") as tar: + compression = Compression[compression] if isinstance(compression, str) else compression + with tarfile.open(path, f"w:{compression.value}") as tar: if self.artifacts: for name, path in self.artifacts.items(): tar.add(path, f"artifacts/{name}")