Skip to content

Commit

Permalink
Merge pull request marshmallow-code#2017 from marshmallow-code/enum
Browse files Browse the repository at this point in the history
Add Enum field
  • Loading branch information
lafrech committed Sep 4, 2022
2 parents 1ddf058 + 7cc1de4 commit f61dc84
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 6 deletions.
75 changes: 75 additions & 0 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import typing
import warnings
from enum import Enum
from collections.abc import Mapping as _Mapping

from marshmallow import validate, utils, class_registry, types
Expand Down Expand Up @@ -59,6 +60,8 @@
"IPInterface",
"IPv4Interface",
"IPv6Interface",
"EnumSymbol",
"EnumValue",
"Method",
"Function",
"Str",
Expand Down Expand Up @@ -1855,6 +1858,78 @@ class IPv6Interface(IPInterface):
DESERIALIZATION_CLASS = ipaddress.IPv6Interface


class EnumSymbol(String):
"""An Enum field (de)serializing enum members by symbol (name) as string.
:param enum Enum: Enum class
.. versionadded:: 3.18.0
"""

default_error_messages = {
"unknown": "Must be one of: {choices}.",
}

def __init__(self, enum: type[Enum], **kwargs):
self.enum = enum
self.choices = ", ".join(enum.__members__)
super().__init__(**kwargs)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return value.name

def _deserialize(self, value, attr, data, **kwargs):
value = super()._deserialize(value, attr, data, **kwargs)
try:
return getattr(self.enum, value)
except AttributeError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


class EnumValue(Field):
"""An Enum field (de)serializing enum members by value.
A Field must be provided to (de)serialize the value.
:param cls_or_instance: Field class or instance.
:param enum Enum: Enum class
.. versionadded:: 3.18.0
"""

default_error_messages = {
"unknown": "Must be one of: {choices}.",
}

def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs):
super().__init__(**kwargs)
try:
self.field = resolve_field_instance(cls_or_instance)
except FieldInstanceResolutionError as error:
raise ValueError(
"The enum field must be a subclass or instance of "
"marshmallow.base.FieldABC."
) from error
self.enum = enum
self.choices = ", ".join(
[str(self.field._serialize(m.value, None, None)) for m in enum]
)

def _serialize(self, value, attr, obj, **kwargs):
if value is None:
return None
return self.field._serialize(value.value, attr, obj, **kwargs)

def _deserialize(self, value, attr, data, **kwargs):
value = self.field._deserialize(value, attr, data, **kwargs)
try:
return self.enum(value)
except ValueError as exc:
raise self.make_error("unknown", choices=self.choices) from exc


class Method(Field):
"""A field that takes the value returned by a `Schema` method.
Expand Down
35 changes: 31 additions & 4 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Test utilities and fixtures."""
import functools
import datetime as dt
import uuid
from enum import Enum, IntEnum

import simplejson

Expand All @@ -12,6 +14,25 @@
central = pytz.timezone("US/Central")


class GenderEnum(IntEnum):
male = 1
female = 2
non_binary = 3


class HairColorEnum(Enum):
black = "black hair"
brown = "brown hair"
blond = "blond hair"
red = "red hair"


class DateEnum(Enum):
date_1 = dt.date(2004, 2, 29)
date_2 = dt.date(2008, 2, 29)
date_3 = dt.date(2012, 2, 29)


ALL_FIELDS = [
fields.String,
fields.Integer,
Expand All @@ -33,8 +54,12 @@
fields.IPInterface,
fields.IPv4Interface,
fields.IPv6Interface,
functools.partial(fields.EnumSymbol, GenderEnum),
functools.partial(fields.EnumValue, fields.String, HairColorEnum),
functools.partial(fields.EnumValue, fields.Integer, GenderEnum),
]


##### Custom asserts #####


Expand Down Expand Up @@ -69,7 +94,8 @@ def __init__(
birthdate=None,
birthtime=None,
balance=100,
sex="male",
sex=GenderEnum.male,
hair_color=HairColorEnum.black,
employer=None,
various_data=None,
):
Expand All @@ -86,15 +112,16 @@ def __init__(
self.email = email
self.balance = balance
self.registered = registered
self.hair_colors = ["black", "brown", "blond", "redhead"]
self.sex_choices = ("male", "female")
self.hair_colors = list(HairColorEnum.__members__)
self.sex_choices = list(GenderEnum.__members__)
self.finger_count = 10
self.uid = uuid.uuid1()
self.time_registered = time_registered or dt.time(1, 23, 45, 6789)
self.birthdate = birthdate or dt.date(2013, 1, 23)
self.birthtime = birthtime or dt.time(0, 1, 2, 3333)
self.activation_date = dt.date(2013, 12, 11)
self.sex = sex
self.hair_color = hair_color
self.employer = employer
self.relatives = []
self.various_data = various_data or {
Expand Down Expand Up @@ -180,7 +207,7 @@ class UserSchema(Schema):
birthtime = fields.Time()
activation_date = fields.Date()
since_created = fields.TimeDelta()
sex = fields.Str(validate=validate.OneOf(["male", "female"]))
sex = fields.Str(validate=validate.OneOf(list(GenderEnum.__members__)))
various_data = fields.Dict()

class Meta:
Expand Down
61 changes: 60 additions & 1 deletion tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from marshmallow.exceptions import ValidationError
from marshmallow.validate import Equal

from tests.base import assert_date_equal, assert_time_equal, central, ALL_FIELDS
from tests.base import (
assert_date_equal,
assert_time_equal,
central,
ALL_FIELDS,
GenderEnum,
HairColorEnum,
DateEnum,
)


class TestDeserializingNone:
Expand Down Expand Up @@ -1089,6 +1097,57 @@ def test_invalid_ipv6interface_deserialization(self, in_value):

assert excinfo.value.args[0] == "Not a valid IPv6 interface."

def test_enumsymbol_field_deserialization(self):
field = fields.EnumSymbol(GenderEnum)
assert field.deserialize("male") == GenderEnum.male

def test_enumsymbol_field_invalid_value(self):
field = fields.EnumSymbol(GenderEnum)
with pytest.raises(
ValidationError, match="Must be one of: male, female, non_binary."
):
field.deserialize("dummy")

def test_enumsymbol_field_not_string(self):
field = fields.EnumSymbol(GenderEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)

def test_enumvalue_field_deserialization(self):
field = fields.EnumValue(fields.String, HairColorEnum)
assert field.deserialize("black hair") == HairColorEnum.black
field = fields.EnumValue(fields.Integer, GenderEnum)
assert field.deserialize(1) == GenderEnum.male
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
assert field.deserialize("29/02/2004") == DateEnum.date_1

def test_enumvalue_field_invalid_value(self):
field = fields.EnumValue(fields.String, HairColorEnum)
with pytest.raises(
ValidationError,
match="Must be one of: black hair, brown hair, blond hair, red hair.",
):
field.deserialize("dummy")
field = fields.EnumValue(fields.Integer, GenderEnum)
with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."):
field.deserialize(12)
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
with pytest.raises(
ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012."
):
field.deserialize("28/02/2004")

def test_enumvalue_field_wrong_type(self):
field = fields.EnumValue(fields.String, HairColorEnum)
with pytest.raises(ValidationError, match="Not a valid string."):
field.deserialize(12)
field = fields.EnumValue(fields.Integer, GenderEnum)
with pytest.raises(ValidationError, match="Not a valid integer."):
field.deserialize("dummy")
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
with pytest.raises(ValidationError, match="Not a valid date."):
field.deserialize("30/02/2004")

def test_deserialization_function_must_be_callable(self):
with pytest.raises(TypeError):
fields.Function(lambda x: None, deserialize="notvalid")
Expand Down
18 changes: 17 additions & 1 deletion tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from marshmallow import Schema, fields, missing as missing_

from tests.base import User, ALL_FIELDS, central
from tests.base import User, ALL_FIELDS, central, GenderEnum, HairColorEnum, DateEnum


class DateTimeList:
Expand Down Expand Up @@ -255,6 +255,22 @@ def test_ipv6_interface_field(self, user):
== ipv6interface_exploded_string
)

def test_enumsymbol_field_serialization(self, user):
user.sex = GenderEnum.male
field = fields.EnumSymbol(GenderEnum)
assert field.serialize("sex", user) == "male"

def test_enumvalue_field_serialization(self, user):
user.hair_color = HairColorEnum.black
field = fields.EnumValue(fields.String, HairColorEnum)
assert field.serialize("hair_color", user) == "black hair"
user.sex = GenderEnum.male
field = fields.EnumValue(fields.Integer, GenderEnum)
assert field.serialize("sex", user) == 1
user.some_date = DateEnum.date_1
field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum)
assert field.serialize("some_date", user) == "29/02/2004"

def test_decimal_field(self, user):
user.m1 = 12
user.m2 = "12.355"
Expand Down

0 comments on commit f61dc84

Please sign in to comment.