-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
430 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import typing | ||
|
||
import flama.schemas | ||
from flama.models.components import ModelComponentBuilder | ||
from flama.resources import types, BaseResource | ||
from flama.resources.exceptions import ResourceAttributeError | ||
from flama.resources.resource import ResourceType | ||
from flama.resources.routing import resource_method | ||
import flama.schemas | ||
|
||
if typing.TYPE_CHECKING: | ||
from flama.components import Component | ||
from flama.models.components import Model | ||
|
||
__all__ = ["ModelResource", "InspectMixin", "PredictMixin", "ModelResourceType"] | ||
|
||
|
||
class InspectMixin: | ||
@classmethod | ||
def _add_inspect( | ||
mcs, name: str, verbose_name: str, ml_model_type: "Model", **kwargs | ||
) -> typing.Dict[str, typing.Any]: | ||
@resource_method("/", methods=["GET"], name=f"{name}-inspect") | ||
async def inspect(self, model: ml_model_type): # type: ignore[valid-type] | ||
return model.inspect() # type: ignore[attr-defined] | ||
|
||
inspect.__doc__ = f""" | ||
tags: | ||
- {verbose_name} | ||
summary: | ||
Retrieve the model. | ||
description: | ||
Retrieve the model from this resource. | ||
responses: | ||
200: | ||
description: | ||
The model. | ||
""" | ||
|
||
return {"_inspect": inspect} | ||
|
||
|
||
class PredictMixin: | ||
@classmethod | ||
def _add_predict( | ||
mcs, name: str, verbose_name: str, ml_model_type: "Model", **kwargs | ||
) -> typing.Dict[str, typing.Any]: | ||
@resource_method("/predict/", methods=["POST"], name=f"{name}-predict") | ||
async def predict( | ||
self, model: ml_model_type, data: flama.schemas.schemas.MLModelInput # type: ignore[valid-type] | ||
) -> flama.schemas.schemas.MLModelOutput: | ||
return {"output": model.predict(data["input"])} # type: ignore[attr-defined] | ||
|
||
predict.__doc__ = f""" | ||
tags: | ||
- {verbose_name} | ||
summary: | ||
Generate a prediction. | ||
description: | ||
Generate a prediction using the model from this resource. | ||
responses: | ||
200: | ||
description: | ||
The prediction generated by the model. | ||
""" | ||
|
||
return {"_predict": predict} | ||
|
||
|
||
class ModelResource(BaseResource): | ||
model: typing.Union[str, os.PathLike] | ||
|
||
|
||
class ModelResourceType(ResourceType, InspectMixin, PredictMixin): | ||
METHODS = ("inspect", "predict") | ||
|
||
def __new__(mcs, name: str, bases: typing.Tuple[type], namespace: typing.Dict[str, typing.Any]): | ||
"""Resource metaclass for defining basic behavior for ML resources: | ||
* Create _meta attribute containing some metadata (model...). | ||
* Adds methods related to ML resource (inspect, predict...) listed in METHODS class attribute. | ||
:param name: Class name. | ||
:param bases: List of superclasses. | ||
:param namespace: Variables namespace used to create the class. | ||
""" | ||
try: | ||
# Get model component | ||
component = mcs._get_model_component(bases, namespace) | ||
model = component.model # type: ignore[attr-defined] | ||
namespace["component"] = component | ||
namespace["model"] = component.model # type: ignore[attr-defined] | ||
except AttributeError as e: | ||
raise ResourceAttributeError(str(e), name) | ||
|
||
metadata_namespace = {"component": component, "model": model, "model_type": type(model)} | ||
if "_meta" in namespace: | ||
namespace["_meta"].namespaces["ml"] = metadata_namespace | ||
else: | ||
namespace["_meta"] = types.Metadata(namespaces={"ml": metadata_namespace}) | ||
|
||
return super().__new__(mcs, name, bases, namespace) | ||
|
||
@classmethod | ||
def _get_model_component( | ||
mcs, bases: typing.Sequence[typing.Any], namespace: typing.Dict[str, typing.Any] | ||
) -> "Component": | ||
with open(mcs._get_attribute("model", bases, namespace), "rb") as f: | ||
return ModelComponentBuilder.loads(f.read()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pytest | ||
import tensorflow as tf | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
import flama | ||
|
||
|
||
def tensorflow_model( | ||
sequential=tf.keras.models.Sequential( | ||
[ | ||
tf.keras.layers.Flatten(input_shape=(28, 28)), | ||
tf.keras.layers.Dense(128, activation="relu"), | ||
tf.keras.layers.Dropout(0.2), | ||
tf.keras.layers.Dense(10, activation="softmax"), | ||
] | ||
) | ||
): | ||
tf_model = sequential | ||
|
||
tf_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | ||
|
||
return tf_model | ||
|
||
|
||
def sklearn_model(): | ||
return LogisticRegression() | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def model(request): | ||
if request.param == "tensorflow": | ||
return tensorflow_model() | ||
|
||
if request.param == "sklearn": | ||
return sklearn_model() | ||
|
||
raise ValueError("Unknown model") | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def model_dump(request): | ||
if request.param == "tensorflow": | ||
return flama.dumps("tensorflow", tensorflow_model()) | ||
|
||
if request.param == "sklearn": | ||
return flama.dumps("sklearn", sklearn_model()) | ||
|
||
raise ValueError("Unknown model") |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.