Skip to content

Commit

Permalink
✨ Model resources
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent 4b141b5 commit ca5b7ee
Show file tree
Hide file tree
Showing 14 changed files with 430 additions and 76 deletions.
6 changes: 5 additions & 1 deletion flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flama.exceptions import HTTPException
from flama.injection import Injector
from flama.lifespan import Lifespan
from flama.models.modules import ModelsModule
from flama.modules import Modules
from flama.pagination import paginator
from flama.resources import ResourcesModule
Expand All @@ -28,7 +29,7 @@
__all__ = ["Flama"]


DEFAULT_MODULES: typing.List[typing.Type["Module"]] = [SQLAlchemyModule, ResourcesModule, SchemaModule]
DEFAULT_MODULES: typing.List[typing.Type["Module"]] = [SQLAlchemyModule, ResourcesModule, SchemaModule, ModelsModule]


class Flama(Starlette):
Expand Down Expand Up @@ -106,6 +107,9 @@ def injector(self) -> Injector:
def components(self) -> "Components":
return self.router.components

def add_component(self, component: "Component"):
self.router.add_component(component)

@property
def routes(self) -> typing.List["BaseRoute"]: # type: ignore[override]
return self.router.routes
Expand Down
21 changes: 16 additions & 5 deletions flama/models/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import json
import typing

from flama.components import Component
Expand All @@ -11,19 +12,29 @@ class Model:
def __init__(self, model: typing.Any):
self.model = model

@abc.abstractmethod
def inspect(self) -> typing.Any:
...

@abc.abstractmethod
def predict(self, x: typing.Any) -> typing.Any:
...


class TensorFlowModel(Model):
def predict(self, x: typing.Any) -> typing.Any:
return self.model.predict(x)
def inspect(self) -> typing.Any:
return json.loads(self.model.to_json())

def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:
return self.model.predict(x).tolist()


class SKLearnModel(Model):
def predict(self, x: typing.Any) -> typing.Any:
return self.model.predict(x)
def inspect(self) -> typing.Any:
return self.model.get_params()

def predict(self, x: typing.List[typing.List[typing.Any]]) -> typing.Any:
return self.model.predict(x).tolist()


class ModelComponentBuilder:
Expand All @@ -39,7 +50,7 @@ class ModelComponent(Component):
def __init__(self, model: model_class): # type: ignore[valid-type]
self.model = model

def resolve(self, test: bool) -> model_class: # type: ignore[valid-type]
def resolve(self) -> model_class: # type: ignore[valid-type]
return self.model

return ModelComponent(model_obj)
28 changes: 15 additions & 13 deletions flama/models/modules.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
import inspect
import os
import typing

from flama.models.resource import ModelResource, ModelResourceType
from flama.modules import Module
from flama.resources.routing import ResourceRoute

if typing.TYPE_CHECKING:
from flama.resources.resource import BaseResource

__all__ = ["ModelsModule"]


class ModelsModule(Module):
name = "resources"
name = "models"

def add_model(
self, path: str, resource: typing.Union["BaseResource", typing.Type["BaseResource"]], *args, **kwargs
):
def add_model(self, path: str, model: typing.Union[str, os.PathLike], name: str, *args, **kwargs):
"""Adds a model to this application, setting its endpoints.
:param path: Resource base path.
:param resource: Resource class.
:param model: Model path.
:param name: Model name.
"""
# Handle class or instance objects
resource = resource(app=self.app, *args, **kwargs) if inspect.isclass(resource) else resource
name_ = name
model_ = model

class Resource(ModelResource, metaclass=ModelResourceType):
name = name_
model = model_

self.app.routes.append(ResourceRoute(path, resource, main_app=self.app))
resource = Resource()
self.app.resources.add_resource(path, resource) # type: ignore[attr-defined]
self.app.add_component(resource.component) # type: ignore

def model(self, path: str, *args, **kwargs) -> typing.Callable:
"""Decorator for Model classes for adding them to the application.
Expand Down
109 changes: 109 additions & 0 deletions flama/models/resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import typing

import flama.schemas
from flama.models.components import ModelComponentBuilder
from flama.resources import types, BaseResource
from flama.resources.exceptions import ResourceAttributeError
from flama.resources.resource import ResourceType
from flama.resources.routing import resource_method
import flama.schemas

if typing.TYPE_CHECKING:
from flama.components import Component
from flama.models.components import Model

__all__ = ["ModelResource", "InspectMixin", "PredictMixin", "ModelResourceType"]


class InspectMixin:
@classmethod
def _add_inspect(
mcs, name: str, verbose_name: str, ml_model_type: "Model", **kwargs
) -> typing.Dict[str, typing.Any]:
@resource_method("/", methods=["GET"], name=f"{name}-inspect")
async def inspect(self, model: ml_model_type): # type: ignore[valid-type]
return model.inspect() # type: ignore[attr-defined]

inspect.__doc__ = f"""
tags:
- {verbose_name}
summary:
Retrieve the model.
description:
Retrieve the model from this resource.
responses:
200:
description:
The model.
"""

return {"_inspect": inspect}


class PredictMixin:
@classmethod
def _add_predict(
mcs, name: str, verbose_name: str, ml_model_type: "Model", **kwargs
) -> typing.Dict[str, typing.Any]:
@resource_method("/predict/", methods=["POST"], name=f"{name}-predict")
async def predict(
self, model: ml_model_type, data: flama.schemas.schemas.MLModelInput # type: ignore[valid-type]
) -> flama.schemas.schemas.MLModelOutput:
return {"output": model.predict(data["input"])} # type: ignore[attr-defined]

predict.__doc__ = f"""
tags:
- {verbose_name}
summary:
Generate a prediction.
description:
Generate a prediction using the model from this resource.
responses:
200:
description:
The prediction generated by the model.
"""

return {"_predict": predict}


class ModelResource(BaseResource):
model: typing.Union[str, os.PathLike]


class ModelResourceType(ResourceType, InspectMixin, PredictMixin):
METHODS = ("inspect", "predict")

def __new__(mcs, name: str, bases: typing.Tuple[type], namespace: typing.Dict[str, typing.Any]):
"""Resource metaclass for defining basic behavior for ML resources:
* Create _meta attribute containing some metadata (model...).
* Adds methods related to ML resource (inspect, predict...) listed in METHODS class attribute.
:param name: Class name.
:param bases: List of superclasses.
:param namespace: Variables namespace used to create the class.
"""
try:
# Get model component
component = mcs._get_model_component(bases, namespace)
model = component.model # type: ignore[attr-defined]
namespace["component"] = component
namespace["model"] = component.model # type: ignore[attr-defined]
except AttributeError as e:
raise ResourceAttributeError(str(e), name)

metadata_namespace = {"component": component, "model": model, "model_type": type(model)}
if "_meta" in namespace:
namespace["_meta"].namespaces["ml"] = metadata_namespace
else:
namespace["_meta"] = types.Metadata(namespaces={"ml": metadata_namespace})

return super().__new__(mcs, name, bases, namespace)

@classmethod
def _get_model_component(
mcs, bases: typing.Sequence[typing.Any], namespace: typing.Dict[str, typing.Any]
) -> "Component":
with open(mcs._get_attribute("model", bases, namespace), "rb") as f:
return ModelComponentBuilder.loads(f.read())
7 changes: 3 additions & 4 deletions flama/resources/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ class RESTResource(BaseResource):

class RESTResourceType(ResourceType):
def __new__(mcs, name: str, bases: typing.Tuple[type], namespace: typing.Dict[str, typing.Any]):
"""Resource metaclass for defining basic behavior:
* Create _meta attribute containing some metadata (model, schemas, names...).
"""Resource metaclass for defining basic behavior for REST resources:
* Create _meta attribute containing some metadata (model, schemas...).
* Adds methods related to REST resource (create, retrieve, update, delete...) listed in METHODS class attribute.
* Generate a Router with above methods.
:param name: Class name.
:param bases: List of superclasses.
:param namespace: Variables namespace used to create the class.
"""
try:
# Get model and replace it with a read-only descriptor
# Get model
model = mcs._get_model(bases, namespace)
namespace["model"] = model.table

Expand Down
3 changes: 3 additions & 0 deletions flama/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def components(self) -> Components:
]
)

def add_component(self, component: Component):
self._components.append(component)

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
try:
main_app = self.main_app
Expand Down
18 changes: 18 additions & 0 deletions flama/schemas/_libs/marshmallow/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,29 @@ class PageNumber(marshmallow.Schema):
data = marshmallow.fields.List(marshmallow.fields.Dict(), required=True)


class MLModelInput(marshmallow.Schema):
input = marshmallow.fields.List(
marshmallow.fields.Raw(),
required=True,
metadata={"title": "input", "description": "Model input"},
)


class MLModelOutput(marshmallow.Schema):
output = marshmallow.fields.List(
marshmallow.fields.Raw(),
required=True,
metadata={"title": "output", "description": "Model output"},
)


SCHEMAS = {
"APIError": APIError,
"DropCollection": DropCollection,
"LimitOffsetMeta": LimitOffsetMeta,
"LimitOffset": LimitOffset,
"PageNumberMeta": PageNumberMeta,
"PageNumber": PageNumber,
"MLModelInput": MLModelInput,
"MLModelOutput": MLModelOutput,
}
10 changes: 10 additions & 0 deletions flama/schemas/_libs/typesystem/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,13 @@
}
)
SCHEMAS["PageNumber"] = PageNumber

MLModelInput = typesystem.Schema(
fields={"input": typesystem.fields.Array(typesystem.fields.Any(), title="input", description="Model input")}
)
SCHEMAS["MLModelInput"] = MLModelInput

MLModelOutput = typesystem.Schema(
fields={"output": typesystem.fields.Array(typesystem.fields.Any(), title="input", description="Model input")}
)
SCHEMAS["MLModelOutput"] = MLModelOutput
48 changes: 48 additions & 0 deletions tests/models/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
import tensorflow as tf
from sklearn.linear_model import LogisticRegression

import flama


def tensorflow_model(
sequential=tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
):
tf_model = sequential

tf_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

return tf_model


def sklearn_model():
return LogisticRegression()


@pytest.fixture(scope="function")
def model(request):
if request.param == "tensorflow":
return tensorflow_model()

if request.param == "sklearn":
return sklearn_model()

raise ValueError("Unknown model")


@pytest.fixture(scope="function")
def model_dump(request):
if request.param == "tensorflow":
return flama.dumps("tensorflow", tensorflow_model())

if request.param == "sklearn":
return flama.dumps("sklearn", sklearn_model())

raise ValueError("Unknown model")
Binary file added tests/models/sklearn_model.flm
Binary file not shown.
Binary file added tests/models/tensorflow_model.flm
Binary file not shown.
Loading

0 comments on commit ca5b7ee

Please sign in to comment.