Skip to content

Commit

Permalink
Merge pull request #1230 from marshmallow-code/propagate_partial_to_n…
Browse files Browse the repository at this point in the history
…ested_containers

Fix propagation of "partial" to Nested containers
  • Loading branch information
sloria committed Jun 9, 2019
2 parents a5597e2 + fe85631 commit 0d718a2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 23 deletions.
8 changes: 4 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def _deserialize(self, value, attr, data, **kwargs):
errors = {}
for idx, each in enumerate(value):
try:
result.append(self.container.deserialize(each))
result.append(self.container.deserialize(each, **kwargs))
except ValidationError as error:
if error.valid_data is not None:
result.append(error.valid_data)
Expand Down Expand Up @@ -682,7 +682,7 @@ def _deserialize(self, value, attr, data, **kwargs):

for idx, (container, each) in enumerate(zip(self.tuple_fields, value)):
try:
result.append(container.deserialize(each))
result.append(container.deserialize(each, **kwargs))
except ValidationError as error:
if error.valid_data is not None:
result.append(error.valid_data)
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def _deserialize(self, value, attr, data, **kwargs):
keys = {}
for key in value.keys():
try:
keys[key] = self.key_container.deserialize(key)
keys[key] = self.key_container.deserialize(key, **kwargs)
except ValidationError as error:
errors[key]['key'] = error.messages

Expand All @@ -1359,7 +1359,7 @@ def _deserialize(self, value, attr, data, **kwargs):
else:
for key, val in value.items():
try:
deser_val = self.value_container.deserialize(val)
deser_val = self.value_container.deserialize(val, **kwargs)
except ValidationError as error:
errors[key]['value'] = error.messages
if error.valid_data is not None and key in keys:
Expand Down
29 changes: 12 additions & 17 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""The :class:`Schema` class, including its metaclass and options (class Meta)."""
from collections import defaultdict, OrderedDict
from collections.abc import Mapping
import datetime as dt
import uuid
import decimal
import copy
import inspect
import json
import warnings
from collections.abc import Mapping

from marshmallow import base, fields as ma_fields, class_registry
from marshmallow.error_store import ErrorStore
from marshmallow.fields import Nested
from marshmallow.exceptions import ValidationError, StringNotCollectionError
from marshmallow.orderedset import OrderedSet
from marshmallow.decorators import (
Expand Down Expand Up @@ -635,22 +634,18 @@ def _deserialize(
):
continue
d_kwargs = {}
if isinstance(field_obj, Nested):
# Allow partial loading of nested schemas.
if partial_is_collection:
prefix = field_name + '.'
len_prefix = len(prefix)
sub_partial = [
f[len_prefix:]
for f in partial if f.startswith(prefix)
]
else:
sub_partial = partial
# Allow partial loading of nested schemas.
if partial_is_collection:
prefix = field_name + '.'
len_prefix = len(prefix)
sub_partial = [
f[len_prefix:]
for f in partial if f.startswith(prefix)
]
d_kwargs['partial'] = sub_partial
getter = lambda val: field_obj.deserialize(
val, field_name,
data, **d_kwargs
)
else:
d_kwargs['partial'] = partial
getter = lambda val: field_obj.deserialize(val, field_name, data, **d_kwargs)
value = self._call_and_store(
getter_func=getter,
data=raw_value,
Expand Down
100 changes: 98 additions & 2 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_error_raised_if_missing_is_set_on_required_field(self):

def test_custom_field_receives_attr_and_obj(self):
class MyField(fields.Field):
def _deserialize(self, val, attr, data):
def _deserialize(self, val, attr, data, **kwargs):
assert attr == 'name'
assert data['foo'] == 42
return val
Expand All @@ -61,7 +61,7 @@ class MySchema(Schema):

def test_custom_field_receives_data_key_if_set(self):
class MyField(fields.Field):
def _deserialize(self, val, attr, data):
def _deserialize(self, val, attr, data, **kwargs):
assert attr == 'name'
assert data['foo'] == 42
return val
Expand Down Expand Up @@ -243,3 +243,99 @@ class MySchema(Schema):
elif field_unknown == RAISE or (schema_unknown == RAISE and not field_unknown):
with pytest.raises(ValidationError):
MySchema().load({'nested': {'x': 1}})


class TestListNested:

def test_list_nested_partial_propagated_to_nested(self):

class Child(Schema):
name = fields.String(required=True)
age = fields.Integer(required=True)

class Family(Schema):
children = fields.List(fields.Nested(Child))

payload = {'children': [{'name': 'Lucette'}]}

for val in (True, ('children.age', )):
result = Family(partial=val).load(payload)
assert result['children'][0]['name'] == 'Lucette'
result = Family().load(payload, partial=val)
assert result['children'][0]['name'] == 'Lucette'

for val in (False, ('children.name', )):
with pytest.raises(ValidationError) as excinfo:
result = Family(partial=val).load(payload)
assert excinfo.value.args[0] == {
'children': {0: {'age': ['Missing data for required field.']}},
}
with pytest.raises(ValidationError) as excinfo:
result = Family().load(payload, partial=val)
assert excinfo.value.args[0] == {
'children': {0: {'age': ['Missing data for required field.']}},
}


class TestTupleNested:

def test_tuple_nested_partial_propagated_to_nested(self):

class Child(Schema):
name = fields.String(required=True)
age = fields.Integer(required=True)

class Family(Schema):
children = fields.Tuple((fields.Nested(Child), ))

payload = {'children': [{'name': 'Lucette'}]}

for val in (True, ('children.age', )):
result = Family(partial=val).load(payload)
assert result['children'][0]['name'] == 'Lucette'
result = Family().load(payload, partial=val)
assert result['children'][0]['name'] == 'Lucette'

for val in (False, ('children.name', )):
with pytest.raises(ValidationError) as excinfo:
result = Family(partial=val).load(payload)
assert excinfo.value.args[0] == {
'children': {0: {'age': ['Missing data for required field.']}},
}
with pytest.raises(ValidationError) as excinfo:
result = Family().load(payload, partial=val)
assert excinfo.value.args[0] == {
'children': {0: {'age': ['Missing data for required field.']}},
}


class TestDictNested:

def test_dict_nested_partial_propagated_to_nested(self):

class Child(Schema):
name = fields.String(required=True)
age = fields.Integer(required=True)

class Family(Schema):
children = fields.Dict(values=fields.Nested(Child))

payload = {'children': {'daughter': {'name': 'Lucette'}}}

for val in (True, ('children.age', )):
result = Family(partial=val).load(payload)
assert result['children']['daughter']['name'] == 'Lucette'
result = Family().load(payload, partial=val)
assert result['children']['daughter']['name'] == 'Lucette'

for val in (False, ('children.name', )):
with pytest.raises(ValidationError) as excinfo:
result = Family(partial=val).load(payload)
assert excinfo.value.args[0] == {
'children': {'daughter': {'value': {'age': ['Missing data for required field.']}}},
}
with pytest.raises(ValidationError) as excinfo:
result = Family().load(payload, partial=val)
assert excinfo.value.args[0] == {
'children': {'daughter': {'value': {'age': ['Missing data for required field.']}}},
}

0 comments on commit 0d718a2

Please sign in to comment.