Skip to content

Commit

Permalink
Merge pull request marshmallow-code#2044 from marshmallow-code/rework…
Browse files Browse the repository at this point in the history
…_enum

Merge EnumSymbol and EnumValue into Enum field
  • Loading branch information
lafrech committed Sep 15, 2022
2 parents 3759755 + db37b73 commit 3e15e19
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 80 deletions.
106 changes: 53 additions & 53 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
import typing
import warnings
from enum import Enum
from enum import Enum as EnumType
from collections.abc import Mapping as _Mapping

from marshmallow import validate, utils, class_registry, types
Expand Down Expand Up @@ -60,8 +60,7 @@
"IPInterface",
"IPv4Interface",
"IPv6Interface",
"EnumSymbol",
"EnumValue",
"Enum",
"Method",
"Function",
"Str",
Expand Down Expand Up @@ -1856,43 +1855,16 @@ class IPv6Interface(IPInterface):
DESERIALIZATION_CLASS = ipaddress.IPv6Interface


class EnumSymbol(String):
"""An Enum field (de)serializing enum members by symbol (name) as string.
class Enum(Field):
"""An Enum field (de)serializing enum members by symbol (name) or by value.
:param enum Enum: Enum class
:param boolean|Schema|Field by_value: Whether to (de)serialize by value or by name,
or Field class or instance to use to (de)serialize by value. Defaults to False.
.. versionadded:: 3.18.0
"""

default_error_messages = {
"unknown": "Must be one of: {choices}.",
}

def __init__(self, enum: type[Enum], **kwargs):
self.enum = enum
self.choices = ", ".join(enum.__members__)
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return value.name

def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


class EnumValue(Field):
"""An Enum field (de)serializing enum members by value.
A Field must be provided to (de)serialize the value.
:param cls_or_instance: Field class or instance.
:param enum Enum: Enum class
If `by_value` is `False` (default), enum members are (de)serialized by symbol (name).
If it is `True`, they are (de)serialized by value using :class:`Field`.
If it is a field instance or class, they are (de)serialized by value using this field.
.. versionadded:: 3.18.0
"""
Expand All @@ -1901,31 +1873,59 @@ class EnumValue(Field):
"unknown": "Must be one of: {choices}.",
}

def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs):
def __init__(
self,
enum: type[EnumType],
*,
by_value: bool | Field | type = False,
**kwargs,
):
super().__init__(**kwargs)
try:
self.field = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError as error:
raise ValueError(
"The enum field must be a subclass or instance of "
"marshmallow.base.FieldABC."
) from error
self.enum = enum
self.choices = ", ".join(
[str(self.field._serialize(m.value, None, None)) for m in enum]
)
self.by_value = by_value

# Serialization by name
if by_value is False:
self.field: Field = String()
self.choices_text = ", ".join(
str(self.field._serialize(m, None, None)) for m in enum.__members__
)
# Serialization by value
else:
if by_value is True:
self.field = Field()
else:
try:
self.field = resolve_field_instance(by_value)
except FieldInstanceResolutionError as error:
raise ValueError(
'"by_value" must be either a bool or a subclass or instance of '
"marshmallow.base.FieldABC."
) from error
self.choices_text = ", ".join(
str(self.field._serialize(m.value, None, None)) for m in enum
)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return self.field._serialize(value.value, attr, obj, **kwargs)
if self.by_value:
val = value.value
else:
val = value.name
return self.field._serialize(val, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
value = self.field._deserialize(value, attr, data, **kwargs)
val = self.field._deserialize(value, attr, data, **kwargs)
if self.by_value:
try:
return self.enum(val)
except ValueError as error:
raise self.make_error("unknown", choices=self.choices_text) from error
try:
return self.enum(value)
except ValueError as exc:
raise self.make_error("unknown", choices=self.choices) from exc
return getattr(self.enum, val)
except AttributeError as error:
raise self.make_error("unknown", choices=self.choices_text) from error


class Method(Field):
Expand Down
6 changes: 3 additions & 3 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class DateEnum(Enum):
fields.IPInterface,
fields.IPv4Interface,
fields.IPv6Interface,
functools.partial(fields.EnumSymbol, GenderEnum),
functools.partial(fields.EnumValue, fields.String, HairColorEnum),
functools.partial(fields.EnumValue, fields.Integer, GenderEnum),
functools.partial(fields.Enum, GenderEnum),
functools.partial(fields.Enum, HairColorEnum, by_value=fields.String),
functools.partial(fields.Enum, GenderEnum, by_value=fields.Integer),
]


Expand Down
64 changes: 46 additions & 18 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,54 +1097,82 @@ def test_invalid_ipv6interface_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid IPv6 interface."

def test_enumsymbol_field_deserialization(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_field_by_symbol_deserialization(self):
field = fields.Enum(GenderEnum)
assert field.deserialize("male") == GenderEnum.male

def test_enumsymbol_field_invalid_value(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_field_by_symbol_invalid_value(self):
field = fields.Enum(GenderEnum)
with pytest.raises(
ValidationError, match="Must be one of: male, female, non_binary."
):
field.deserialize("dummy")

def test_enumsymbol_field_not_string(self):
field = fields.EnumSymbol(GenderEnum)
def test_enum_field_by_symbol_not_string(self):
field = fields.Enum(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_enumvalue_field_deserialization(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_field_by_value_true_deserialization(self):
field = fields.Enum(HairColorEnum, by_value=True)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, by_value=True)
assert field.deserialize(1) == GenderEnum.male
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)

def test_enum_field_by_value_field_deserialization(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.Enum(GenderEnum, by_value=fields.Integer)
assert field.deserialize(1) == GenderEnum.male
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
assert field.deserialize("29/02/2004") == DateEnum.date_1

def test_enumvalue_field_invalid_value(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_field_by_value_true_invalid_value(self):
field = fields.Enum(HairColorEnum, by_value=True)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, by_value=True)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)

def test_enum_field_by_value_field_invalid_value(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.Enum(GenderEnum, by_value=fields.Integer)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
with pytest.raises(
ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012."
):
field.deserialize("28/02/2004")

def test_enumvalue_field_wrong_type(self):
field = fields.EnumValue(fields.String, HairColorEnum)
def test_enum_field_by_value_true_wrong_type(self):
field = fields.Enum(HairColorEnum, by_value=True)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.Enum(GenderEnum, by_value=True)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)

def test_enum_field_by_value_field_wrong_type(self):
field = fields.Enum(HairColorEnum, by_value=fields.String)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, by_value=fields.Integer)
with pytest.raises(ValidationError, match="Not a valid integer."):
field.deserialize("dummy")
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
with pytest.raises(ValidationError, match="Not a valid date."):
field.deserialize("30/02/2004")

Expand Down
21 changes: 15 additions & 6 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,29 @@ def test_ipv6_interface_field(self, user):
== ipv6interface_exploded_string
)

def test_enumsymbol_field_serialization(self, user):
def test_enum_field_by_symbol_serialization(self, user):
user.sex = GenderEnum.male
field = fields.EnumSymbol(GenderEnum)
field = fields.Enum(GenderEnum)
assert field.serialize("sex", user) == "male"

def test_enumvalue_field_serialization(self, user):
def test_enum_field_by_value_true_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.EnumValue(fields.String, HairColorEnum)
field = fields.Enum(HairColorEnum, by_value=True)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.EnumValue(fields.Integer, GenderEnum)
field = fields.Enum(GenderEnum, by_value=True)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)

def test_enum_field_by_value_field_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.Enum(HairColorEnum, by_value=fields.String)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.Enum(GenderEnum, by_value=fields.Integer)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.Enum(DateEnum, by_value=fields.Date(format="%d/%m/%Y"))
assert field.serialize("some_date", user) == "29/02/2004"

def test_decimal_field(self, user):
Expand Down

0 comments on commit 3e15e19

Please sign in to comment.