diff --git a/src/objects/api/serializers.py b/src/objects/api/serializers.py index 21c00cb7..4e7cc4b4 100644 --- a/src/objects/api/serializers.py +++ b/src/objects/api/serializers.py @@ -9,6 +9,7 @@ from objects.utils.serializers import DynamicFieldsMixin from .fields import ObjectSlugRelatedField, ObjectTypeField, ObjectUrlField +from .utils import merge_patch from .validators import GeometryValidator, IsImmutableValidator, JsonSchemaValidator @@ -126,9 +127,12 @@ def update(self, instance, validated_data): # object_data is not used since all object attributes are immutable object_data = validated_data.pop("object", None) validated_data["object"] = instance.object - # in case of PATCH + # version should be set if "version" not in validated_data: validated_data["version"] = instance.version + if self.partial and "data" in validated_data: + # Apply JSON Merge Patch for record data + validated_data["data"] = merge_patch(instance.data, validated_data["data"]) record = super().create(validated_data) return record diff --git a/src/objects/api/utils.py b/src/objects/api/utils.py index 84c2b86f..b3c94da0 100644 --- a/src/objects/api/utils.py +++ b/src/objects/api/utils.py @@ -1,8 +1,10 @@ from datetime import date -from typing import Union +from typing import Dict, Union from djchoices import DjangoChoices +from objects.typing import JSONValue + def string_to_value(value: str) -> Union[str, float, date]: if is_number(value): @@ -43,3 +45,24 @@ def display_choice_values_for_help_text(choices: DjangoChoices) -> str: items.append(item) return "\n".join(items) + + +def merge_patch(target: JSONValue, patch: JSONValue) -> Dict[str, JSONValue]: + """Merge two objects together recursively. + + This is inspired by https://datatracker.ietf.org/doc/html/rfc7396 - JSON Merge Patch, + but deviating in some cases to suit our needs. + """ + + if not isinstance(patch, dict): + return patch + + if not isinstance(target, dict): + # Ignore the contents and set it to an empty dict + target = {} + for k, v in patch.items(): + # According to RFC, we should remove k from target + # if v is None. This is where we deviate. + target[k] = merge_patch(target.get(k), v) + + return target diff --git a/src/objects/api/v1/openapi.yaml b/src/objects/api/v1/openapi.yaml index 4892a747..832f404a 100644 --- a/src/objects/api/v1/openapi.yaml +++ b/src/objects/api/v1/openapi.yaml @@ -349,6 +349,8 @@ paths: patch: operationId: object_partial_update description: Update the OBJECT by creating a new RECORD with the updates values. + The provided `record.data` value will be merged recursively with the existing + record data. parameters: - in: header name: Accept-Crs diff --git a/src/objects/api/v1/views.py b/src/objects/api/v1/views.py index 24c8b240..a00a1b2c 100644 --- a/src/objects/api/v1/views.py +++ b/src/objects/api/v1/views.py @@ -39,7 +39,8 @@ description="Update the OBJECT by creating a new RECORD with the updates values." ), partial_update=extend_schema( - description="Update the OBJECT by creating a new RECORD with the updates values." + description="Update the OBJECT by creating a new RECORD with the updates values. " + "The provided `record.data` value will be merged recursively with the existing record data." ), destroy=extend_schema( description="Delete an OBJECT and all RECORDs belonging to it.", diff --git a/src/objects/api/v2/openapi.yaml b/src/objects/api/v2/openapi.yaml index 0bd6bf1d..b01fdd60 100644 --- a/src/objects/api/v2/openapi.yaml +++ b/src/objects/api/v2/openapi.yaml @@ -377,6 +377,8 @@ paths: patch: operationId: object_partial_update description: Update the OBJECT by creating a new RECORD with the updates values. + The provided `record.data` value will be merged recursively with the existing + record data. parameters: - in: header name: Accept-Crs diff --git a/src/objects/api/v2/views.py b/src/objects/api/v2/views.py index 702146b1..1cca735f 100644 --- a/src/objects/api/v2/views.py +++ b/src/objects/api/v2/views.py @@ -42,7 +42,8 @@ description="Update the OBJECT by creating a new RECORD with the updates values." ), partial_update=extend_schema( - description="Update the OBJECT by creating a new RECORD with the updates values." + description="Update the OBJECT by creating a new RECORD with the updates values. " + "The provided `record.data` value will be merged recursively with the existing record data." ), destroy=extend_schema( description="Delete an OBJECT and all RECORDs belonging to it.", diff --git a/src/objects/tests/test_merge_patch.py b/src/objects/tests/test_merge_patch.py new file mode 100644 index 00000000..8fb1d7d7 --- /dev/null +++ b/src/objects/tests/test_merge_patch.py @@ -0,0 +1,34 @@ +from unittest import TestCase + +from objects.api.utils import merge_patch + + +class MergePatchTests(TestCase): + def test_merge_patch(self): + + test_data = [ + ({"a": "b"}, {"a": "c"}, {"a": "c"}), + ({"a": "b"}, {"b": "c"}, {"a": "b", "b": "c"}), + ({"a": "b"}, {"a": None}, {"a": None}), + ({"a": "b", "b": "c"}, {"a": None}, {"a": None, "b": "c"}), + ({"a": ["b"]}, {"a": "c"}, {"a": "c"}), + ({"a": "c"}, {"a": ["b"]}, {"a": ["b"]}), + ( + {"a": {"b": "c"}}, + {"a": {"b": "d", "c": None}}, + {"a": {"b": "d", "c": None}}, + ), + ({"a": [{"b": "c"}]}, {"a": [1]}, {"a": [1]}), + (["a", "b"], ["c", "d"], ["c", "d"]), + ({"a": "b"}, ["c"], ["c"]), + ({"a": "foo"}, None, None), + ({"a": "foo"}, "bar", "bar"), + ({"e": None}, {"a": 1}, {"e": None, "a": 1}), + ([1, 2], {"a": "b", "c": None}, {"a": "b", "c": None}), + ({}, {"a": {"bb": {"ccc": None}}}, {"a": {"bb": {"ccc": None}}}), + ({"a": "b"}, {"a": "b"}, {"a": "b"}), + ] + + for target, patch, expected in test_data: + with self.subTest(): + self.assertEqual(merge_patch(target, patch), expected) diff --git a/src/objects/tests/v1/test_object_api.py b/src/objects/tests/v1/test_object_api.py index fed267a1..52667ad7 100644 --- a/src/objects/tests/v1/test_object_api.py +++ b/src/objects/tests/v1/test_object_api.py @@ -205,7 +205,10 @@ def test_patch_object_record(self, m): ) initial_record = ObjectRecordFactory.create( - version=1, object__object_type=self.object_type, start_at=date.today() + version=1, + object__object_type=self.object_type, + start_at=date.today(), + data={"name": "Name", "diameter": 20}, ) object = initial_record.object @@ -229,8 +232,10 @@ def test_patch_object_record(self, m): current_record = object.current_record self.assertEqual(current_record.version, initial_record.version) + # The actual behavior of the data merging is in test_merge_patch.py: self.assertEqual( - current_record.data, {"plantDate": "2020-04-12", "diameter": 30} + current_record.data, + {"plantDate": "2020-04-12", "diameter": 30, "name": "Name"}, ) self.assertEqual(current_record.start_at, date(2020, 1, 1)) self.assertEqual(current_record.registration_at, date(2020, 8, 8)) diff --git a/src/objects/tests/v2/test_object_api.py b/src/objects/tests/v2/test_object_api.py index 12b5a43a..7a4c8c88 100644 --- a/src/objects/tests/v2/test_object_api.py +++ b/src/objects/tests/v2/test_object_api.py @@ -1,13 +1,14 @@ import json import uuid from datetime import date, timedelta +from typing import cast import requests_mock from freezegun import freeze_time from rest_framework import status from rest_framework.test import APITestCase -from objects.core.models import Object +from objects.core.models import Object, ObjectType from objects.core.tests.factories import ( ObjectFactory, ObjectRecordFactory, @@ -33,7 +34,9 @@ class ObjectApiTests(TokenAuthMixin, APITestCase): def setUpTestData(cls): super().setUpTestData() - cls.object_type = ObjectTypeFactory(service__api_root=OBJECT_TYPES_API) + cls.object_type = cast( + ObjectType, ObjectTypeFactory(service__api_root=OBJECT_TYPES_API) + ) PermissionFactory.create( object_type=cls.object_type, mode=PermissionModes.read_and_write, @@ -227,14 +230,17 @@ def test_patch_object_record(self, m): ) initial_record = ObjectRecordFactory.create( - version=1, object__object_type=self.object_type, start_at=date.today() + version=1, + object__object_type=self.object_type, + start_at=date.today(), + data={"name": "Name", "diameter": 20}, ) object = initial_record.object url = reverse("object-detail", args=[object.uuid]) data = { "record": { - "data": {"plantDate": "2020-04-12", "diameter": 30}, + "data": {"plantDate": "2020-04-12", "diameter": 30, "name": None}, "startAt": "2020-01-01", "correctionFor": initial_record.index, }, @@ -251,8 +257,10 @@ def test_patch_object_record(self, m): current_record = object.current_record self.assertEqual(current_record.version, initial_record.version) + # The actual behavior of the data merging is in test_merge_patch.py: self.assertEqual( - current_record.data, {"plantDate": "2020-04-12", "diameter": 30} + current_record.data, + {"plantDate": "2020-04-12", "diameter": 30, "name": None}, ) self.assertEqual(current_record.start_at, date(2020, 1, 1)) self.assertEqual(current_record.registration_at, date(2020, 8, 8)) diff --git a/src/objects/typing.py b/src/objects/typing.py new file mode 100644 index 00000000..0264e153 --- /dev/null +++ b/src/objects/typing.py @@ -0,0 +1,7 @@ +from typing import Dict, List, Union + +JSONPrimitive = Union[str, int, None, float, bool] + +JSONValue = Union[JSONPrimitive, "JSONObject", List["JSONValue"]] + +JSONObject = Dict[str, JSONValue]