Skip to content

Commit

Permalink
Improve types in validate module (marshmallow-code#1786)
Browse files Browse the repository at this point in the history
  • Loading branch information
sloria committed Apr 4, 2021
1 parent 40c5662 commit 6f8ec72
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 23 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
---------

3.12.0 (unreleased)
*******************

Other changes:

- Improve types in ``marshmallow.validate``.
- Make `marshmallow.validate.Validator` an abstract base class.

3.11.1 (2021-03-29)
*******************

Expand Down
61 changes: 38 additions & 23 deletions src/marshmallow/validate.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
"""Validation classes for various types of data."""
import re
import typing
from abc import ABC, abstractmethod
from itertools import zip_longest
from operator import attrgetter
import typing

from marshmallow import types
from marshmallow.exceptions import ValidationError

_T = typing.TypeVar("_T")


class Validator:
"""Base abstract class for validators.
class Validator(ABC):
"""Abstract base class for validators.
.. note::
This class does not provide any behavior. It is only used to
This class does not provide any validation behavior. It is only used to
add a useful `__repr__` implementation for validators.
"""

Expand All @@ -32,6 +35,10 @@ def _repr_args(self) -> str:
"""
return ""

@abstractmethod
def __call__(self, value: typing.Any) -> typing.Any:
...


class URL(Validator):
"""Validate a URL.
Expand All @@ -48,7 +55,7 @@ class RegexMemoizer:
def __init__(self):
self._memoized = {}

def _regex_generator(self, relative: bool, require_tld: bool):
def _regex_generator(self, relative: bool, require_tld: bool) -> typing.Pattern:
return re.compile(
r"".join(
(
Expand Down Expand Up @@ -107,7 +114,7 @@ def _repr_args(self) -> str:
def _format_error(self, value) -> str:
return self.error.format(input=value)

def __call__(self, value) -> typing.Any:
def __call__(self, value: str) -> str:
message = self._format_error(value)
if not value:
raise ValidationError(message)
Expand Down Expand Up @@ -157,10 +164,10 @@ class Email(Validator):
def __init__(self, *, error: typing.Optional[str] = None):
self.error = error or self.default_message # type: str

def _format_error(self, value) -> typing.Any:
def _format_error(self, value: str) -> str:
return self.error.format(input=value)

def __call__(self, value) -> typing.Any:
def __call__(self, value: str) -> str:
message = self._format_error(value)

if not value or "@" not in value:
Expand Down Expand Up @@ -245,10 +252,10 @@ def _repr_args(self) -> str:
self.min, self.max, self.min_inclusive, self.max_inclusive
)

def _format_error(self, value, message: str) -> str:
def _format_error(self, value: _T, message: str) -> str:
return (self.error or message).format(input=value, min=self.min, max=self.max)

def __call__(self, value) -> typing.Any:
def __call__(self, value: _T) -> _T:
if self.min is not None and (
value < self.min if self.min_inclusive else value <= self.min
):
Expand Down Expand Up @@ -306,12 +313,12 @@ def __init__(
def _repr_args(self) -> str:
return "min={!r}, max={!r}, equal={!r}".format(self.min, self.max, self.equal)

def _format_error(self, value, message: str) -> str:
def _format_error(self, value: typing.Sized, message: str) -> str:
return (self.error or message).format(
input=value, min=self.min, max=self.max, equal=self.equal
)

def __call__(self, value) -> typing.Any:
def __call__(self, value: typing.Sized) -> typing.Sized:
length = len(value)

if self.equal is not None:
Expand Down Expand Up @@ -348,10 +355,10 @@ def __init__(self, comparable, *, error: typing.Optional[str] = None):
def _repr_args(self) -> str:
return "comparable={!r}".format(self.comparable)

def _format_error(self, value) -> str:
def _format_error(self, value: _T) -> str:
return self.error.format(input=value, other=self.comparable)

def __call__(self, value) -> typing.Any:
def __call__(self, value: _T) -> _T:
if value != self.comparable:
raise ValidationError(self._format_error(value))
return value
Expand All @@ -377,7 +384,7 @@ class Regexp(Validator):
def __init__(
self,
regex: typing.Union[str, bytes, typing.Pattern],
flags=0,
flags: int = 0,
*,
error: typing.Optional[str] = None
):
Expand All @@ -389,10 +396,18 @@ def __init__(
def _repr_args(self) -> str:
return "regex={!r}".format(self.regex)

def _format_error(self, value) -> str:
def _format_error(self, value: typing.Union[str, bytes]) -> str:
return self.error.format(input=value, regex=self.regex.pattern)

def __call__(self, value) -> typing.Any:
@typing.overload
def __call__(self, value: str) -> str:
...

@typing.overload
def __call__(self, value: bytes) -> bytes:
...

def __call__(self, value):
if self.regex.match(value) is None:
raise ValidationError(self._format_error(value))

Expand Down Expand Up @@ -421,10 +436,10 @@ def __init__(self, method: str, *, error: typing.Optional[str] = None, **kwargs)
def _repr_args(self) -> str:
return "method={!r}, kwargs={!r}".format(self.method, self.kwargs)

def _format_error(self, value) -> str:
def _format_error(self, value: typing.Any) -> str:
return self.error.format(input=value, method=self.method)

def __call__(self, value) -> str:
def __call__(self, value: typing.Any) -> typing.Any:
method = getattr(value, self.method)

if not method(**self.kwargs):
Expand Down Expand Up @@ -456,7 +471,7 @@ def _repr_args(self) -> str:
def _format_error(self, value) -> str:
return self.error.format(input=value, values=self.values_text)

def __call__(self, value) -> str:
def __call__(self, value: typing.Any) -> typing.Any:
try:
if value in self.iterable:
raise ValidationError(self._format_error(value))
Expand Down Expand Up @@ -498,7 +513,7 @@ def _format_error(self, value) -> str:
input=value, choices=self.choices_text, labels=self.labels_text
)

def __call__(self, value) -> str:
def __call__(self, value: typing.Any) -> typing.Any:
try:
if value not in self.choices:
raise ValidationError(self._format_error(value))
Expand Down Expand Up @@ -549,7 +564,7 @@ def _format_error(self, value) -> str:
value_text = ", ".join(str(val) for val in value)
return super()._format_error(value_text)

def __call__(self, value) -> typing.Any:
def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
# We can't use set.issubset because does not handle unhashable types
for val in value:
if val not in self.choices:
Expand All @@ -574,7 +589,7 @@ def _format_error(self, value) -> str:
value_text = ", ".join(str(val) for val in value)
return super()._format_error(value_text)

def __call__(self, value) -> typing.Any:
def __call__(self, value: typing.Sequence[_T]) -> typing.Sequence[_T]:
for val in value:
if val in self.iterable:
raise ValidationError(self._format_error(value))
Expand Down

0 comments on commit 6f8ec72

Please sign in to comment.