Skip to content

Commit

Permalink
✨ Enhanced model serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
migduroli authored and perdy committed Jan 19, 2023
1 parent 9c482d8 commit c8128c1
Show file tree
Hide file tree
Showing 22 changed files with 1,091 additions and 1,004 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__/
.Python
env/
.venv/
.idea/
build/
develop-eggs/
dist/
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ repos:
- id: name-tests-test
args:
- --django
exclude: "asserts.py|utils.py"
- id: pretty-format-json
args:
- --autofix
Expand Down
50 changes: 20 additions & 30 deletions flama/models/components.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import abc
import json
import typing
import typing as t

from flama import exceptions
from flama.injection import Component
from flama.serialize import ModelFormat, loads
from flama.serialize import loads
from flama.serialize.types import Framework

if t.TYPE_CHECKING:
from flama.serialize.data_structures import Metadata

try:
import torch
Expand All @@ -20,27 +23,20 @@


class Model:
def __init__(self, model: typing.Any):
def __init__(self, model: t.Any, meta: "Metadata"):
self.model = model
self.meta: "Metadata" = meta

@abc.abstractmethod
def inspect(self) -> typing.Any:
...
def inspect(self) -> t.Any:
return self.meta.to_dict()

@abc.abstractmethod
def predict(self, x: typing.Any) -> typing.Any:
def predict(self, x: t.Any) -> t.Any:
...


class PyTorchModel(Model):
def inspect(self) -> typing.Any:
return {
"modules": [str(x) for x in self.model.modules()],
"parameters": {k: str(v) for k, v in self.model.named_parameters()},
"state": self.model.state_dict(),
}

def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
assert torch is not None, "`torch` must be installed to use PyTorchModel."

try:
Expand All @@ -50,21 +46,15 @@ def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:


class SKLearnModel(Model):
def inspect(self) -> typing.Any:
return self.model.get_params()

def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
try:
return self.model.predict(x).tolist()
except ValueError as e:
raise exceptions.HTTPException(status_code=400, detail=str(e))


class TensorFlowModel(Model):
def inspect(self) -> typing.Any:
return json.loads(self.model.to_json())

def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
assert tensorflow is not None, "`tensorflow` must be installed to use TensorFlowModel."

try:
Expand All @@ -77,23 +67,23 @@ class ModelComponent(Component):
def __init__(self, model):
self.model = model

def get_model_type(self) -> typing.Type[Model]:
def get_model_type(self) -> t.Type[Model]:
return self.model.__class__ # type: ignore[no-any-return]


class ModelComponentBuilder:
MODELS = {
ModelFormat.pytorch: ("PyTorchModel", PyTorchModel),
ModelFormat.sklearn: ("SKLearnModel", SKLearnModel),
ModelFormat.tensorflow: ("TensorFlowModel", TensorFlowModel),
Framework.torch: ("PyTorchModel", PyTorchModel),
Framework.sklearn: ("SKLearnModel", SKLearnModel),
Framework.tensorflow: ("TensorFlowModel", TensorFlowModel),
}

@classmethod
def loads(cls, data: bytes) -> ModelComponent:
load_model = loads(data)
name, parent = cls.MODELS[load_model.lib]
name, parent = cls.MODELS[load_model.meta.framework.lib]
model_class = type(name, (parent,), {})
model_obj = model_class(load_model.model)
model_obj = model_class(load_model.model, load_model.meta)

class SpecificModelComponent(ModelComponent):
def resolve(self) -> model_class: # type: ignore[valid-type]
Expand Down
4 changes: 2 additions & 2 deletions flama/serialize/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from flama.serialize.data_structures import ModelArtifact
from flama.serialize.dump import dump, dumps
from flama.serialize.load import load, loads
from flama.serialize.model import Model, ModelFormat

__all__ = ["dump", "dumps", "load", "loads", "Model", "ModelFormat"]
__all__ = ["dump", "dumps", "load", "loads", "ModelArtifact"]
18 changes: 15 additions & 3 deletions flama/serialize/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import abc
import typing
import typing as t

from flama.serialize.types import Framework

__all__ = ["Serializer"]


class Serializer(metaclass=abc.ABCMeta):
lib: t.ClassVar[Framework]

@abc.abstractmethod
def dump(self, obj: t.Any, **kwargs) -> bytes:
...

@abc.abstractmethod
def load(self, model: bytes, **kwargs) -> t.Any:
...

@abc.abstractmethod
def dump(self, obj: typing.Any, **kwargs) -> bytes:
def info(self, model: t.Any) -> t.Dict[str, t.Any]:
...

@abc.abstractmethod
def load(self, model: bytes, **kwargs) -> typing.Any:
def version(self) -> str:
...
210 changes: 210 additions & 0 deletions flama/serialize/data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import codecs
import dataclasses
import datetime
import inspect
import json
import typing as t
import uuid

from flama.serialize.base import Serializer
from flama.serialize.serializers.pytorch import PyTorchSerializer
from flama.serialize.serializers.sklearn import SKLearnSerializer
from flama.serialize.serializers.tensorflow import TensorFlowSerializer
from flama.serialize.types import Framework


class FrameworkSerializers:
_SERIALIZERS = {
Framework.sklearn: SKLearnSerializer(),
Framework.tensorflow: TensorFlowSerializer(),
Framework.keras: TensorFlowSerializer(),
Framework.torch: PyTorchSerializer(),
}

@classmethod
def serializer(cls, framework: t.Union[str, Framework]) -> Serializer:
try:
return cls._SERIALIZERS[Framework(framework)]
except ValueError: # pragma: no cover
raise ValueError("Wrong framework")

@classmethod
def from_model(cls, model: t.Any) -> Serializer:
inspect_objs = [model]

try:
inspect_objs += model.__class__.__mro__
except AttributeError:
...

for obj in inspect_objs:
try:
return cls.serializer(inspect.getmodule(obj).__name__.split(".", 1)[0]) # type: ignore[union-attr]
except (ValueError, AttributeError):
...
else:
raise ValueError("Unknown model framework")


@dataclasses.dataclass
class FrameworkInfo:
"""Dataclass for storing model framework information."""

lib: Framework
version: str

@classmethod
def from_model(cls, model: t.Any) -> "FrameworkInfo":
serializer = FrameworkSerializers.from_model(model)
return cls(lib=serializer.lib, version=serializer.version())

@classmethod
def from_dict(cls, data: t.Dict[str, t.Any]) -> "FrameworkInfo":
return cls(lib=Framework[data["lib"]], version=data["version"])

def to_dict(self) -> t.Dict[str, t.Any]:
return {"lib": self.lib.value, "version": self.version}


@dataclasses.dataclass
class ModelInfo:
"""Dataclass for storing model info."""

obj: str
info: t.Dict[str, t.Any]
params: t.Optional[t.Dict[str, t.Any]] = None
metrics: t.Optional[t.Dict[str, t.Any]] = None

@classmethod
def from_model(
cls, model: t.Any, params: t.Optional[t.Dict[str, t.Any]], metrics: t.Optional[t.Dict[str, t.Any]]
) -> "ModelInfo":
return cls(
obj=model.__name__ if inspect.isclass(model) else model.__class__.__name__,
info=FrameworkSerializers.from_model(model).info(model),
params=params,
metrics=metrics,
)

@classmethod
def from_dict(cls, data: t.Dict[str, t.Any]):
return cls(obj=data["obj"], info=data["info"], params=data.get("params"), metrics=data.get("metrics"))

def to_dict(self) -> t.Dict[str, t.Any]:
return dataclasses.asdict(self)


@dataclasses.dataclass(frozen=True)
class Metadata:
"""Dataclass for storing model metadata."""

id: t.Union[str, uuid.UUID]
timestamp: datetime.datetime
framework: FrameworkInfo
model: ModelInfo
extra: t.Optional[t.Dict[str, t.Any]] = None

@classmethod
def from_model(
cls,
model: t.Any,
*,
model_id: t.Optional[t.Union[str, uuid.UUID]],
timestamp: t.Optional[datetime.datetime],
params: t.Optional[t.Dict[str, t.Any]],
metrics: t.Optional[t.Dict[str, t.Any]],
extra: t.Optional[t.Dict[str, t.Any]],
) -> "Metadata":
return cls(
id=model_id or uuid.uuid4(),
timestamp=timestamp or datetime.datetime.now(),
framework=FrameworkInfo.from_model(model),
model=ModelInfo.from_model(model, params, metrics),
extra=extra,
)

@classmethod
def from_dict(cls, data: t.Dict[str, t.Any]) -> "Metadata":
try:
id_ = uuid.UUID(data["id"])
except ValueError:
id_ = data["id"]

timestamp = (
datetime.datetime.fromisoformat(data["timestamp"])
if isinstance(data["timestamp"], str)
else data["timestamp"]
)

return cls(
id=id_,
timestamp=timestamp,
framework=FrameworkInfo.from_dict(data["framework"]),
model=ModelInfo.from_dict(data["model"]),
extra=data.get("extra"),
)

def to_dict(self) -> t.Dict[str, t.Any]:
return {
"id": str(self.id),
"timestamp": self.timestamp.isoformat(),
"framework": self.framework.to_dict(),
"model": self.model.to_dict(),
"extra": self.extra,
}


@dataclasses.dataclass(frozen=True)
class ModelArtifact:
"""ML Model wrapper to provide mechanisms for serialization and deserialization using Flama format."""

model: t.Any
meta: Metadata

@classmethod
def from_model(
cls,
model: t.Any,
*,
model_id: t.Optional[t.Union[str, uuid.UUID]] = None,
timestamp: t.Optional[datetime.datetime] = None,
params: t.Optional[t.Dict[str, t.Any]] = None,
metrics: t.Optional[t.Dict[str, t.Any]] = None,
extra: t.Optional[t.Dict[str, t.Any]] = None,
) -> "ModelArtifact":
return cls(
model=model,
meta=Metadata.from_model(
model, model_id=model_id, timestamp=timestamp, params=params, metrics=metrics, extra=extra
),
)

@classmethod
def from_dict(cls, data: t.Dict[str, t.Any], **kwargs) -> "ModelArtifact":
try:
metadata = Metadata.from_dict(data["meta"])
model = FrameworkSerializers.serializer(metadata.framework.lib).load(data["model"].encode(), **kwargs)
except KeyError: # pragma: no cover
raise ValueError("Wrong data")

return cls(model=model, meta=metadata)

@classmethod
def from_json(cls, data: str, **kwargs) -> "ModelArtifact":
return cls.from_dict(json.loads(data), **kwargs)

@classmethod
def from_bytes(cls, data: bytes, **kwargs) -> "ModelArtifact":
return cls.from_json(codecs.decode(data, "zlib"), **kwargs) # type: ignore[arg-type]

def to_dict(self, **kwargs) -> t.Dict[str, t.Any]:
return {
"model": FrameworkSerializers.serializer(self.meta.framework.lib).dump(self.model, **kwargs).decode(),
"meta": self.meta.to_dict(),
}

def to_json(self, **kwargs) -> str:
return json.dumps(self.to_dict(**kwargs))

def to_bytes(self, **kwargs) -> bytes:
return codecs.encode(self.to_json(**kwargs).encode(), "zlib")
Loading

0 comments on commit c8128c1

Please sign in to comment.