Skip to content

Commit

Permalink
✨ Improving types with pyright (#106 #108)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Sep 19, 2023
1 parent f07bd6a commit d3dbb98
Show file tree
Hide file tree
Showing 42 changed files with 397 additions and 631 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_and_publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ jobs:
- id: ruff
name: Code style (ruff)
run: ./scripts/ruff .
- id: mypy
- id: pyright
name: Static types check
run: ./scripts/mypy .
run: ./scripts/pyright
- id: tests
name: Tests
run: ./scripts/test
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_pull_request_branch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ jobs:
- id: ruff
name: Code style (ruff)
run: ./scripts/ruff .
- id: mypy
- id: pyright
name: Static types check
run: ./scripts/mypy .
run: ./scripts/pyright
- id: tests
name: Tests
run: ./scripts/test
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ repos:
exclude: "make"
- id: ruff
name: Ruff - Code Linter
entry: ./scripts/ruff
entry: ./scripts/ruff --fix
language: system
types: [file, python]
exclude: "make"
- id: mypy
name: Mypy - Static types check
entry: ./scripts/mypy
- id: pyright
name: Pyright - Static types check
entry: ./scripts/pyright
language: system
types: [file, python]
exclude: "(make|tests/|examples/)"
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ black: ## Runs black on Flama
ruff: ## Runs ruff on Flama
@./scripts/ruff .

mypy: ## Runs mypy on Flama
@./scripts/mypy .
pyright: ## Runs pyright on Flama
@./scripts/pyright

docker_push: ## Runs mypy on Flama
docker_push: ## Push docker images to registry
@./scripts/docker_push .

.PHONY: help check clean install build lint tests publish version isort black ruff mypy docker_push
.PHONY: help check clean install build lint tests publish version isort black ruff pyright docker_push
.DEFAULT_GOAL := help

help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
49 changes: 17 additions & 32 deletions flama/background.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import enum
import functools
import sys
import typing as t
from multiprocessing import Process

import starlette.background

Expand All @@ -19,48 +17,36 @@
P = t.ParamSpec("P")


class task_wrapper:
def __init__(self, target: t.Callable[P, t.Union[None, t.Awaitable[None]]]):
self.target = target
functools.update_wrapper(self, target)

async def __call__(self, *args, **kwargs):
await concurrency.run(self.target, *args, **kwargs)


class Concurrency(enum.Enum):
thread = "thread"
process = "process"


class BackgroundTask(starlette.background.BackgroundTask):
def __init__(
self, concurrency: t.Union[Concurrency, str], func: t.Callable[P, t.Any], *args: P.args, **kwargs: P.kwargs
self,
concurrency: t.Union[Concurrency, str],
func: t.Callable[P, t.Union[None, t.Awaitable[None]]],
*args: P.args,
**kwargs: P.kwargs
) -> None:
self.func = self._create_task_function(func)
self.func = task_wrapper(func)
self.args = args
self.kwargs = kwargs
self.concurrency = Concurrency[concurrency] if isinstance(concurrency, str) else concurrency

def _create_task_function(self, func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _inner(*args, **kwargs):
await func(*args, **kwargs)

else:

@functools.wraps(func)
async def _inner(*args, **kwargs):
await concurrency.run(func, *args, **kwargs)

return _inner

def _create_process_target(self, func: t.Callable[P, t.Any]):
@functools.wraps(func)
def process_target(*args: P.args, **kwargs: P.kwargs): # pragma: no cover
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
policy.set_event_loop(loop)
loop.run_until_complete(func(*args, **kwargs))

return process_target

async def __call__(self):
if self.concurrency == Concurrency.process:
Process(target=self._create_process_target(self.func), args=self.args, kwargs=self.kwargs).start()
concurrency.AsyncProcess(target=self.func, args=self.args, kwargs=self.kwargs).start()
else:
await self.func(*self.args, **self.kwargs)

Expand All @@ -72,8 +58,7 @@ def __init__(self, tasks: t.Optional[t.Sequence[BackgroundTask]] = None):
def add_task(
self, concurrency: t.Union[Concurrency, str], func: t.Callable[P, t.Any], *args: P.args, **kwargs: P.kwargs
) -> None:
task = BackgroundTask(concurrency, func, *args, **kwargs)
self.tasks.append(task)
self.tasks.append(BackgroundTask(concurrency, func, *args, **kwargs))

async def __call__(self) -> None:
for task in self.tasks:
Expand Down
39 changes: 27 additions & 12 deletions flama/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import asyncio
import functools
import multiprocessing
import sys
import typing as t

from starlette.concurrency import run_in_threadpool

if sys.version_info < (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover
from typing_extensions import ParamSpec, TypeGuard

t.TypeGuard = TypeGuard
t.ParamSpec = ParamSpec

__all__ = ["is_async", "run"]
__all__ = ["is_async", "run", "run"]

T = t.TypeVar("T", covariant=True)
R = t.TypeVar("R", covariant=True)
P = t.ParamSpec("P")


def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable]]:
def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable[t.Any]]]:
"""Check if given object is an async function, callable or partialised function.
:param obj: Object to check.
Expand All @@ -26,13 +25,15 @@ def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable]]:
while isinstance(obj, functools.partial):
obj = obj.func

return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__) # type: ignore[operator]
)
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))


async def run(func: t.Callable[P, t.Union[T, t.Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
"""Run a function either as asyncio awaiting it if it's an async function or running it in a threadpool if it's a
async def run(
func: t.Union[t.Callable[P, R], t.Callable[P, t.Awaitable[R]]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Run a function either as asyncio awaiting it if it's an async function or running it in a thread if it's a
sync function.
:param func: Function to run.
Expand All @@ -41,6 +42,20 @@ async def run(func: t.Callable[P, t.Union[T, t.Awaitable[T]]], *args: P.args, **
:return: Function returned value.
"""
if is_async(func):
return await func(*args, **kwargs) # type: ignore[no-any-return]
return await func(*args, **kwargs)

return await asyncio.to_thread(func, *args, **kwargs) # type: ignore


class AsyncProcess(multiprocessing.Process):
"""Multiprocessing Process class whose target is an async function."""

def run(self):
if self._target: # type: ignore
task = self._target(*self._args, **self._kwargs) # type: ignore

return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
if is_async(self._target): # type: ignore
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
policy.set_event_loop(loop)
loop.run_until_complete(task)
2 changes: 1 addition & 1 deletion flama/debug/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def process_exception(

response = await concurrency.run(handler, scope, receive, send, exc)

if response:
if response and concurrency.is_async(response):
await response(scope, receive, send)

def http_exception_handler(
Expand Down
2 changes: 1 addition & 1 deletion flama/debug/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
HandlerException = t.TypeVar("HandlerException", bound=Exception)
Handler = t.Callable[
[types.Scope, types.Receive, types.Send, HandlerException],
t.Union[t.Optional["http.Response"], t.Awaitable[None], t.Awaitable[t.Optional["http.Response"]]],
t.Union[t.Optional["http.Response"], t.Awaitable[t.Optional["http.Response"]]],
]
4 changes: 0 additions & 4 deletions flama/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import starlette.exceptions

import flama.schemas.exceptions

__all__ = [
"DecodeError",
"HTTPException",
Expand All @@ -16,8 +14,6 @@
"MethodNotAllowedException",
]

__all__ += flama.schemas.exceptions.__all__


class DecodeError(Exception):
"""
Expand Down
5 changes: 1 addition & 4 deletions flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from flama import schemas, types
from flama.exceptions import HTTPException, SerializationError

if t.TYPE_CHECKING:
from flama.types.schema import _T_Schema

__all__ = [
"Method",
"Request",
Expand Down Expand Up @@ -114,7 +111,7 @@ class FileResponse(starlette.responses.FileResponse, Response):
class APIResponse(JSONResponse):
media_type = "application/json"

def __init__(self, content: t.Any = None, schema: t.Optional["_T_Schema"] = None, *args, **kwargs):
def __init__(self, content: t.Any = None, schema: t.Optional["schemas.Schema"] = None, *args, **kwargs):
self.schema = schema
super().__init__(content, *args, **kwargs)

Expand Down
16 changes: 10 additions & 6 deletions flama/injection/components.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import asyncio
import inspect
import typing as t
Expand All @@ -8,7 +9,7 @@
__all__ = ["Component", "Components"]


class Component:
class Component(metaclass=abc.ABCMeta):
def identity(self, parameter: Parameter) -> str:
"""Each component needs a unique identifier string that we use for lookups from the `state` dictionary when we
run the dependency injection.
Expand All @@ -22,8 +23,7 @@ def identity(self, parameter: Parameter) -> str:
parameter_type = parameter.type.__class__.__name__
component_id = f"{id(parameter.type)}:{parameter_type}"

# If `resolve_parameter` includes `Parameter` then we use an identifier that is additionally parameterized by
# the parameter name.
# If `resolve` includes `Parameter` then use an id that is additionally parameterized by the parameter name.
args = inspect.signature(self.resolve).parameters.values() # type: ignore[attr-defined]
if Parameter in [arg.annotation for arg in args]:
component_id += f":{parameter.name.lower()}"
Expand Down Expand Up @@ -65,14 +65,18 @@ async def __call__(self, *args, **kwargs):
:param kwargs: Resolve keyword arguments.
:return: Resolve result.
"""
if asyncio.iscoroutinefunction(self.resolve):
return await self.resolve(*args, **kwargs)
if asyncio.iscoroutinefunction(self.resolve): # type: ignore[attr-defined]
return await self.resolve(*args, **kwargs) # type: ignore[attr-defined]

return self.resolve(*args, **kwargs)
return self.resolve(*args, **kwargs) # type: ignore[attr-defined]

def __str__(self) -> str:
return str(self.__class__.__name__)

@abc.abstractmethod
def resolve(self, *args, **kwargs) -> t.Any:
...


class Components(t.Tuple[Component, ...]):
def __new__(cls, components=None):
Expand Down
17 changes: 12 additions & 5 deletions flama/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware

from flama import concurrency
from flama.debug.middleware import ExceptionMiddleware, ServerErrorMiddleware

try:
Expand Down Expand Up @@ -39,7 +40,9 @@ def __init__(self, middleware: "types.Middleware", **kwargs: t.Any) -> None:
self.middleware = middleware
self.kwargs = kwargs

def __call__(self, app: "types.App"):
def __call__(
self, app: "types.App"
) -> t.Union["types.App", t.Awaitable["types.App"], "types.MiddlewareClass", "types.MiddlewareAsyncClass"]:
return self.middleware(app, **self.kwargs)

def __repr__(self) -> str:
Expand All @@ -59,13 +62,17 @@ def __init__(self, app: "types.App", middleware: t.Sequence[Middleware], debug:
self._exception_handlers: t.Dict[
t.Union[int, t.Type[Exception]], t.Callable[["Request", Exception], "Response"]
] = {}
self._stack: t.Optional["types.App"] = None
self._stack: t.Optional[
t.Union["types.App", t.Awaitable["types.App"], "types.MiddlewareClass", "types.MiddlewareAsyncClass"]
] = None

@property
def stack(self) -> "types.App":
def stack(
self,
) -> t.Union["types.App", t.Awaitable["types.App"], "types.MiddlewareClass", "types.MiddlewareAsyncClass"]:
if self._stack is None:
self._stack = functools.reduce(
lambda app, middleware: middleware(app=app),
lambda app, middleware: middleware(app=app), # type: ignore
[
Middleware(ExceptionMiddleware, handlers=self._exception_handlers, debug=self.debug),
*self.middleware,
Expand Down Expand Up @@ -100,4 +107,4 @@ def add_middleware(self, middleware: Middleware):
del self.stack

async def __call__(self, scope: "types.Scope", receive: "types.Receive", send: "types.Send") -> None:
await self.stack(scope, receive, send)
await concurrency.run(self.stack, scope, receive, send) # type: ignore
2 changes: 1 addition & 1 deletion flama/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Model:
def __init__(self, model: t.Any, meta: "Metadata", artifacts: "Artifacts"):
def __init__(self, model: t.Any, meta: "Metadata", artifacts: t.Optional["Artifacts"]):
self.model = model
self.meta = meta
self.artifacts = artifacts
Expand Down
8 changes: 4 additions & 4 deletions flama/models/models/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from flama.models.base import Model

try:
import tensorflow
import tensorflow as tf
except Exception: # pragma: no cover
tensorflow = None # type: ignore
tf = None


class TensorFlowModel(Model):
def predict(self, x: t.List[t.List[t.Any]]) -> t.Any:
assert tensorflow is not None, "`tensorflow` must be installed to use TensorFlowModel."
assert tf is not None, "`tensorflow` must be installed to use TensorFlowModel."

try:
return self.model.predict(x).tolist()
except (tensorflow.errors.OpError, ValueError):
except (tf.errors.OpError, ValueError): # type: ignore
raise exceptions.HTTPException(status_code=400)
Loading

0 comments on commit d3dbb98

Please sign in to comment.