Skip to content

Commit

Permalink
✨ Model components
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Jan 19, 2023
1 parent 4ee60f6 commit af63231
Show file tree
Hide file tree
Showing 15 changed files with 247 additions and 43 deletions.
44 changes: 44 additions & 0 deletions examples/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import uvicorn

from flama import Component, Flama


class Address:
def __init__(self, address: str, zip_code: str):
self.address = address
self.zip_code = zip_code

def to_dict(self):
return {"address": self.address, "zip_code": self.zip_code}


class AddressComponent(Component):
def resolve(self, address: str, zip_code: str) -> Address:
return Address(address, zip_code)


class Person:
def __init__(self, name: str, age: int, address: Address):
self.name = name
self.age = age
self.address = address

def to_dict(self):
return {"name": self.name, "age": self.age, "address": self.address.to_dict()}


class PersonComponent(Component):
def resolve(self, name: str, age: int, address: Address) -> Person:
return Person(name, age, address)


app = Flama(components=[PersonComponent(), AddressComponent()])


@app.get("/foo")
def person(person: Person):
return {"person": person.to_dict()}


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
2 changes: 1 addition & 1 deletion flama/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __len__(self) -> int:
return self._components.__len__()

def __add__(self, other: "Components") -> "Components":
return Components(self._components + list(other))
return Components(list(dict.fromkeys(self._components + list(other))))

def __eq__(self, other: object) -> bool:
if isinstance(other, Components):
Expand Down
Empty file added flama/models/__init__.py
Empty file.
43 changes: 43 additions & 0 deletions flama/models/components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import abc
import typing

from flama.components import Component
from flama.serialize import Format, loads


class Model:
def __init__(self, model: typing.Any):
self.model = model

@abc.abstractmethod
def predict(self, x: typing.Any) -> typing.Any:
...


class TensorFlowModel(Model):
def predict(self, x: typing.Any) -> typing.Any:
return self.model.predict(x)


class SKLearnModel(Model):
def predict(self, x: typing.Any) -> typing.Any:
return self.model.predict(x)


class ModelComponentBuilder:
@classmethod
def loads(cls, data: bytes) -> Component:
load_model = loads(data)
name = {Format.tensorflow: "TensorFlowModel", Format.sklearn: "SKLearnModel"}[load_model.lib]
parent = {Format.tensorflow: TensorFlowModel, Format.sklearn: SKLearnModel}[load_model.lib]
model_class = type(name, (parent,), {})
model_obj = model_class(load_model.model)

class ModelComponent(Component):
def __init__(self, model: model_class): # type: ignore[valid-type]
self.model = model

def resolve(self, test: bool) -> model_class: # type: ignore[valid-type]
return self.model

return ModelComponent(model_obj)
40 changes: 40 additions & 0 deletions flama/models/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import inspect
import typing

from flama.modules import Module
from flama.resources.routing import ResourceRoute

if typing.TYPE_CHECKING:
from flama.resources.resource import BaseResource

__all__ = ["ModelsModule"]


class ModelsModule(Module):
name = "resources"

def add_model(
self, path: str, resource: typing.Union["BaseResource", typing.Type["BaseResource"]], *args, **kwargs
):
"""Adds a model to this application, setting its endpoints.
:param path: Resource base path.
:param resource: Resource class.
"""
# Handle class or instance objects
resource = resource(app=self.app, *args, **kwargs) if inspect.isclass(resource) else resource

self.app.routes.append(ResourceRoute(path, resource, main_app=self.app))

def model(self, path: str, *args, **kwargs) -> typing.Callable:
"""Decorator for Model classes for adding them to the application.
:param path: Resource base path.
:return: Decorated resource class.
"""

def decorator(resource: typing.Type["BaseResource"]) -> typing.Type["BaseResource"]:
self.add_model(path, resource, *args, **kwargs)
return resource

return decorator
3 changes: 2 additions & 1 deletion flama/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def __init__(
main_app: "Flama" = None,
app: ASGIApp = None,
routes: typing.Sequence[BaseRoute] = None,
components: typing.Sequence[Component] = None,
name: str = None,
):
if app is None:
app = Router(routes=routes)
app = Router(routes=routes, components=components)

super().__init__(path, app, routes, name)

Expand Down
2 changes: 1 addition & 1 deletion flama/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def dumps(lib: typing.Union[str, Format], model: typing.Any) -> bytes:
return Model(lib, model).to_bytes()
return Model(Format(lib), model).to_bytes()


def dump(lib: typing.Union[str, Format], model: typing.Any, fs: typing.BinaryIO) -> None:
Expand Down
10 changes: 5 additions & 5 deletions flama/serialize/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ class Format(enum.Enum):

@dataclasses.dataclass(frozen=True)
class Model:
lib: typing.Union[str, Format]
lib: Format
model: typing.Any

@classmethod
def serializer(cls, lib: typing.Union[str, Format]) -> Serializer:
def serializer(cls, lib: Format) -> Serializer:
try:
return SERIALIZERS[Format(lib)]
return SERIALIZERS[lib]
except ValueError:
raise ValueError("Wrong lib")

@classmethod
def from_bytes(cls, data: bytes) -> "Model":
try:
serialized_data = json.loads(codecs.decode(data, "zlib"))
lib = serialized_data["lib"]
lib = Format(serialized_data["lib"])
model = cls.serializer(lib).load(serialized_data["model"].encode())
except KeyError:
raise ValueError("Wrong data")
Expand All @@ -45,7 +45,7 @@ def from_bytes(cls, data: bytes) -> "Model":

def to_dict(self) -> typing.Dict[str, typing.Any]:
pickled_model = self.serializer(self.lib).dump(self.model).decode()
return {"lib": self.lib, "model": pickled_model}
return {"lib": self.lib.value, "model": pickled_model}

def to_json(self) -> str:
return json.dumps(self.to_dict())
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from contextlib import ExitStack
from time import sleep

import marshmallow
import pytest
Expand Down Expand Up @@ -120,3 +121,15 @@ def assert_recursive_contains(first, second):
assert_recursive_contains(first[i], second[i])
else:
assert first == second


def assert_read_from_file(file_path, value, max_tries=10):
read_value = None
i = 0
while not read_value and i < max_tries:
sleep(i)
with open(file_path) as f:
read_value = f.read()
i += 1

assert read_value == value
Empty file added tests/models/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions tests/models/test_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
import tensorflow as tf
from pytest import param
from sklearn.linear_model import LogisticRegression

import flama
from flama.models.components import Model, ModelComponentBuilder


class TestCaseModelComponent:
@pytest.fixture
def tensorflow_model(self):
tf_model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)

tf_model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

return tf_model

@pytest.fixture
def sklearn_model(self):
return LogisticRegression()

@pytest.fixture
def tensorflow_dump(self, tensorflow_model):
return flama.dumps("tensorflow", tensorflow_model)

@pytest.fixture
def sklearn_dump(self, sklearn_model):
return flama.dumps("sklearn", sklearn_model)

@pytest.fixture(scope="function")
def model(self, request, sklearn_model, tensorflow_model):
if request.param == "tensorflow":
return tensorflow_model

if request.param == "sklearn":
return sklearn_model

raise ValueError("Unknown model")

@pytest.fixture(scope="function")
def dump(self, request, sklearn_dump, tensorflow_dump):
if request.param == "tensorflow":
return tensorflow_dump

if request.param == "sklearn":
return sklearn_dump

raise ValueError("Unknown model")

@pytest.mark.parametrize(
("dump", "model"), (param("tensorflow", "tensorflow"), param("sklearn", "sklearn")), indirect=["dump", "model"]
)
def test_build(self, dump, model):
component = ModelComponentBuilder.loads(dump)
model_wrapper = component.model
model_instance = model_wrapper.model
assert isinstance(model_wrapper, Model)
assert isinstance(model_instance, model.__class__)
18 changes: 5 additions & 13 deletions tests/test_background.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from tempfile import NamedTemporaryFile

import anyio
import pytest
from conftest import assert_read_from_file

from flama import BackgroundProcessTask, BackgroundTasks, BackgroundThreadTask, Concurrency
from flama.responses import APIResponse
Expand Down Expand Up @@ -49,9 +49,7 @@ async def test(path: str, msg: str):
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(1)
with open(tmp_file.name) as f:
assert f.read() == "foo"
assert_read_from_file(tmp_file.name, "foo")

def test_background_thread_task(self, app, client, task, tmp_file):
@app.route("/")
Expand All @@ -62,9 +60,7 @@ async def test(path: str, msg: str):
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(1)
with open(tmp_file.name) as f:
assert f.read() == "foo"
assert_read_from_file(tmp_file.name, "foo")


class TestCaseBackgroundTasks:
Expand Down Expand Up @@ -92,9 +88,5 @@ async def test(path_1: str, msg_1: str, path_2: str, msg_2: str):
assert response.status_code == 200
assert response.json() == {"foo": "bar"}

time.sleep(1)
with open(tmp_file.name) as f:
assert f.read() == "foo"

with open(tmp_file_2.name) as f:
assert f.read() == "bar"
assert_read_from_file(tmp_file.name, "foo")
assert_read_from_file(tmp_file_2.name, "bar")
2 changes: 1 addition & 1 deletion tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestCaseComponent:
@pytest.fixture
def component(self):
class FooComponent(Component):
def resolve(self, *args, **kwargs) -> Foo:
def resolve(self, z: int, *args, **kwargs) -> Foo:
return Foo()

return FooComponent()
Expand Down
6 changes: 5 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def test_mount_declarative(self, component_mock):
routes = [
Route("/", root_mock),
Mount(
"/foo", routes=[Route("/", foo_mock, methods=["GET"]), Route("/view", foo_view_mock, methods=["GET"])]
"/foo",
routes=[Route("/", foo_mock, methods=["GET"]), Route("/view", foo_view_mock, methods=["GET"])],
components=[component_mock],
),
Mount(
"/bar",
Expand Down Expand Up @@ -179,6 +181,8 @@ def test_mount_declarative(self, component_mock):
assert isinstance(mount_with_routes_route.app, Router)
mount_with_routes_router = mount_with_routes_route.app
assert mount_with_routes_router.main_app == app
assert mount_with_routes_router.components == Components([component_mock])
assert app.components == Components([component_mock])
# Check second-level routes are created an initialized
assert len(mount_with_routes_route.routes) == 2
assert mount_with_routes_route.routes[0].path == "/"
Expand Down
Loading

0 comments on commit af63231

Please sign in to comment.