Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Field take default on feature #1767

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ Contributors (chronological)
- Stephen Rosen `@sirosen <https://github.com/sirosen>`_
- Vladimir Mikhaylov `@vemikhaylov <https://github.com/vemikhaylov>`_
- Stephen Eaton `@madeinoz67 <https://github.com/madeinoz67>`_
- Dor Meiri `@dormeiri <https://github.com/dormeiri>`_
12 changes: 9 additions & 3 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
dump_only: bool = False,
error_messages: typing.Optional[typing.Dict[str, str]] = None,
metadata: typing.Optional[typing.Mapping[str, typing.Any]] = None,
take_default_on=None,
**additional_metadata
) -> None:
self.default = default
Expand Down Expand Up @@ -191,6 +192,7 @@ def __init__(
raise ValueError("'missing' must not be set for required fields.")
self.required = required
self.missing = missing
self.take_default_on = tuple() if take_default_on is None else take_default_on

metadata = metadata or {}
self.metadata = {**metadata, **additional_metadata}
Expand Down Expand Up @@ -218,7 +220,8 @@ def __repr__(self) -> str:
"validate={self.validate}, required={self.required}, "
"load_only={self.load_only}, dump_only={self.dump_only}, "
"missing={self.missing}, allow_none={self.allow_none}, "
"error_messages={self.error_messages})>".format(
"error_messages={self.error_messages}, "
"take_default_on={self.take_default_on})>".format(
ClassName=self.__class__.__name__, self=self
)
)
Expand Down Expand Up @@ -323,10 +326,10 @@ def serialize(
"""
if self._CHECK_ATTRIBUTE:
value = self.get_value(obj, attr, accessor=accessor)
if value is missing_ and hasattr(self, "default"):
if self._should_take_default(value) and hasattr(self, "default"):
default = self.default
value = default() if callable(default) else default
if value is missing_:
if self._should_take_default(value):
return value
else:
value = None
Expand Down Expand Up @@ -416,6 +419,9 @@ def _deserialize(
"""
return value

def _should_take_default(self, value):
return value is missing_ or value in self.take_default_on

# Properties

@property
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_repr(self):
"validate=None, required=False, "
"load_only=False, dump_only=False, "
"missing={missing}, allow_none=False, "
"error_messages={error_messages})>".format(
"error_messages={error_messages}, take_default_on=())>".format(
default, missing=missing, error_messages=field.error_messages
)
)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,51 @@ class MySchema(Schema):
assert errors["allow_none_field"][0] == "<custom>"


class TestTakeDefaultOn:
class MySchema(Schema):
int_take_default_on_none = fields.Int(default=42, take_default_on=[None])
int_take_default_on_value = fields.Int(default=42, take_default_on=[0, 1])
str_take_default_on_none = fields.Str(default="foo", take_default_on=[None])
str_take_default_on_value = fields.Str(
default="foo", take_default_on=["", "bar"]
)

@pytest.fixture()
def schema(self):
return self.MySchema()

@pytest.fixture()
def default_values(self):
return dict(
int_take_default_on_none=42,
int_take_default_on_value=42,
str_take_default_on_none="foo",
str_take_default_on_value="foo",
)

def test_default_taken_not_missing(self, schema, default_values):
data = dict(
int_take_default_on_none=None,
int_take_default_on_value=0,
str_take_default_on_none=None,
str_take_default_on_value="bar",
)
assert schema.dump(data) == default_values

def test_default_taken_missing(self, schema, default_values):
data = dict()
assert schema.dump(data) == default_values

def test_default_not_taken(self, schema, default_values):
data = dict(
int_take_default_on_none=-1,
int_take_default_on_value=-1,
str_take_default_on_none="baz",
str_take_default_on_value="baz",
)
assert schema.dump(data) == data


class TestDefaults:
class MySchema(Schema):
int_no_default = fields.Int(allow_none=True)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ def test_integer_field_default_set_to_none(self, user):
field = fields.Integer(default=None)
assert field.serialize("age", user) is None

def test_integer_field_take_default_on(self, user):
field = fields.Integer(default=0, take_default_on=[None, 1])
user.age = 42
assert field.serialize("age", user) == 42
del user.age
assert field.serialize("age", user) == 0
user.age = None
assert field.serialize("age", user) == 0
user.age = 1
assert field.serialize("age", user) == 0

def test_uuid_field(self, user):
user.uuid1 = uuid.UUID("12345678123456781234567812345678")
user.uuid2 = None
Expand Down