diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8be179d64b..2e768e9a1d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,10 @@ jobs: with: python-version: ${{ matrix.python-version }} architecture: x64 - + - uses: actions/setup-node@v2 + with: + node-version: '14' + - run: npm install -g --no-package-lock --no-save pyright - run: pip install poetry - name: Install dependencies run: poetry install diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..da350fec0b --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +This release adds support for Pyright and Pylance, improving the +integration with Visual Studio Code! diff --git a/TWEET.md b/TWEET.md new file mode 100644 index 0000000000..972b3172f0 --- /dev/null +++ b/TWEET.md @@ -0,0 +1,8 @@ +Strawberry $version is out! This time with improved support for Pyright and +Pylance, @code's default language server for Python! + +--- + +This release adds support for Pyright and Pylance, improving the integration +with Visual Studio Code! Now you can use Pylance instead of MyPy in VS Code and +get proper autocomplete and intellisense 🎉 diff --git a/docs/README.md b/docs/README.md index 0cb4f52e79..40ee5e0d1a 100644 --- a/docs/README.md +++ b/docs/README.md @@ -38,6 +38,10 @@ - [Tools](./guides/tools.md) - [Schema export](./guides/schema-export.md) +## Editor integration + +- [Visual Studio Code](./editors/vscode.md) + ## Concepts - [Async](./concepts/async.md) diff --git a/docs/editors/pylance.png b/docs/editors/pylance.png new file mode 100644 index 0000000000..25997be387 Binary files /dev/null and b/docs/editors/pylance.png differ diff --git a/docs/editors/vscode.md b/docs/editors/vscode.md new file mode 100644 index 0000000000..d56c76f75c --- /dev/null +++ b/docs/editors/vscode.md @@ -0,0 +1,37 @@ +--- +title: Visual studio code +--- + +# Visual studio code + +Strawberry comes with support for both MyPy and Pylance, Microsoft's own +language server for Python. + +This guide will explain how to configure Visual Studio Code and Pylance to +work with Strawberry. + +## Install Pylance + +The first thing we need to do is to install +[Pylance](https://marketplace.visualstudio.com/items?itemName=ms-python.vscode-pylance), +this is the extension that enables type checking and intellisense for Visual +Studio Code. + +Once the extension is installed, we need to configure it to enable type +checking. To do so we need to change or add the following two settings: + +```json +{ + "python.languageServer": "Pylance", + "python.analysis.typeCheckingMode": "basic" +} +``` + +The first settings tells the editor to use Pylance as the language server. The +second setting tells the editor to enable type checking by using the basic type +checking mode. At the moment strict mode is not supported. + +Once you have configured the settings, you can restart VS Code and you should be +getting type checking errors in vscode. + +![Pylance showing a type error](./pylance.png) diff --git a/pyproject.toml b/pyproject.toml index 44fca7beb0..6dc8423770 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,10 @@ classifiers = [ ] include = ["strawberry/py.typed"] +[build-system] +requires = ["poetry>=0.12"] +build-backend = "poetry.masonry.api" + [tool.poetry.dependencies] python = "^3.7" starlette = {version = ">=0.13.6,<0.17.0", optional = true} @@ -112,12 +116,17 @@ sections = ["FUTURE", "STDLIB", "PYTEST", "THIRDPARTY", "DJANGO", "GRAPHQL", "FI addopts = "-s --emoji --mypy-ini-file=mypy_tests.ini --benchmark-disable" DJANGO_SETTINGS_MODULE = "tests.django.django_settings" testpaths = ["tests/"] -[build-system] -requires = ["poetry>=0.12"] -build-backend = "poetry.masonry.api" [tool.autopub] git-username = "Botberry" git-email = "bot@strawberry.rocks" project-name = "🍓" append-github-contributor = true + +[tool.pyright] +include = ["strawberry"] +exclude = ["**/__pycache__",] +reportMissingImports = true +reportMissingTypeStubs = false +pythonVersion = "3.7" +stubPath = "" diff --git a/strawberry/federation.py b/strawberry/federation.py index 761b89ea03..cc507d14c3 100644 --- a/strawberry/federation.py +++ b/strawberry/federation.py @@ -1,4 +1,6 @@ -from typing import Callable, List, Optional, Type, Union, cast +from typing import Any, Callable, List, Optional, Type, TypeVar, Union, cast, overload + +from typing_extensions import Literal from graphql import ( GraphQLField, @@ -11,6 +13,7 @@ ) from graphql.type.definition import GraphQLArgument +from strawberry.arguments import UNSET from strawberry.custom_scalar import ScalarDefinition from strawberry.enum import EnumDefinition from strawberry.permission import BasePermission @@ -19,45 +22,104 @@ from strawberry.union import StrawberryUnion from strawberry.utils.inspect import get_func_args -from .field import FederationFieldParams, field as base_field +from .field import ( + _RESOLVER_TYPE, + FederationFieldParams, + StrawberryField, + field as base_field, +) from .object_type import FederationTypeParams, type as base_type from .printer import print_schema from .schema import Schema as BaseSchema +from .utils.typing import __dataclass_transform__ -def type( - cls: Type = None, +T = TypeVar("T") + + +@overload +def field( *, - name: str = None, - description: str = None, - keys: List[str] = None, - extend: bool = False -): - return base_type( - cls, - name=name, - description=description, - federation=FederationTypeParams(keys=keys or [], extend=extend), - ) + resolver: Callable[[], T], + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + provides: Optional[List[str]] = None, + requires: Optional[List[str]] = None, + external: bool = False, + init: Literal[False] = False, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = UNSET, + default_factory: Union[Callable, object] = UNSET, +) -> T: + ... +@overload def field( - resolver: Optional[Callable] = None, *, name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, provides: Optional[List[str]] = None, requires: Optional[List[str]] = None, external: bool = False, + init: Literal[True] = True, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = UNSET, + default_factory: Union[Callable, object] = UNSET, +) -> Any: + ... + + +@overload +def field( + resolver: _RESOLVER_TYPE, + *, + name: Optional[str] = None, is_subscription: bool = False, description: Optional[str] = None, - permission_classes: Optional[List[Type[BasePermission]]] = None -): + provides: Optional[List[str]] = None, + requires: Optional[List[str]] = None, + external: bool = False, + permission_classes: Optional[List[Type[BasePermission]]] = None, + deprecation_reason: Optional[str] = None, + default: Any = UNSET, + default_factory: Union[Callable, object] = UNSET, +) -> StrawberryField: + ... + + +def field( + resolver=None, + *, + name=None, + is_subscription=False, + description=None, + provides=None, + requires=None, + external=False, + permission_classes=None, + deprecation_reason=None, + default=UNSET, + default_factory=UNSET, + # This init parameter is used by PyRight to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init=None, +) -> Any: return base_field( resolver=resolver, name=name, is_subscription=is_subscription, description=description, permission_classes=permission_classes, + deprecation_reason=deprecation_reason, + default=default, + default_factory=default_factory, + init=init, federation=FederationFieldParams( provides=provides or [], requires=requires or [], external=external ), @@ -100,6 +162,25 @@ def _resolve_type(self, value, _type): return entity_type +@__dataclass_transform__( + order_default=True, field_descriptors=(base_field, field, StrawberryField) +) +def type( + cls: Type = None, + *, + name: str = None, + description: str = None, + keys: List[str] = None, + extend: bool = False, +): + return base_type( + cls, + name=name, + description=description, + federation=FederationTypeParams(keys=keys or [], extend=extend), + ) + + class Schema(BaseSchema): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/strawberry/field.py b/strawberry/field.py index 9bc5d54542..823ead4551 100644 --- a/strawberry/field.py +++ b/strawberry/field.py @@ -12,9 +12,11 @@ Type, TypeVar, Union, + overload, ) from cached_property import cached_property # type: ignore +from typing_extensions import Literal from strawberry.annotation import StrawberryAnnotation from strawberry.arguments import UNSET, StrawberryArgument @@ -267,8 +269,45 @@ def is_async(self) -> bool: return self._has_async_permission_classes or self._has_async_base_resolver +T = TypeVar("T") + + +@overload +def field( + *, + resolver: Callable[[], T], + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + init: Literal[False] = False, + permission_classes: Optional[List[Type[BasePermission]]] = None, + federation: Optional[FederationFieldParams] = None, + deprecation_reason: Optional[str] = None, + default: Any = UNSET, + default_factory: Union[Callable, object] = UNSET, +) -> T: + ... + + +@overload +def field( + *, + name: Optional[str] = None, + is_subscription: bool = False, + description: Optional[str] = None, + init: Literal[True] = True, + permission_classes: Optional[List[Type[BasePermission]]] = None, + federation: Optional[FederationFieldParams] = None, + deprecation_reason: Optional[str] = None, + default: Any = UNSET, + default_factory: Union[Callable, object] = UNSET, +) -> Any: + ... + + +@overload def field( - resolver: Optional[_RESOLVER_TYPE] = None, + resolver: _RESOLVER_TYPE, *, name: Optional[str] = None, is_subscription: bool = False, @@ -279,6 +318,25 @@ def field( default: Any = UNSET, default_factory: Union[Callable, object] = UNSET, ) -> StrawberryField: + ... + + +def field( + resolver=None, + *, + name=None, + is_subscription=False, + description=None, + permission_classes=None, + federation=None, + deprecation_reason=None, + default=UNSET, + default_factory=UNSET, + # This init parameter is used by PyRight to determine whether this field + # is added in the constructor or not. It is not used to change + # any behavior at the moment. + init=None, +) -> Any: """Annotates a method or property as a GraphQL field. This is normally used inside a type declaration: @@ -308,6 +366,7 @@ def field( ) if resolver: + assert init is not True, "Can't set init as True when passing a resolver." return field_(resolver) return field_ diff --git a/strawberry/object_type.py b/strawberry/object_type.py index 85bb61bd82..215d3e36ce 100644 --- a/strawberry/object_type.py +++ b/strawberry/object_type.py @@ -1,12 +1,12 @@ import dataclasses -from functools import partial from typing import List, Optional, Type, cast from .exceptions import MissingFieldAnnotationError, MissingReturnAnnotationError -from .field import StrawberryField +from .field import StrawberryField, field from .types.type_resolver import _get_fields from .types.types import FederationTypeParams, TypeDefinition from .utils.str_converters import to_camel_case +from .utils.typing import __dataclass_transform__ def _get_interfaces(cls: Type) -> List[TypeDefinition]: @@ -38,29 +38,29 @@ def _check_field_annotations(cls: Type): cls_annotations = cls.__dict__.get("__annotations__", {}) cls.__annotations__ = cls_annotations - for field_name, field in cls.__dict__.items(): - if not isinstance(field, (StrawberryField, dataclasses.Field)): + for field_name, field_ in cls.__dict__.items(): + if not isinstance(field_, (StrawberryField, dataclasses.Field)): # Not a dataclasses.Field, nor a StrawberryField. Ignore continue # If the field is a StrawberryField we need to do a bit of extra work # to make sure dataclasses.dataclass is ready for it - if isinstance(field, StrawberryField): + if isinstance(field_, StrawberryField): # Make sure the cls has an annotation if field_name not in cls_annotations: # If the field uses the default resolver, the field _must_ be # annotated - if not field.base_resolver: + if not field_.base_resolver: raise MissingFieldAnnotationError(field_name) # The resolver _must_ have a return type annotation # TODO: Maybe check this immediately when adding resolver to # field - if field.base_resolver.type_annotation is None: + if field_.base_resolver.type_annotation is None: raise MissingReturnAnnotationError(field_name) - cls_annotations[field_name] = field.base_resolver.type_annotation + cls_annotations[field_name] = field_.base_resolver.type_annotation # TODO: Make sure the cls annotation agrees with the field's type # >>> if cls_annotations[field_name] != field.base_resolver.type: @@ -115,13 +115,14 @@ def _process_type( # so we need to restore them, this will change in future, but for now this # solution should suffice - for field in fields: - if field.base_resolver and field.python_name: - setattr(cls, field.python_name, field.base_resolver.wrapped_func) + for field_ in fields: + if field_.base_resolver and field_.python_name: + setattr(cls, field_.python_name, field_.base_resolver.wrapped_func) return cls +@__dataclass_transform__(order_default=True, field_descriptors=(field, StrawberryField)) def type( cls: Type = None, *, @@ -158,8 +159,48 @@ def wrap(cls): return wrap(cls) -input = partial(type, is_input=True) -interface = partial(type, is_interface=True) +@__dataclass_transform__(order_default=True, field_descriptors=(field, StrawberryField)) +def input( + cls: Type = None, + *, + name: str = None, + description: str = None, + federation: Optional[FederationTypeParams] = None, +): + """Annotates a class as a GraphQL Input type. + Example usage: + >>> @strawberry.input: + >>> class X: + >>> field_abc: str = "ABC" + """ + + return type( + cls, name=name, description=description, federation=federation, is_input=True + ) + + +@__dataclass_transform__(order_default=True, field_descriptors=(field, StrawberryField)) +def interface( + cls: Type = None, + *, + name: str = None, + description: str = None, + federation: Optional[FederationTypeParams] = None, +): + """Annotates a class as a GraphQL Interface. + Example usage: + >>> @strawberry.interface: + >>> class X: + >>> field_abc: str + """ + + return type( + cls, + name=name, + description=description, + federation=federation, + is_interface=True, + ) __all__ = [ diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index f548cf4ce3..5f0daaf505 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -1,6 +1,15 @@ -import typing from collections.abc import AsyncGenerator -from typing import Type, TypeVar +from typing import _GenericAlias # type: ignore +from typing import ( # type: ignore + Any, + Callable, + ClassVar, + Generic, + Tuple, + Type, + TypeVar, + Union, +) def is_list(annotation: Type) -> bool: @@ -16,7 +25,7 @@ def is_union(annotation: Type) -> bool: annotation_origin = getattr(annotation, "__origin__", None) - return annotation_origin == typing.Union + return annotation_origin == Union def is_optional(annotation: Type) -> bool: @@ -52,16 +61,16 @@ def get_list_annotation(annotation: Type) -> Type: def is_concrete_generic(annotation: type) -> bool: - ignored_generics = (list, tuple, typing.Union, typing.ClassVar, AsyncGenerator) + ignored_generics = (list, tuple, Union, ClassVar, AsyncGenerator) return ( - isinstance(annotation, typing._GenericAlias) # type:ignore + isinstance(annotation, _GenericAlias) # type:ignore and annotation.__origin__ not in ignored_generics ) def is_generic_subclass(annotation: type) -> bool: return isinstance(annotation, type) and issubclass( - annotation, typing.Generic # type:ignore + annotation, Generic # type:ignore ) @@ -84,11 +93,24 @@ def is_type_var(annotation: Type) -> bool: def get_parameters(annotation: Type): if ( - isinstance(annotation, typing._GenericAlias) # type:ignore + isinstance(annotation, _GenericAlias) # type:ignore or isinstance(annotation, type) - and issubclass(annotation, typing.Generic) # type:ignore - and annotation is not typing.Generic + and issubclass(annotation, Generic) # type:ignore + and annotation is not Generic ): return annotation.__parameters__ else: return () # pragma: no cover + + +_T = TypeVar("_T") + + +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a diff --git a/tests/pyright/__init__.py b/tests/pyright/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pyright/test_federation.py b/tests/pyright/test_federation.py new file mode 100644 index 0000000000..e3ae132626 --- /dev/null +++ b/tests/pyright/test_federation.py @@ -0,0 +1,56 @@ +import pytest + +from .utils import Result, pyright_exist, run_pyright + + +pytestmark = pytest.mark.skipif( + not pyright_exist(), reason="These tests require pyright" +) +CODE = """ +import strawberry + +def get_user_age() -> int: + return 0 + + +@strawberry.federation.type +class User: + name: str + age: int = strawberry.field(resolver=get_user_age) + something_else: int = strawberry.federation.field(resolver=get_user_age) + + +User(name="Patrick") +User(n="Patrick") + +reveal_type(User) +reveal_type(User.__init__) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="error", + message='No parameter named "n" (reportGeneralTypeIssues)', + line=16, + column=6, + ), + Result( + type="error", + message='Argument missing for parameter "name" (reportGeneralTypeIssues)', + line=16, + column=1, + ), + Result( + type="info", message='Type of "User" is "Type[User]"', line=18, column=13 + ), + Result( + type="info", + message='Type of "User.__init__" is "(self: User, name: str) -> None"', + line=19, + column=13, + ), + ] diff --git a/tests/pyright/test_fields.py b/tests/pyright/test_fields.py new file mode 100644 index 0000000000..462e22fdc1 --- /dev/null +++ b/tests/pyright/test_fields.py @@ -0,0 +1,53 @@ +import pytest + +from .utils import Result, pyright_exist, run_pyright + + +pytestmark = pytest.mark.skipif( + not pyright_exist(), reason="These tests require pyright" +) + + +CODE = """ +import strawberry + + +@strawberry.type +class User: + name: str + + +User(name="Patrick") +User(n="Patrick") + +reveal_type(User) +reveal_type(User.__init__) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="error", + message='No parameter named "n" (reportGeneralTypeIssues)', + line=11, + column=6, + ), + Result( + type="error", + message='Argument missing for parameter "name" (reportGeneralTypeIssues)', + line=11, + column=1, + ), + Result( + type="info", message='Type of "User" is "Type[User]"', line=13, column=13 + ), + Result( + type="info", + message='Type of "User.__init__" is "(self: User, name: str) -> None"', + line=14, + column=13, + ), + ] diff --git a/tests/pyright/test_fields_input.py b/tests/pyright/test_fields_input.py new file mode 100644 index 0000000000..cda831bdde --- /dev/null +++ b/tests/pyright/test_fields_input.py @@ -0,0 +1,51 @@ +import pytest + +from .utils import Result, pyright_exist, run_pyright + + +pytestmark = pytest.mark.skipif( + not pyright_exist(), reason="These tests require pyright" +) +CODE = """ +import strawberry + + +@strawberry.input +class User: + name: str + + +User(name="Patrick") +User(n="Patrick") + +reveal_type(User) +reveal_type(User.__init__) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="error", + message='No parameter named "n" (reportGeneralTypeIssues)', + line=11, + column=6, + ), + Result( + type="error", + message='Argument missing for parameter "name" (reportGeneralTypeIssues)', + line=11, + column=1, + ), + Result( + type="info", message='Type of "User" is "Type[User]"', line=13, column=13 + ), + Result( + type="info", + message='Type of "User.__init__" is "(self: User, name: str) -> None"', + line=14, + column=13, + ), + ] diff --git a/tests/pyright/test_fields_resolver.py b/tests/pyright/test_fields_resolver.py new file mode 100644 index 0000000000..e1a09b503e --- /dev/null +++ b/tests/pyright/test_fields_resolver.py @@ -0,0 +1,55 @@ +import pytest + +from .utils import Result, pyright_exist, run_pyright + + +pytestmark = pytest.mark.skipif( + not pyright_exist(), reason="These tests require pyright" +) +CODE = """ +import strawberry + +def get_user_age() -> int: + return 0 + + +@strawberry.type +class User: + name: str + age: int = strawberry.field(resolver=get_user_age) + + +User(name="Patrick") +User(n="Patrick") + +reveal_type(User) +reveal_type(User.__init__) +""" + + +def test_pyright(): + results = run_pyright(CODE) + + assert results == [ + Result( + type="error", + message='No parameter named "n" (reportGeneralTypeIssues)', + line=15, + column=6, + ), + Result( + type="error", + message='Argument missing for parameter "name" (reportGeneralTypeIssues)', + line=15, + column=1, + ), + Result( + type="info", message='Type of "User" is "Type[User]"', line=17, column=13 + ), + Result( + type="info", + message='Type of "User.__init__" is "(self: User, name: str) -> None"', + line=18, + column=13, + ), + ] diff --git a/tests/pyright/utils.py b/tests/pyright/utils.py new file mode 100644 index 0000000000..c9951d3b20 --- /dev/null +++ b/tests/pyright/utils.py @@ -0,0 +1,54 @@ +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass +from typing import List, cast + +from typing_extensions import Literal + + +ResultType = Literal["error", "info"] + + +@dataclass +class Result: + type: ResultType + message: str + line: int + column: int + + @classmethod + def from_output_line(cls, output_line: str) -> "Result": + # an output line looks like: filename.py:11:6 - type: Message + + file_info, result = output_line.split("-", maxsplit=1) + + line, column = [int(value) for value in file_info.split(":")[1:]] + type_, message = [value.strip() for value in result.split(":", maxsplit=1)] + type_ = cast(ResultType, type_) + + return cls(type=type_, message=message, line=line, column=column) + + +def run_pyright(code: str) -> List[Result]: + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: + f.write(code) + + result = subprocess.run(["pyright", f.name], stdout=subprocess.PIPE) + + os.remove(f.name) + + output = result.stdout.decode("utf-8") + + results: List[Result] = [] + + for line in output.splitlines(): + if line.strip().startswith(f"{f.name}:"): + results.append(Result.from_output_line(line)) + + return results + + +def pyright_exist() -> bool: + return shutil.which("pyright") is not None