Skip to content

Commit

Permalink
✨ Models serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent a719a59 commit 0dc9320
Show file tree
Hide file tree
Showing 12 changed files with 1,516 additions and 297 deletions.
1 change: 1 addition & 0 deletions flama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from flama.modules import Module # noqa
from flama.pagination import * # noqa
from flama.routing import * # noqa
from flama.serialize import * # noqa
11 changes: 8 additions & 3 deletions flama/schemas/_libs/marshmallow/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from flama.schemas.adapter import Adapter
from flama.schemas.exceptions import SchemaGenerationError, SchemaValidationError

if typing.TYPE_CHECKING:
from apispec.ext.marshmallow import OpenAPIConverter

__all__ = ["MarshmallowAdapter"]


Expand Down Expand Up @@ -92,14 +95,16 @@ def to_json_schema(
) -> typing.Dict[str, typing.Any]:
json_schema: typing.Dict[str, typing.Any]
try:
plugin = MarshmallowPlugin(schema_name_resolver=lambda x: resolve_schema_cls(x).__name__)
plugin = MarshmallowPlugin(
schema_name_resolver=lambda x: resolve_schema_cls(x).__name__ # type: ignore[no-any-return]
)
APISpec("", "", "3.1.0", [plugin])
converter = plugin.converter
converter: "OpenAPIConverter" = plugin.converter # type: ignore[assignment]

if (inspect.isclass(schema) and issubclass(schema, marshmallow.fields.Field)) or isinstance(
schema, marshmallow.fields.Field
):
json_schema = converter.field2property(schema)
json_schema = converter.field2property(schema) # type: ignore[arg-type]
elif inspect.isclass(schema) and issubclass(schema, marshmallow.Schema):
json_schema = converter.schema2jsonschema(schema)
elif isinstance(schema, marshmallow.Schema):
Expand Down
21 changes: 21 additions & 0 deletions flama/serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import typing

from flama.serialize.types import Format, Model

__all__ = ["dump", "dumps", "load", "loads"]


def dumps(lib: typing.Union[str, Format], model: typing.Any) -> bytes:
return Model(lib, model).to_bytes()


def dump(lib: typing.Union[str, Format], model: typing.Any, fs: typing.BinaryIO) -> None:
fs.write(dumps(lib, model))


def loads(data: bytes) -> Model:
return Model.from_bytes(data)


def load(fs: typing.BinaryIO) -> Model:
return loads(fs.read())
14 changes: 14 additions & 0 deletions flama/serialize/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import abc
import typing

__all__ = ["Serializer"]


class Serializer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def dump(self, obj: typing.Any) -> bytes:
...

@abc.abstractmethod
def load(self, model: bytes) -> typing.Any:
...
Empty file.
13 changes: 13 additions & 0 deletions flama/serialize/serializers/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import codecs
import pickle
import typing

from flama.serialize.base import Serializer


class SKLearnSerializer(Serializer):
def dump(self, obj: typing.Any) -> bytes:
return codecs.encode(pickle.dumps(obj), "base64")

def load(self, model: bytes) -> typing.Any:
return pickle.loads(codecs.decode(model, "base64"))
13 changes: 13 additions & 0 deletions flama/serialize/serializers/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import codecs
import pickle
import typing

from flama.serialize.base import Serializer


class TensorflowSerializer(Serializer):
def dump(self, obj: typing.Any) -> bytes:
return codecs.encode(pickle.dumps(obj), "base64")

def load(self, model: bytes) -> typing.Any:
return pickle.loads(codecs.decode(model, "base64"))
54 changes: 54 additions & 0 deletions flama/serialize/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import codecs
import dataclasses
import enum
import json
import typing

from flama.serialize.base import Serializer
from flama.serialize.serializers.sklearn import SKLearnSerializer
from flama.serialize.serializers.tensorflow import TensorflowSerializer


class Format(enum.Enum):
sklearn = "sklearn"
tensorflow = "tensorflow"


SERIALIZERS = {
Format.sklearn: SKLearnSerializer(),
Format.tensorflow: TensorflowSerializer(),
}


@dataclasses.dataclass(frozen=True)
class Model:
lib: typing.Union[str, Format]
model: typing.Any

@classmethod
def serializer(cls, lib: typing.Union[str, Format]) -> Serializer:
try:
return SERIALIZERS[Format(lib)]
except ValueError:
raise ValueError("Wrong lib")

@classmethod
def from_bytes(cls, data: bytes) -> "Model":
try:
serialized_data = json.loads(codecs.decode(data, "zlib"))
lib = serialized_data["lib"]
model = cls.serializer(lib).load(serialized_data["model"].encode())
except KeyError:
raise ValueError("Wrong data")

return cls(lib, model)

def to_dict(self) -> typing.Dict[str, typing.Any]:
pickled_model = self.serializer(self.lib).dump(self.model).decode()
return {"lib": self.lib, "model": pickled_model}

def to_json(self) -> str:
return json.dumps(self.to_dict())

def to_bytes(self) -> bytes:
return codecs.encode(self.to_json().encode(), "zlib")
Loading

0 comments on commit 0dc9320

Please sign in to comment.