Skip to content

Commit

Permalink
🚧 Improving types with pyright (#106 #108)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Aug 2, 2023
1 parent 8dfcbdd commit af7eb96
Show file tree
Hide file tree
Showing 40 changed files with 231 additions and 283 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_and_publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ jobs:
- id: ruff
name: Code style (ruff)
run: ./scripts/ruff .
- id: mypy
- id: pyright
name: Static types check
run: ./scripts/mypy .
run: ./scripts/pyright
- id: tests
name: Tests
run: ./scripts/test
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test_pull_request_branch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ jobs:
- id: ruff
name: Code style (ruff)
run: ./scripts/ruff .
- id: mypy
- id: pyright
name: Static types check
run: ./scripts/mypy .
run: ./scripts/pyright
- id: tests
name: Tests
run: ./scripts/test
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ repos:
exclude: "make"
- id: ruff
name: Ruff - Code Linter
entry: ./scripts/ruff
entry: ./scripts/ruff --fix
language: system
types: [file, python]
exclude: "make"
- id: mypy
name: Mypy - Static types check
entry: ./scripts/mypy
- id: pyright
name: Pyright - Static types check
entry: ./scripts/pyright
language: system
types: [file, python]
exclude: "(make|tests/|examples/)"
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ black: ## Runs black on Flama
ruff: ## Runs ruff on Flama
@./scripts/ruff .

mypy: ## Runs mypy on Flama
@./scripts/mypy .
pyright: ## Runs pyright on Flama
@./scripts/pyright

docker_push: ## Runs mypy on Flama
docker_push: ## Push docker images to registry
@./scripts/docker_push .

.PHONY: help check clean install build lint tests publish version isort black ruff mypy docker_push
.PHONY: help check clean install build lint tests publish version isort black ruff pyright docker_push
.DEFAULT_GOAL := help

help:
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
24 changes: 11 additions & 13 deletions flama/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,23 @@ class Concurrency(enum.Enum):

class BackgroundTask(starlette.background.BackgroundTask):
def __init__(
self, concurrency: t.Union[Concurrency, str], func: t.Callable[P, t.Any], *args: P.args, **kwargs: P.kwargs
self,
concurrency: t.Union[Concurrency, str],
func: t.Callable[P, t.Union[None, t.Awaitable[None]]],
*args: P.args,
**kwargs: P.kwargs
) -> None:
self.func = self._create_task_function(func)
self.args = args
self.kwargs = kwargs
self.concurrency = Concurrency[concurrency] if isinstance(concurrency, str) else concurrency

def _create_task_function(self, func: t.Callable[P, t.Any]) -> t.Callable[P, t.Any]:
if asyncio.iscoroutinefunction(func):

@functools.wraps(func)
async def _inner(*args, **kwargs):
await func(*args, **kwargs)

else:

@functools.wraps(func)
async def _inner(*args, **kwargs):
await concurrency.run(func, *args, **kwargs)
def _create_task_function(
self, func: t.Callable[P, t.Union[None, t.Awaitable[None]]]
) -> t.Callable[P, t.Awaitable[None]]:
@functools.wraps(func)
async def _inner(*args: P.args, **kwargs: P.kwargs):
await concurrency.run_in_thread(func, *args, **kwargs)

return _inner

Expand Down
42 changes: 30 additions & 12 deletions flama/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
import sys
import typing as t

from starlette.concurrency import run_in_threadpool

if sys.version_info < (3, 10): # PORT: Remove when stop supporting 3.9 # pragma: no cover
from typing_extensions import ParamSpec, TypeGuard

t.TypeGuard = TypeGuard
t.ParamSpec = ParamSpec

__all__ = ["is_async", "run"]
__all__ = ["is_async", "run", "run_in_thread"]

T = t.TypeVar("T", covariant=True)
R = t.TypeVar("R")
P = t.ParamSpec("P")


def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable]]:
def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable[t.Any]]]:
"""Check if given object is an async function, callable or partialised function.
:param obj: Object to check.
Expand All @@ -26,13 +24,33 @@ def is_async(obj: t.Any) -> t.TypeGuard[t.Callable[..., t.Awaitable]]:
while isinstance(obj, functools.partial):
obj = obj.func

return asyncio.iscoroutinefunction(obj) or (
callable(obj) and asyncio.iscoroutinefunction(obj.__call__) # type: ignore[operator]
)
return asyncio.iscoroutinefunction(obj) or (callable(obj) and asyncio.iscoroutinefunction(obj.__call__))


async def run(
func: t.Union[t.Callable[P, R], t.Callable[P, t.Awaitable[R]]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Run a function either as asyncio awaiting it if it's an async function or running it if it's a sync function.
:param func: Function to run.
:param args: Positional arguments.
:param kwargs: Keyword arguments.
:return: Function returned value.
"""
if is_async(func):
return await func(*args, **kwargs)

return func(*args, **kwargs) # type: ignore


async def run(func: t.Callable[P, t.Union[T, t.Awaitable[T]]], *args: P.args, **kwargs: P.kwargs) -> T:
"""Run a function either as asyncio awaiting it if it's an async function or running it in a threadpool if it's a
async def run_in_thread(
func: t.Union[t.Callable[P, R], t.Callable[P, t.Awaitable[R]]],
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Run a function either as asyncio awaiting it if it's an async function or running it in a thread if it's a
sync function.
:param func: Function to run.
Expand All @@ -41,6 +59,6 @@ async def run(func: t.Callable[P, t.Union[T, t.Awaitable[T]]], *args: P.args, **
:return: Function returned value.
"""
if is_async(func):
return await func(*args, **kwargs) # type: ignore[no-any-return]
return await func(*args, **kwargs)

return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
return await asyncio.to_thread(func, *args, **kwargs) # type: ignore
2 changes: 1 addition & 1 deletion flama/debug/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def process_exception(
if response_started:
raise RuntimeError("Caught handled exception, but response already started.") from exc

response = await concurrency.run(handler, scope, receive, send, exc)
response = await concurrency.run_in_thread(handler, scope, receive, send, exc)

if response:
await response(scope, receive, send)
Expand Down
2 changes: 1 addition & 1 deletion flama/debug/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
HandlerException = t.TypeVar("HandlerException", bound=Exception)
Handler = t.Callable[
[types.Scope, types.Receive, types.Send, HandlerException],
t.Union[t.Optional["http.Response"], t.Awaitable[None], t.Awaitable[t.Optional["http.Response"]]],
t.Union[t.Optional["http.Response"], t.Awaitable[t.Optional["http.Response"]]],
]
2 changes: 1 addition & 1 deletion flama/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def dispatch(self) -> None:
"""Dispatch a request."""
app = self.state["app"]
handler = await app.injector.inject(self.handler, **self.state)
return await concurrency.run(handler)
return await concurrency.run_in_thread(handler)


class WebSocketEndpoint:
Expand Down
4 changes: 0 additions & 4 deletions flama/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import starlette.exceptions

import flama.schemas.exceptions

__all__ = [
"DecodeError",
"HTTPException",
Expand All @@ -16,8 +14,6 @@
"MethodNotAllowedException",
]

__all__ += flama.schemas.exceptions.__all__


class DecodeError(Exception):
"""
Expand Down
5 changes: 1 addition & 4 deletions flama/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from flama import schemas, types
from flama.exceptions import HTTPException, SerializationError

if t.TYPE_CHECKING:
from flama.types.schema import _T_Schema

__all__ = [
"Method",
"Request",
Expand Down Expand Up @@ -114,7 +111,7 @@ class FileResponse(starlette.responses.FileResponse, Response):
class APIResponse(JSONResponse):
media_type = "application/json"

def __init__(self, content: t.Any = None, schema: t.Optional["_T_Schema"] = None, *args, **kwargs):
def __init__(self, content: t.Any = None, schema: t.Optional["schemas.Schema"] = None, *args, **kwargs):
self.schema = schema
super().__init__(content, *args, **kwargs)

Expand Down
16 changes: 10 additions & 6 deletions flama/injection/components.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import asyncio
import inspect
import typing as t
Expand All @@ -8,7 +9,7 @@
__all__ = ["Component", "Components"]


class Component:
class Component(metaclass=abc.ABCMeta):
def identity(self, parameter: Parameter) -> str:
"""Each component needs a unique identifier string that we use for lookups from the `state` dictionary when we
run the dependency injection.
Expand All @@ -22,8 +23,7 @@ def identity(self, parameter: Parameter) -> str:
parameter_type = parameter.type.__class__.__name__
component_id = f"{id(parameter.type)}:{parameter_type}"

# If `resolve_parameter` includes `Parameter` then we use an identifier that is additionally parameterized by
# the parameter name.
# If `resolve` includes `Parameter` then use an id that is additionally parameterized by the parameter name.
args = inspect.signature(self.resolve).parameters.values() # type: ignore[attr-defined]
if Parameter in [arg.annotation for arg in args]:
component_id += f":{parameter.name.lower()}"
Expand Down Expand Up @@ -65,14 +65,18 @@ async def __call__(self, *args, **kwargs):
:param kwargs: Resolve keyword arguments.
:return: Resolve result.
"""
if asyncio.iscoroutinefunction(self.resolve):
return await self.resolve(*args, **kwargs)
if asyncio.iscoroutinefunction(self.resolve): # type: ignore[attr-defined]
return await self.resolve(*args, **kwargs) # type: ignore[attr-defined]

return self.resolve(*args, **kwargs)
return self.resolve(*args, **kwargs) # type: ignore[attr-defined]

def __str__(self) -> str:
return str(self.__class__.__name__)

@abc.abstractmethod
def resolve(self, *args, **kwargs) -> t.Any:
...


class Components(t.Tuple[Component, ...]):
def __new__(cls, components=None):
Expand Down
7 changes: 4 additions & 3 deletions flama/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware

from flama import concurrency
from flama.debug.middleware import ExceptionMiddleware, ServerErrorMiddleware

try:
Expand Down Expand Up @@ -39,8 +40,8 @@ def __init__(self, middleware: "types.Middleware", **kwargs: t.Any) -> None:
self.middleware = middleware
self.kwargs = kwargs

def __call__(self, app: "types.App"):
return self.middleware(app, **self.kwargs)
async def __call__(self, app: "types.App") -> "types.App":
return await self.middleware(app, **self.kwargs)

def __repr__(self) -> str:
name = self.__class__.__name__
Expand Down Expand Up @@ -100,4 +101,4 @@ def add_middleware(self, middleware: Middleware):
del self.stack

async def __call__(self, scope: "types.Scope", receive: "types.Receive", send: "types.Send") -> None:
await self.stack(scope, receive, send)
await concurrency.run(self.stack, scope, receive, send)
2 changes: 1 addition & 1 deletion flama/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class Model:
def __init__(self, model: t.Any, meta: "Metadata", artifacts: "Artifacts"):
def __init__(self, model: t.Any, meta: "Metadata", artifacts: t.Optional["Artifacts"]):
self.model = model
self.meta = meta
self.artifacts = artifacts
Expand Down
10 changes: 5 additions & 5 deletions flama/models/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class InspectMixin:
@classmethod
def _add_inspect(
mcs, name: str, verbose_name: str, model_model_type: t.Type["Model"], **kwargs
cls, name: str, verbose_name: str, model_model_type: t.Type["Model"], **kwargs
) -> t.Dict[str, t.Any]:
@resource_method("/", methods=["GET"], name=f"{name}-inspect")
async def inspect(self, model: model_model_type): # type: ignore[valid-type]
Expand All @@ -44,7 +44,7 @@ async def inspect(self, model: model_model_type): # type: ignore[valid-type]
class PredictMixin:
@classmethod
def _add_predict(
mcs, name: str, verbose_name: str, model_model_type: t.Type["Model"], **kwargs
cls, name: str, verbose_name: str, model_model_type: t.Type["Model"], **kwargs
) -> t.Dict[str, t.Any]:
@resource_method("/predict/", methods=["POST"], name=f"{name}-predict")
async def predict(
Expand Down Expand Up @@ -108,16 +108,16 @@ def __new__(mcs, name: str, bases: t.Tuple[type], namespace: t.Dict[str, t.Any])
return super().__new__(mcs, name, bases, namespace)

@classmethod
def _get_model_component(mcs, bases: t.Sequence[t.Any], namespace: t.Dict[str, t.Any]) -> "ModelComponent":
def _get_model_component(cls, bases: t.Sequence[t.Any], namespace: t.Dict[str, t.Any]) -> "ModelComponent":
try:
component: "ModelComponent" = mcs._get_attribute("component", bases, namespace, metadata_namespace="model")
component: "ModelComponent" = cls._get_attribute("component", bases, namespace, metadata_namespace="model")
return component
except AttributeError:
...

try:
return ModelComponentBuilder.load(
mcs._get_attribute("model_path", bases, namespace, metadata_namespace="model")
cls._get_attribute("model_path", bases, namespace, metadata_namespace="model")
)
except AttributeError:
...
Expand Down
Loading

0 comments on commit af7eb96

Please sign in to comment.