From 5594eccb6fe7f925a7692edd0791f82e45865780 Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 7 Sep 2023 10:40:36 -0700 Subject: [PATCH 1/5] use deferred evaluation of annotations --- CHANGES.rst | 1 + pyproject.toml | 2 -- src/markupsafe/__init__.py | 55 +++++++++++++++++++++----------------- src/markupsafe/_native.py | 4 ++- 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index bbd6c0c7..48e9a5b8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,7 @@ Unreleased - Use modern packaging metadata with ``pyproject.toml`` instead of ``setup.cfg``. :pr:`348` - Change ``distutils`` imports to ``setuptools``. :pr:`399` +- Use deferred evaluation of annotations. :pr:`400` Version 2.1.3 diff --git a/pyproject.toml b/pyproject.toml index 8b034a9a..cf982b02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,5 +51,3 @@ files = ["src/markupsafe"] show_error_codes = true pretty = true strict = true -local_partial_types = true -warn_unreachable = true diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index 55f90a99..b046057f 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import collections.abc as cabc import functools import re import string @@ -20,9 +23,9 @@ def __html__(self) -> str: _strip_tags_re = re.compile(r"<.*?>", re.DOTALL) -def _simple_escaping_wrapper(func: "t.Callable[_P, str]") -> "t.Callable[_P, Markup]": +def _simple_escaping_wrapper(func: t.Callable[_P, str]) -> t.Callable[_P, Markup]: @functools.wraps(func) - def wrapped(self: "Markup", *args: "_P.args", **kwargs: "_P.kwargs") -> "Markup": + def wrapped(self: Markup, *args: _P.args, **kwargs: _P.kwargs) -> Markup: arg_list = _escape_argspec(list(args), enumerate(args), self.escape) _escape_argspec(kwargs, kwargs.items(), self.escape) return self.__class__(func(self, *arg_list, **kwargs)) # type: ignore[arg-type] @@ -69,8 +72,8 @@ class Markup(str): __slots__ = () def __new__( - cls, base: t.Any = "", encoding: t.Optional[str] = None, errors: str = "strict" - ) -> "te.Self": + cls, base: t.Any = "", encoding: str | None = None, errors: str = "strict" + ) -> te.Self: if hasattr(base, "__html__"): base = base.__html__() @@ -79,22 +82,22 @@ def __new__( return super().__new__(cls, base, encoding, errors) - def __html__(self) -> "te.Self": + def __html__(self) -> te.Self: return self - def __add__(self, other: t.Union[str, "HasHTML"]) -> "te.Self": + def __add__(self, other: str | HasHTML) -> te.Self: if isinstance(other, str) or hasattr(other, "__html__"): return self.__class__(super().__add__(self.escape(other))) return NotImplemented - def __radd__(self, other: t.Union[str, "HasHTML"]) -> "te.Self": + def __radd__(self, other: str | HasHTML) -> te.Self: if isinstance(other, str) or hasattr(other, "__html__"): return self.escape(other).__add__(self) return NotImplemented - def __mul__(self, num: "te.SupportsIndex") -> "te.Self": + def __mul__(self, num: t.SupportsIndex) -> te.Self: if isinstance(num, int): return self.__class__(super().__mul__(num)) @@ -102,7 +105,7 @@ def __mul__(self, num: "te.SupportsIndex") -> "te.Self": __rmul__ = __mul__ - def __mod__(self, arg: t.Any) -> "te.Self": + def __mod__(self, arg: t.Any) -> te.Self: if isinstance(arg, tuple): # a tuple of arguments, each wrapped arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg) @@ -118,28 +121,28 @@ def __mod__(self, arg: t.Any) -> "te.Self": def __repr__(self) -> str: return f"{self.__class__.__name__}({super().__repr__()})" - def join(self, seq: t.Iterable[t.Union[str, "HasHTML"]]) -> "te.Self": + def join(self, seq: cabc.Iterable[str | HasHTML]) -> te.Self: return self.__class__(super().join(map(self.escape, seq))) join.__doc__ = str.join.__doc__ def split( # type: ignore[override] - self, sep: t.Optional[str] = None, maxsplit: int = -1 - ) -> t.List["te.Self"]: + self, sep: str | None = None, maxsplit: int = -1 + ) -> list[te.Self]: return [self.__class__(v) for v in super().split(sep, maxsplit)] split.__doc__ = str.split.__doc__ def rsplit( # type: ignore[override] - self, sep: t.Optional[str] = None, maxsplit: int = -1 - ) -> t.List["te.Self"]: + self, sep: str | None = None, maxsplit: int = -1 + ) -> list[te.Self]: return [self.__class__(v) for v in super().rsplit(sep, maxsplit)] rsplit.__doc__ = str.rsplit.__doc__ def splitlines( # type: ignore[override] self, keepends: bool = False - ) -> t.List["te.Self"]: + ) -> list[te.Self]: return [self.__class__(v) for v in super().splitlines(keepends)] splitlines.__doc__ = str.splitlines.__doc__ @@ -169,7 +172,7 @@ def striptags(self) -> str: return self.__class__(value).unescape() @classmethod - def escape(cls, s: t.Any) -> "te.Self": + def escape(cls, s: t.Any) -> te.Self: """Escape a string. Calls :func:`escape` and ensures that for subclasses the correct type is returned. """ @@ -202,27 +205,27 @@ def escape(cls, s: t.Any) -> "te.Self": removeprefix = _simple_escaping_wrapper(str.removeprefix) removesuffix = _simple_escaping_wrapper(str.removesuffix) - def partition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]: + def partition(self, sep: str) -> tuple[te.Self, te.Self, te.Self]: l, s, r = super().partition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) - def rpartition(self, sep: str) -> t.Tuple["te.Self", "te.Self", "te.Self"]: + def rpartition(self, sep: str) -> tuple[te.Self, te.Self, te.Self]: l, s, r = super().rpartition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) - def format(self, *args: t.Any, **kwargs: t.Any) -> "te.Self": + def format(self, *args: t.Any, **kwargs: t.Any) -> te.Self: formatter = EscapeFormatter(self.escape) return self.__class__(formatter.vformat(self, args, kwargs)) def format_map( # type: ignore[override] - self, map: t.Mapping[str, t.Any] - ) -> "te.Self": + self, map: cabc.Mapping[str, t.Any] + ) -> te.Self: formatter = EscapeFormatter(self.escape) return self.__class__(formatter.vformat(self, (), map)) - def __html_format__(self, format_spec: str) -> "te.Self": + def __html_format__(self, format_spec: str) -> te.Self: if format_spec: raise ValueError("Unsupported format specification for Markup.") @@ -254,11 +257,13 @@ def format_field(self, value: t.Any, format_spec: str) -> str: return str(self.escape(rv)) -_ListOrDict = t.TypeVar("_ListOrDict", t.List[t.Any], t.Dict[t.Any, t.Any]) +_ListOrDict = t.TypeVar("_ListOrDict", "list[t.Any]", "dict[t.Any, t.Any]") def _escape_argspec( - obj: _ListOrDict, iterable: t.Iterable[t.Any], escape: t.Callable[[t.Any], Markup] + obj: _ListOrDict, + iterable: cabc.Iterable[t.Any], + escape: t.Callable[[t.Any], Markup], ) -> _ListOrDict: """Helper for various string-wrapped functions.""" for key, value in iterable: @@ -277,7 +282,7 @@ def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None: self.obj = obj self.escape = escape - def __getitem__(self, item: t.Any) -> "te.Self": + def __getitem__(self, item: t.Any) -> te.Self: return self.__class__(self.obj[item], self.escape) def __str__(self) -> str: diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py index 8117b271..a0583f13 100644 --- a/src/markupsafe/_native.py +++ b/src/markupsafe/_native.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing as t from . import Markup @@ -27,7 +29,7 @@ def escape(s: t.Any) -> Markup: ) -def escape_silent(s: t.Optional[t.Any]) -> Markup: +def escape_silent(s: t.Any | None) -> Markup: """Like :func:`escape` but treats ``None`` as the empty string. Useful with optional values, as otherwise you get the string ``'None'`` when the value is ``None``. From a4f3da9148b7e87a7bc97ff8f53c6625af5f6de3 Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 7 Sep 2023 11:02:03 -0700 Subject: [PATCH 2/5] type check tests --- pyproject.toml | 2 +- requirements/typing.in | 1 + requirements/typing.txt | 10 ++++- tests/__init__.py | 0 tests/conftest.py | 37 ++++++++++++---- tests/test_escape.py | 9 +++- tests/test_exception_custom_html.py | 11 ++++- tests/test_leak.py | 4 +- tests/test_markupsafe.py | 69 ++++++++++++++++------------- 9 files changed, 99 insertions(+), 44 deletions(-) create mode 100644 tests/__init__.py diff --git a/pyproject.toml b/pyproject.toml index cf982b02..8a4311b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ source = ["src", "*/site-packages"] [tool.mypy] python_version = "3.8" -files = ["src/markupsafe"] +files = ["src/markupsafe", "tests"] show_error_codes = true pretty = true strict = true diff --git a/requirements/typing.in b/requirements/typing.in index f0aa93ac..d76706ea 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1 +1,2 @@ mypy +pytest diff --git a/requirements/typing.txt b/requirements/typing.txt index 10f9860f..5f8c39c0 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,13 +1,21 @@ -# SHA1:7983aaa01d64547827c20395d77e248c41b2572f +# SHA1:c5ca84191e9dddc0ac30ed7693225ae49943a2ba # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # +iniconfig==2.0.0 + # via pytest mypy==1.5.1 # via -r requirements/typing.in mypy-extensions==1.0.0 # via mypy +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.1 + # via -r requirements/typing.in typing-extensions==4.7.1 # via mypy diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py index d040ea8b..d648bce8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,33 @@ +from __future__ import annotations + +import typing as t +from types import ModuleType + import pytest from markupsafe import _native +from markupsafe import Markup try: from markupsafe import _speedups except ImportError: _speedups = None # type: ignore +if t.TYPE_CHECKING: + import typing_extensions as te + + class TPEscape(te.Protocol): + def __call__(self, s: t.Any) -> Markup: + ... + + class TPEscapeSilent(te.Protocol): + def __call__(self, s: t.Any | None) -> Markup: + ... + + class TPSoftStr(te.Protocol): + def __call__(self, s: t.Any) -> str: + ... + @pytest.fixture( scope="session", @@ -18,20 +39,20 @@ ), ), ) -def _mod(request): - return request.param +def _mod(request: pytest.FixtureRequest) -> ModuleType: + return t.cast(ModuleType, request.param) @pytest.fixture(scope="session") -def escape(_mod): - return _mod.escape +def escape(_mod: ModuleType) -> TPEscape: + return t.cast("TPEscape", _mod.escape) @pytest.fixture(scope="session") -def escape_silent(_mod): - return _mod.escape_silent +def escape_silent(_mod: ModuleType) -> TPEscapeSilent: + return t.cast("TPEscapeSilent", _mod.escape_silent) @pytest.fixture(scope="session") -def soft_str(_mod): - return _mod.soft_str +def soft_str(_mod: ModuleType) -> t.Callable[[t.Any], str]: + return t.cast("t.Callable[[t.Any], str]", _mod.soft_str) diff --git a/tests/test_escape.py b/tests/test_escape.py index bf53face..f3ecf378 100644 --- a/tests/test_escape.py +++ b/tests/test_escape.py @@ -1,7 +1,14 @@ +from __future__ import annotations + +import typing as t + import pytest from markupsafe import Markup +if t.TYPE_CHECKING: + from .conftest import TPEscape + @pytest.mark.parametrize( ("value", "expect"), @@ -25,5 +32,5 @@ ("\U0001F363\U0001F362&><'\"", "\U0001F363\U0001F362&><'""), ), ) -def test_escape(escape, value, expect): +def test_escape(escape: TPEscape, value: str, expect: str) -> None: assert escape(value) == Markup(expect) diff --git a/tests/test_exception_custom_html.py b/tests/test_exception_custom_html.py index ec2f10b1..f839cbb7 100644 --- a/tests/test_exception_custom_html.py +++ b/tests/test_exception_custom_html.py @@ -1,12 +1,19 @@ +from __future__ import annotations + +import typing as t + import pytest +if t.TYPE_CHECKING: + from .conftest import TPEscape + class CustomHtmlThatRaises: - def __html__(self): + def __html__(self) -> str: raise ValueError(123) -def test_exception_custom_html(escape): +def test_exception_custom_html(escape: TPEscape) -> None: """Checks whether exceptions in custom __html__ implementations are propagated correctly. diff --git a/tests/test_leak.py b/tests/test_leak.py index 55b10b98..a925e001 100644 --- a/tests/test_leak.py +++ b/tests/test_leak.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import gc import platform @@ -10,7 +12,7 @@ escape.__module__ == "markupsafe._native", reason="only test memory leak with speedups", ) -def test_markup_leaks(): +def test_markup_leaks() -> None: counts = set() for _i in range(20): diff --git a/tests/test_markupsafe.py b/tests/test_markupsafe.py index ea9a9187..0c047092 100644 --- a/tests/test_markupsafe.py +++ b/tests/test_markupsafe.py @@ -1,9 +1,18 @@ +from __future__ import annotations + +import typing as t + import pytest from markupsafe import Markup +if t.TYPE_CHECKING: + from .conftest import TPEscape + from .conftest import TPEscapeSilent + from .conftest import TPSoftStr + -def test_adding(escape): +def test_adding(escape: TPEscape) -> None: unsafe = '' safe = Markup("username") assert unsafe + safe == str(escape(unsafe)) + str(safe) @@ -22,22 +31,22 @@ def test_adding(escape): ("%.2f", 3.14, "3.14"), ), ) -def test_string_interpolation(template, data, expect): +def test_string_interpolation(template: str, data: t.Any, expect: str) -> None: assert Markup(template) % data == expect -def test_type_behavior(): +def test_type_behavior() -> None: assert type(Markup("foo") + "bar") is Markup x = Markup("foo") assert x.__html__() is x -def test_html_interop(): +def test_html_interop() -> None: class Foo: - def __html__(self): + def __html__(self) -> str: return "awesome" - def __str__(self): + def __str__(self) -> str: return "awesome" assert Markup(Foo()) == "awesome" @@ -46,18 +55,18 @@ def __str__(self): @pytest.mark.parametrize("args", ["foo", 42, ("foo", 42)]) -def test_missing_interpol(args): +def test_missing_interpol(args: t.Any) -> None: with pytest.raises(TypeError): Markup("") % args -def test_tuple_interpol(): +def test_tuple_interpol() -> None: result = Markup("%s:%s") % ("", "") expect = Markup("<foo>:<bar>") assert result == expect -def test_dict_interpol(): +def test_dict_interpol() -> None: result = Markup("%(foo)s") % {"foo": ""} expect = Markup("<foo>") assert result == expect @@ -67,7 +76,7 @@ def test_dict_interpol(): assert result == expect -def test_escaping(escape): +def test_escaping(escape: TPEscape) -> None: assert escape("\"<>&'") == ""<>&'" assert ( Markup( @@ -82,7 +91,7 @@ def test_escaping(escape): ) -def test_unescape(): +def test_unescape() -> None: assert Markup("<test>").unescape() == "" result = Markup("jack & tavi are cooler than mike & russ").unescape() @@ -97,7 +106,7 @@ def test_unescape(): assert twice == expect -def test_format(): +def test_format() -> None: result = Markup("{awesome}").format(awesome="") assert result == "<awesome>" @@ -108,39 +117,39 @@ def test_format(): assert result == "" -def test_format_map(): +def test_format_map() -> None: result = Markup("{value}").format_map({"value": ""}) assert result == "<value>" -def test_formatting_empty(): +def test_formatting_empty() -> None: formatted = Markup("{}").format(0) assert formatted == Markup("0") -def test_custom_formatting(): +def test_custom_formatting() -> None: class HasHTMLOnly: - def __html__(self): + def __html__(self) -> Markup: return Markup("") class HasHTMLAndFormat: - def __html__(self): + def __html__(self) -> Markup: return Markup("") - def __html_format__(self, spec): + def __html_format__(self, spec: str) -> Markup: return Markup("") assert Markup("{0}").format(HasHTMLOnly()) == Markup("") assert Markup("{0}").format(HasHTMLAndFormat()) == Markup("") -def test_complex_custom_formatting(): +def test_complex_custom_formatting() -> None: class User: - def __init__(self, id, username): + def __init__(self, id: int, username: str) -> None: self.id = id self.username = username - def __html_format__(self, format_spec): + def __html_format__(self, format_spec: str) -> Markup: if format_spec == "link": return Markup('{1}').format( self.id, self.__html__() @@ -150,7 +159,7 @@ def __html_format__(self, format_spec): return self.__html__() - def __html__(self): + def __html__(self) -> Markup: return Markup("{0}").format(self.username) user = User(1, "foo") @@ -159,43 +168,43 @@ def __html__(self): assert result == expect -def test_formatting_with_objects(): +def test_formatting_with_objects() -> None: class Stringable: - def __str__(self): + def __str__(self) -> str: return "строка" assert Markup("{s}").format(s=Stringable()) == Markup("строка") -def test_escape_silent(escape, escape_silent): +def test_escape_silent(escape: TPEscape, escape_silent: TPEscapeSilent) -> None: assert escape_silent(None) == Markup() assert escape(None) == Markup(None) assert escape_silent("") == Markup("<foo>") -def test_splitting(): +def test_splitting() -> None: expect = [Markup("a"), Markup("b")] assert Markup("a b").split() == expect assert Markup("a b").rsplit() == expect assert Markup("a\nb").splitlines() == expect -def test_mul(): +def test_mul() -> None: assert Markup("a") * 3 == Markup("aaa") -def test_escape_return_type(escape): +def test_escape_return_type(escape: TPEscape) -> None: assert isinstance(escape("a"), Markup) assert isinstance(escape(Markup("a")), Markup) class Foo: - def __html__(self): + def __html__(self) -> str: return "Foo" assert isinstance(escape(Foo()), Markup) -def test_soft_str(soft_str): +def test_soft_str(soft_str: TPSoftStr) -> None: assert type(soft_str("")) is str assert type(soft_str(Markup())) is Markup assert type(soft_str(15)) is str From 8f5de90d601eb4b649175a988ff5af42141619aa Mon Sep 17 00:00:00 2001 From: David Lord Date: Thu, 14 Sep 2023 09:28:27 -0700 Subject: [PATCH 3/5] implement str methods directly wrapper wasn't accurate with types, required extra calls and instance checks --- src/markupsafe/__init__.py | 110 +++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 47 deletions(-) diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index b046057f..c751c7e3 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections.abc as cabc -import functools import re import string import sys @@ -23,16 +22,6 @@ def __html__(self) -> str: _strip_tags_re = re.compile(r"<.*?>", re.DOTALL) -def _simple_escaping_wrapper(func: t.Callable[_P, str]) -> t.Callable[_P, Markup]: - @functools.wraps(func) - def wrapped(self: Markup, *args: _P.args, **kwargs: _P.kwargs) -> Markup: - arg_list = _escape_argspec(list(args), enumerate(args), self.escape) - _escape_argspec(kwargs, kwargs.items(), self.escape) - return self.__class__(func(self, *arg_list, **kwargs)) # type: ignore[arg-type] - - return wrapped # type: ignore[return-value] - - class Markup(str): """A string that is ready to be safely inserted into an HTML or XML document, either because it was escaped or because it was marked @@ -183,29 +172,72 @@ def escape(cls, s: t.Any) -> te.Self: return rv # type: ignore[return-value] - __getitem__ = _simple_escaping_wrapper(str.__getitem__) - capitalize = _simple_escaping_wrapper(str.capitalize) - title = _simple_escaping_wrapper(str.title) - lower = _simple_escaping_wrapper(str.lower) - upper = _simple_escaping_wrapper(str.upper) - replace = _simple_escaping_wrapper(str.replace) - ljust = _simple_escaping_wrapper(str.ljust) - rjust = _simple_escaping_wrapper(str.rjust) - lstrip = _simple_escaping_wrapper(str.lstrip) - rstrip = _simple_escaping_wrapper(str.rstrip) - center = _simple_escaping_wrapper(str.center) - strip = _simple_escaping_wrapper(str.strip) - translate = _simple_escaping_wrapper(str.translate) - expandtabs = _simple_escaping_wrapper(str.expandtabs) - swapcase = _simple_escaping_wrapper(str.swapcase) - zfill = _simple_escaping_wrapper(str.zfill) - casefold = _simple_escaping_wrapper(str.casefold) + def __getitem__(self, key: t.SupportsIndex | slice, /) -> te.Self: + return self.__class__(super().__getitem__(key)) + + def capitalize(self, /) -> te.Self: + return self.__class__(super().capitalize()) + + def title(self, /) -> te.Self: + return self.__class__(super().title()) + + def lower(self, /) -> te.Self: + return self.__class__(super().lower()) + + def upper(self, /) -> te.Self: + return self.__class__(super().upper()) + + def replace(self, old: str, new: str, count: t.SupportsIndex = -1, /) -> te.Self: + return self.__class__( + super().replace(self.escape(old), self.escape(new), count) + ) + + def ljust(self, width: t.SupportsIndex, fillchar: str = " ", /) -> te.Self: + return self.__class__(super().ljust(width, self.escape(fillchar))) + + def rjust(self, width: t.SupportsIndex, fillchar: str = " ", /) -> te.Self: + return self.__class__(super().rjust(width, self.escape(fillchar))) + + def lstrip(self, chars: str | None = None, /) -> te.Self: + return self.__class__(super().lstrip(self.escape(chars))) + + def rstrip(self, chars: str | None = None, /) -> te.Self: + return self.__class__(super().rstrip(self.escape(chars))) + + def center(self, width: t.SupportsIndex, fillchar: str = " ", /) -> te.Self: + return self.__class__(super().center(width, self.escape(fillchar))) + + def strip(self, chars: str | None = None, /) -> te.Self: + return self.__class__(super().strip(self.escape(chars))) + + def translate( + self, + table: cabc.Mapping[int, str | int | None], # type: ignore[override] + /, + ) -> str: + return self.__class__(super().translate(table)) + + def expandtabs(self, /, tabsize: t.SupportsIndex = 8) -> te.Self: + return self.__class__(super().expandtabs(tabsize)) + + def swapcase(self, /) -> te.Self: + return self.__class__(super().swapcase()) + + def zfill(self, width: t.SupportsIndex, /) -> te.Self: + return self.__class__(super().zfill(width)) + + def casefold(self, /) -> te.Self: + return self.__class__(super().casefold()) if sys.version_info >= (3, 9): - removeprefix = _simple_escaping_wrapper(str.removeprefix) - removesuffix = _simple_escaping_wrapper(str.removesuffix) - def partition(self, sep: str) -> tuple[te.Self, te.Self, te.Self]: + def removeprefix(self, prefix: str, /) -> te.Self: + return self.__class__(super().removeprefix(self.escape(prefix))) + + def removesuffix(self, suffix: str) -> te.Self: + return self.__class__(super().removesuffix(self.escape(suffix))) + + def partition(self, sep: str, /) -> tuple[te.Self, te.Self, te.Self]: l, s, r = super().partition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) @@ -257,22 +289,6 @@ def format_field(self, value: t.Any, format_spec: str) -> str: return str(self.escape(rv)) -_ListOrDict = t.TypeVar("_ListOrDict", "list[t.Any]", "dict[t.Any, t.Any]") - - -def _escape_argspec( - obj: _ListOrDict, - iterable: cabc.Iterable[t.Any], - escape: t.Callable[[t.Any], Markup], -) -> _ListOrDict: - """Helper for various string-wrapped functions.""" - for key, value in iterable: - if isinstance(value, str) or hasattr(value, "__html__"): - obj[key] = escape(value) - - return obj - - class _MarkupEscapeHelper: """Helper for :meth:`Markup.__mod__`.""" From c4be0f0032f83c2a8db60768ca2b3af28b0667be Mon Sep 17 00:00:00 2001 From: David Lord Date: Fri, 15 Sep 2023 07:09:28 -0700 Subject: [PATCH 4/5] use positional-only arguments update str method signatures to match Python 3.8 --- CHANGES.rst | 2 + src/markupsafe/__init__.py | 102 ++++++++++++++++------------------- src/markupsafe/_native.py | 6 +-- src/markupsafe/_speedups.pyi | 10 ++-- 4 files changed, 56 insertions(+), 64 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 48e9a5b8..02c4b489 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -8,6 +8,8 @@ Unreleased :pr:`348` - Change ``distutils`` imports to ``setuptools``. :pr:`399` - Use deferred evaluation of annotations. :pr:`400` +- Update signatures for ``Markup`` methods to match ``str`` signatures. Use + positional-only arguments. :pr:`400` Version 2.1.3 diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index c751c7e3..a43052b3 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -10,8 +10,8 @@ import typing_extensions as te class HasHTML(te.Protocol): - def __html__(self) -> str: - pass + def __html__(self, /) -> str: + ... _P = te.ParamSpec("_P") @@ -61,82 +61,72 @@ class Markup(str): __slots__ = () def __new__( - cls, base: t.Any = "", encoding: str | None = None, errors: str = "strict" + cls, object: t.Any = "", encoding: str | None = None, errors: str = "strict" ) -> te.Self: - if hasattr(base, "__html__"): - base = base.__html__() + if hasattr(object, "__html__"): + object = object.__html__() if encoding is None: - return super().__new__(cls, base) + return super().__new__(cls, object) - return super().__new__(cls, base, encoding, errors) + return super().__new__(cls, object, encoding, errors) - def __html__(self) -> te.Self: + def __html__(self, /) -> te.Self: return self - def __add__(self, other: str | HasHTML) -> te.Self: - if isinstance(other, str) or hasattr(other, "__html__"): - return self.__class__(super().__add__(self.escape(other))) + def __add__(self, value: str | HasHTML, /) -> te.Self: + if isinstance(value, str) or hasattr(value, "__html__"): + return self.__class__(super().__add__(self.escape(value))) return NotImplemented - def __radd__(self, other: str | HasHTML) -> te.Self: - if isinstance(other, str) or hasattr(other, "__html__"): - return self.escape(other).__add__(self) + def __radd__(self, value: str | HasHTML, /) -> te.Self: + if isinstance(value, str) or hasattr(value, "__html__"): + return self.escape(value).__add__(self) return NotImplemented - def __mul__(self, num: t.SupportsIndex) -> te.Self: - if isinstance(num, int): - return self.__class__(super().__mul__(num)) + def __mul__(self, value: t.SupportsIndex, /) -> te.Self: + return self.__class__(super().__mul__(value)) - return NotImplemented - - __rmul__ = __mul__ + def __rmul__(self, value: t.SupportsIndex, /) -> te.Self: + return self.__class__(super().__mul__(value)) - def __mod__(self, arg: t.Any) -> te.Self: - if isinstance(arg, tuple): + def __mod__(self, value: t.Any, /) -> te.Self: + if isinstance(value, tuple): # a tuple of arguments, each wrapped - arg = tuple(_MarkupEscapeHelper(x, self.escape) for x in arg) - elif hasattr(type(arg), "__getitem__") and not isinstance(arg, str): + value = tuple(_MarkupEscapeHelper(x, self.escape) for x in value) + elif hasattr(type(value), "__getitem__") and not isinstance(value, str): # a mapping of arguments, wrapped - arg = _MarkupEscapeHelper(arg, self.escape) + value = _MarkupEscapeHelper(value, self.escape) else: # a single argument, wrapped with the helper and a tuple - arg = (_MarkupEscapeHelper(arg, self.escape),) + value = (_MarkupEscapeHelper(value, self.escape),) - return self.__class__(super().__mod__(arg)) + return self.__class__(super().__mod__(value)) - def __repr__(self) -> str: + def __repr__(self, /) -> str: return f"{self.__class__.__name__}({super().__repr__()})" - def join(self, seq: cabc.Iterable[str | HasHTML]) -> te.Self: - return self.__class__(super().join(map(self.escape, seq))) - - join.__doc__ = str.join.__doc__ + def join(self, iterable: cabc.Iterable[str | HasHTML], /) -> te.Self: + return self.__class__(super().join(map(self.escape, iterable))) def split( # type: ignore[override] - self, sep: str | None = None, maxsplit: int = -1 + self, /, sep: str | None = None, maxsplit: t.SupportsIndex = -1 ) -> list[te.Self]: return [self.__class__(v) for v in super().split(sep, maxsplit)] - split.__doc__ = str.split.__doc__ - def rsplit( # type: ignore[override] - self, sep: str | None = None, maxsplit: int = -1 + self, /, sep: str | None = None, maxsplit: t.SupportsIndex = -1 ) -> list[te.Self]: return [self.__class__(v) for v in super().rsplit(sep, maxsplit)] - rsplit.__doc__ = str.rsplit.__doc__ - def splitlines( # type: ignore[override] - self, keepends: bool = False + self, /, keepends: bool = False ) -> list[te.Self]: return [self.__class__(v) for v in super().splitlines(keepends)] - splitlines.__doc__ = str.splitlines.__doc__ - - def unescape(self) -> str: + def unescape(self, /) -> str: """Convert escaped markup back into a text string. This replaces HTML entities with the characters they represent. @@ -147,7 +137,7 @@ def unescape(self) -> str: return unescape(str(self)) - def striptags(self) -> str: + def striptags(self, /) -> str: """:meth:`unescape` the markup, remove tags, and normalize whitespace to single spaces. @@ -161,7 +151,7 @@ def striptags(self) -> str: return self.__class__(value).unescape() @classmethod - def escape(cls, s: t.Any) -> te.Self: + def escape(cls, s: t.Any, /) -> te.Self: """Escape a string. Calls :func:`escape` and ensures that for subclasses the correct type is returned. """ @@ -242,7 +232,7 @@ def partition(self, sep: str, /) -> tuple[te.Self, te.Self, te.Self]: cls = self.__class__ return cls(l), cls(s), cls(r) - def rpartition(self, sep: str) -> tuple[te.Self, te.Self, te.Self]: + def rpartition(self, sep: str, /) -> tuple[te.Self, te.Self, te.Self]: l, s, r = super().rpartition(self.escape(sep)) cls = self.__class__ return cls(l), cls(s), cls(r) @@ -251,13 +241,15 @@ def format(self, *args: t.Any, **kwargs: t.Any) -> te.Self: formatter = EscapeFormatter(self.escape) return self.__class__(formatter.vformat(self, args, kwargs)) - def format_map( # type: ignore[override] - self, map: cabc.Mapping[str, t.Any] + def format_map( + self, + mapping: cabc.Mapping[str, t.Any], # type: ignore[override] + /, ) -> te.Self: formatter = EscapeFormatter(self.escape) - return self.__class__(formatter.vformat(self, (), map)) + return self.__class__(formatter.vformat(self, (), mapping)) - def __html_format__(self, format_spec: str) -> te.Self: + def __html_format__(self, format_spec: str, /) -> te.Self: if format_spec: raise ValueError("Unsupported format specification for Markup.") @@ -298,19 +290,19 @@ def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None: self.obj = obj self.escape = escape - def __getitem__(self, item: t.Any) -> te.Self: - return self.__class__(self.obj[item], self.escape) + def __getitem__(self, key: t.Any, /) -> te.Self: + return self.__class__(self.obj[key], self.escape) - def __str__(self) -> str: + def __str__(self, /) -> str: return str(self.escape(self.obj)) - def __repr__(self) -> str: + def __repr__(self, /) -> str: return str(self.escape(repr(self.obj))) - def __int__(self) -> int: + def __int__(self, /) -> int: return int(self.obj) - def __float__(self) -> float: + def __float__(self, /) -> float: return float(self.obj) diff --git a/src/markupsafe/_native.py b/src/markupsafe/_native.py index a0583f13..e5ac0c13 100644 --- a/src/markupsafe/_native.py +++ b/src/markupsafe/_native.py @@ -5,7 +5,7 @@ from . import Markup -def escape(s: t.Any) -> Markup: +def escape(s: t.Any, /) -> Markup: """Replace the characters ``&``, ``<``, ``>``, ``'``, and ``"`` in the string with HTML-safe sequences. Use this if you need to display text that might contain such characters in HTML. @@ -29,7 +29,7 @@ def escape(s: t.Any) -> Markup: ) -def escape_silent(s: t.Any | None) -> Markup: +def escape_silent(s: t.Any | None, /) -> Markup: """Like :func:`escape` but treats ``None`` as the empty string. Useful with optional values, as otherwise you get the string ``'None'`` when the value is ``None``. @@ -45,7 +45,7 @@ def escape_silent(s: t.Any | None) -> Markup: return escape(s) -def soft_str(s: t.Any) -> str: +def soft_str(s: t.Any, /) -> str: """Convert an object to a string if it isn't already. This preserves a :class:`Markup` string rather than converting it back to a basic string, so it will still be marked as safe and won't be escaped diff --git a/src/markupsafe/_speedups.pyi b/src/markupsafe/_speedups.pyi index f673240f..6af2376b 100644 --- a/src/markupsafe/_speedups.pyi +++ b/src/markupsafe/_speedups.pyi @@ -1,9 +1,7 @@ -from typing import Any -from typing import Optional +import typing as t from . import Markup -def escape(s: Any) -> Markup: ... -def escape_silent(s: Optional[Any]) -> Markup: ... -def soft_str(s: Any) -> str: ... -def soft_unicode(s: Any) -> str: ... +def escape(s: t.Any, /) -> Markup: ... +def escape_silent(s: t.Any | None, /) -> Markup: ... +def soft_str(s: t.Any, /) -> str: ... From 53d3f454828dc958b7c5ff972a39eb1279e5c943 Mon Sep 17 00:00:00 2001 From: David Lord Date: Fri, 15 Sep 2023 07:09:59 -0700 Subject: [PATCH 5/5] check types with pyright --- pyproject.toml | 5 +++++ requirements/dev.txt | 2 -- requirements/typing.in | 1 + requirements/typing.txt | 9 ++++++++- src/markupsafe/__init__.py | 14 ++++++++------ tests/test_markupsafe.py | 2 +- tox.ini | 5 ++++- 7 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8a4311b2..52fb6b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,8 @@ files = ["src/markupsafe", "tests"] show_error_codes = true pretty = true strict = true + +[tool.pyright] +pythonVersion = "3.8" +include = ["src/markupsafe", "tests"] +typeCheckingMode = "basic" diff --git a/requirements/dev.txt b/requirements/dev.txt index d453726c..2d24e6f0 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -30,8 +30,6 @@ filelock==3.12.3 # virtualenv identify==2.5.27 # via pre-commit -nodeenv==1.8.0 - # via pre-commit pip-compile-multi==2.6.3 # via -r requirements/dev.in pip-tools==7.3.0 diff --git a/requirements/typing.in b/requirements/typing.in index d76706ea..8be59c5d 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1,2 +1,3 @@ mypy +pyright pytest diff --git a/requirements/typing.txt b/requirements/typing.txt index 5f8c39c0..4dbce907 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,4 +1,4 @@ -# SHA1:c5ca84191e9dddc0ac30ed7693225ae49943a2ba +# SHA1:56a0ce21228a30a3076f16291ca773ce48ffcb2c # # This file is autogenerated by pip-compile-multi # To update, run: @@ -11,11 +11,18 @@ mypy==1.5.1 # via -r requirements/typing.in mypy-extensions==1.0.0 # via mypy +nodeenv==1.8.0 + # via pyright packaging==23.1 # via pytest pluggy==1.3.0 # via pytest +pyright==1.1.326 + # via -r requirements/typing.in pytest==7.4.1 # via -r requirements/typing.in typing-extensions==4.7.1 # via mypy + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/src/markupsafe/__init__.py b/src/markupsafe/__init__.py index a43052b3..b2d118e8 100644 --- a/src/markupsafe/__init__.py +++ b/src/markupsafe/__init__.py @@ -13,7 +13,9 @@ class HasHTML(te.Protocol): def __html__(self, /) -> str: ... - _P = te.ParamSpec("_P") + class TPEscape(te.Protocol): + def __call__(self, s: t.Any, /) -> Markup: + ... __version__ = "2.2.0.dev" @@ -259,8 +261,8 @@ def __html_format__(self, format_spec: str, /) -> te.Self: class EscapeFormatter(string.Formatter): __slots__ = ("escape",) - def __init__(self, escape: t.Callable[[t.Any], Markup]) -> None: - self.escape = escape + def __init__(self, escape: TPEscape) -> None: + self.escape: TPEscape = escape super().__init__() def format_field(self, value: t.Any, format_spec: str) -> str: @@ -286,9 +288,9 @@ class _MarkupEscapeHelper: __slots__ = ("obj", "escape") - def __init__(self, obj: t.Any, escape: t.Callable[[t.Any], Markup]) -> None: - self.obj = obj - self.escape = escape + def __init__(self, obj: t.Any, escape: TPEscape) -> None: + self.obj: t.Any = obj + self.escape: TPEscape = escape def __getitem__(self, key: t.Any, /) -> te.Self: return self.__class__(self.obj[key], self.escape) diff --git a/tests/test_markupsafe.py b/tests/test_markupsafe.py index 0c047092..d754f73e 100644 --- a/tests/test_markupsafe.py +++ b/tests/test_markupsafe.py @@ -57,7 +57,7 @@ def __str__(self) -> str: @pytest.mark.parametrize("args", ["foo", 42, ("foo", 42)]) def test_missing_interpol(args: t.Any) -> None: with pytest.raises(TypeError): - Markup("") % args + assert Markup("") % args def test_tuple_interpol() -> None: diff --git a/tox.ini b/tox.ini index 79a91c73..a5c0baca 100644 --- a/tox.ini +++ b/tox.ini @@ -19,7 +19,10 @@ commands = pre-commit run --all-files [testenv:typing] deps = -r requirements/typing.txt -commands = mypy +commands = + mypy + pyright + pyright --verifytypes markupsafe [testenv:docs] deps = -r requirements/docs.txt