Skip to content

Commit

Permalink
🐛 Amend some Enum instantiations from strings
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Mar 17, 2023
1 parent cbd026f commit 18daa90
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion flama/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
self.func = self._create_task_function(func)
self.args = args
self.kwargs = kwargs
self.concurrency = Concurrency(concurrency)
self.concurrency = Concurrency[concurrency] if isinstance(concurrency, str) else concurrency

def _create_task_function(self, func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if asyncio.iscoroutinefunction(func):
Expand Down
8 changes: 4 additions & 4 deletions flama/serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flama.serialize.data_structures import ModelArtifact
from flama.serialize.dump import dump
from flama.serialize.load import load
from flama.serialize.data_structures import * # noqa
from flama.serialize.dump import * # noqa
from flama.serialize.load import * # noqa

__all__ = ["dump", "load", "ModelArtifact"]
__all__ = ["dump", "load"] # noqa
7 changes: 5 additions & 2 deletions flama/serialize/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from flama.serialize.base import Serializer
from flama.serialize.types import Framework

__all__ = ["ModelArtifact", "Compression"]


class Compression(enum.Enum):
fast = "gz"
Expand All @@ -31,7 +33,7 @@ def serializer(cls, framework: t.Union[str, Framework]) -> Serializer:
Framework.sklearn: ("sklearn", "SKLearnSerializer"),
Framework.keras: ("tensorflow", "TensorFlowSerializer"),
Framework.tensorflow: ("tensorflow", "TensorFlowSerializer"),
}[Framework(framework)]
}[Framework[framework] if isinstance(framework, str) else framework]
except KeyError: # pragma: no cover
raise ValueError("Wrong framework")

Expand Down Expand Up @@ -242,7 +244,8 @@ def dump(
:param compression: Compression type.
:param kwargs: Keyword arguments passed to library dump method.
"""
with tarfile.open(path, f"w:{Compression(compression).value}") as tar:
compression = Compression[compression] if isinstance(compression, str) else compression
with tarfile.open(path, f"w:{compression.value}") as tar:
if self.artifacts:
for name, path in self.artifacts.items():
tar.add(path, f"artifacts/{name}")
Expand Down

0 comments on commit 18daa90

Please sign in to comment.