-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
247 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import uvicorn | ||
|
||
from flama import Component, Flama | ||
|
||
|
||
class Address: | ||
def __init__(self, address: str, zip_code: str): | ||
self.address = address | ||
self.zip_code = zip_code | ||
|
||
def to_dict(self): | ||
return {"address": self.address, "zip_code": self.zip_code} | ||
|
||
|
||
class AddressComponent(Component): | ||
def resolve(self, address: str, zip_code: str) -> Address: | ||
return Address(address, zip_code) | ||
|
||
|
||
class Person: | ||
def __init__(self, name: str, age: int, address: Address): | ||
self.name = name | ||
self.age = age | ||
self.address = address | ||
|
||
def to_dict(self): | ||
return {"name": self.name, "age": self.age, "address": self.address.to_dict()} | ||
|
||
|
||
class PersonComponent(Component): | ||
def resolve(self, name: str, age: int, address: Address) -> Person: | ||
return Person(name, age, address) | ||
|
||
|
||
app = Flama(components=[PersonComponent(), AddressComponent()]) | ||
|
||
|
||
@app.get("/foo") | ||
def person(person: Person): | ||
return {"person": person.to_dict()} | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import abc | ||
import typing | ||
|
||
from flama.components import Component | ||
from flama.serialize import Format, loads | ||
|
||
|
||
class Model: | ||
def __init__(self, model: typing.Any): | ||
self.model = model | ||
|
||
@abc.abstractmethod | ||
def predict(self, x: typing.Any) -> typing.Any: | ||
... | ||
|
||
|
||
class TensorFlowModel(Model): | ||
def predict(self, x: typing.Any) -> typing.Any: | ||
return self.model.predict(x) | ||
|
||
|
||
class SKLearnModel(Model): | ||
def predict(self, x: typing.Any) -> typing.Any: | ||
return self.model.predict(x) | ||
|
||
|
||
class ModelComponentBuilder: | ||
@classmethod | ||
def loads(cls, data: bytes) -> Component: | ||
load_model = loads(data) | ||
name = {Format.tensorflow: "TensorFlowModel", Format.sklearn: "SKLearnModel"}[load_model.lib] | ||
parent = {Format.tensorflow: TensorFlowModel, Format.sklearn: SKLearnModel}[load_model.lib] | ||
model_class = type(name, (parent,), {}) | ||
model_obj = model_class(load_model.model) | ||
|
||
class ModelComponent(Component): | ||
def __init__(self, model: model_class): # type: ignore[valid-type] | ||
self.model = model | ||
|
||
def resolve(self, test: bool) -> model_class: # type: ignore[valid-type] | ||
return self.model | ||
|
||
return ModelComponent(model_obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import inspect | ||
import typing | ||
|
||
from flama.modules import Module | ||
from flama.resources.routing import ResourceRoute | ||
|
||
if typing.TYPE_CHECKING: | ||
from flama.resources.resource import BaseResource | ||
|
||
__all__ = ["ModelsModule"] | ||
|
||
|
||
class ModelsModule(Module): | ||
name = "resources" | ||
|
||
def add_model( | ||
self, path: str, resource: typing.Union["BaseResource", typing.Type["BaseResource"]], *args, **kwargs | ||
): | ||
"""Adds a model to this application, setting its endpoints. | ||
:param path: Resource base path. | ||
:param resource: Resource class. | ||
""" | ||
# Handle class or instance objects | ||
resource = resource(app=self.app, *args, **kwargs) if inspect.isclass(resource) else resource | ||
|
||
self.app.routes.append(ResourceRoute(path, resource, main_app=self.app)) | ||
|
||
def model(self, path: str, *args, **kwargs) -> typing.Callable: | ||
"""Decorator for Model classes for adding them to the application. | ||
:param path: Resource base path. | ||
:return: Decorated resource class. | ||
""" | ||
|
||
def decorator(resource: typing.Type["BaseResource"]) -> typing.Type["BaseResource"]: | ||
self.add_model(path, resource, *args, **kwargs) | ||
return resource | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import pytest | ||
import tensorflow as tf | ||
from pytest import param | ||
from sklearn.linear_model import LogisticRegression | ||
|
||
import flama | ||
from flama.models.components import Model, ModelComponentBuilder | ||
|
||
|
||
class TestCaseModelComponent: | ||
@pytest.fixture | ||
def tensorflow_model(self): | ||
tf_model = tf.keras.models.Sequential( | ||
[ | ||
tf.keras.layers.Flatten(input_shape=(28, 28)), | ||
tf.keras.layers.Dense(128, activation="relu"), | ||
tf.keras.layers.Dropout(0.2), | ||
tf.keras.layers.Dense(10, activation="softmax"), | ||
] | ||
) | ||
|
||
tf_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]) | ||
|
||
return tf_model | ||
|
||
@pytest.fixture | ||
def sklearn_model(self): | ||
return LogisticRegression() | ||
|
||
@pytest.fixture | ||
def tensorflow_dump(self, tensorflow_model): | ||
return flama.dumps("tensorflow", tensorflow_model) | ||
|
||
@pytest.fixture | ||
def sklearn_dump(self, sklearn_model): | ||
return flama.dumps("sklearn", sklearn_model) | ||
|
||
@pytest.fixture(scope="function") | ||
def model(self, request, sklearn_model, tensorflow_model): | ||
if request.param == "tensorflow": | ||
return tensorflow_model | ||
|
||
if request.param == "sklearn": | ||
return sklearn_model | ||
|
||
raise ValueError("Unknown model") | ||
|
||
@pytest.fixture(scope="function") | ||
def dump(self, request, sklearn_dump, tensorflow_dump): | ||
if request.param == "tensorflow": | ||
return tensorflow_dump | ||
|
||
if request.param == "sklearn": | ||
return sklearn_dump | ||
|
||
raise ValueError("Unknown model") | ||
|
||
@pytest.mark.parametrize( | ||
("dump", "model"), (param("tensorflow", "tensorflow"), param("sklearn", "sklearn")), indirect=["dump", "model"] | ||
) | ||
def test_build(self, dump, model): | ||
component = ModelComponentBuilder.loads(dump) | ||
model_wrapper = component.model | ||
model_instance = model_wrapper.model | ||
assert isinstance(model_wrapper, Model) | ||
assert isinstance(model_instance, model.__class__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.