From 730c343890835677aa85c6663aa61100b634ab13 Mon Sep 17 00:00:00 2001 From: Lars van de Kerkhof Date: Wed, 4 Oct 2023 16:18:51 +0200 Subject: [PATCH] Implemented faster attributes save --- oscarapi/serializers/admin/product.py | 4 +- oscarapi/serializers/fields.py | 91 +++++++++------------------ oscarapi/serializers/product.py | 55 +++++++++++++--- oscarapi/tests/unit/testproduct.py | 2 + oscarapi/utils/attributes.py | 78 +++++++++++++++++++++++ 5 files changed, 161 insertions(+), 69 deletions(-) create mode 100644 oscarapi/utils/attributes.py diff --git a/oscarapi/serializers/admin/product.py b/oscarapi/serializers/admin/product.py index 1843b02c..3984bef4 100644 --- a/oscarapi/serializers/admin/product.py +++ b/oscarapi/serializers/admin/product.py @@ -145,7 +145,9 @@ def update(self, instance, validated_data): if ( self.partial ): # we need to clean up all the attributes with wrong product class - attribute_codes = product_class.attributes.values_list("code", flat=True) + attribute_codes = product_class.attributes.values_list( + "code", flat=True + ) for attribute_value in instance.attribute_values.exclude( attribute__product_class=product_class ): diff --git a/oscarapi/serializers/fields.py b/oscarapi/serializers/fields.py index 1e0599f8..c250338f 100644 --- a/oscarapi/serializers/fields.py +++ b/oscarapi/serializers/fields.py @@ -17,6 +17,7 @@ from oscar.core.loading import get_model, get_class from oscarapi import settings +from oscarapi.utils.attributes import AttributeFieldBase, attribute_details from oscarapi.utils.loading import get_api_class from oscarapi.utils.exists import bound_unique_together_get_or_create from .exceptions import FieldError @@ -27,7 +28,6 @@ create_from_breadcrumbs = get_class("catalogue.categories", "create_from_breadcrumbs") entity_internal_value = get_api_class("serializers.hooks", "entity_internal_value") RetrieveFileMixin = get_api_class(settings.FILE_DOWNLOADER_MODULE, "RetrieveFileMixin") -attribute_details = operator.itemgetter("code", "value") class TaxIncludedDecimalField(serializers.DecimalField): @@ -93,7 +93,7 @@ def use_pk_only_optimization(self): return False -class AttributeValueField(serializers.Field): +class AttributeValueField(AttributeFieldBase, serializers.Field): """ This field is used to handle the value of the ProductAttributeValue model @@ -105,28 +105,39 @@ class AttributeValueField(serializers.Field): def __init__(self, **kwargs): # this field always needs the full object kwargs["source"] = "*" - kwargs["error_messages"] = { - "no_such_option": _("{code}: Option {value} does not exist."), - "invalid": _("Wrong type, {error}."), - "attribute_validation_error": _( - "Error assigning `{value}` to {code}, {error}." - ), - "attribute_required": _("Attribute {code} is required."), - "attribute_missing": _( - "No attribute exist with code={code}, " - "please define it in the product_class first." - ), - "child_without_parent": _( - "Can not find attribute if product_class is empty and " - "parent is empty as well, child without parent?" - ), - } super(AttributeValueField, self).__init__(**kwargs) def get_value(self, dictionary): # return all the data because this field uses everything return dictionary + def get_data_attribute(self, data): + if "product" in data: + # we need the attribute to determine the type of the value + return ProductAttribute.objects.get( + code=data["code"], product_class__products__id=data["product"] + ) + elif "product_class" in data and data["product_class"] is not None: + return ProductAttribute.objects.get( + code=data["code"], product_class__slug=data.get("product_class") + ) + elif "parent" in data: + return ProductAttribute.objects.get( + code=data["code"], product_class__products__id=data["parent"] + ) + + def convert_to_internal_value(self, attribute, code, value): + internal_value = super().convert_to_internal_value(attribute, code, value) + if attribute.type in [ + attribute.IMAGE, + attribute.FILE, + ]: + image_field = ImageUrlField() + image_field._context = self.context + internal_value = image_field.to_internal_value(value) + + return internal_value + def to_internal_value(self, data): # noqa assert "product" in data or "product_class" in data or "parent" in data @@ -134,49 +145,9 @@ def to_internal_value(self, data): # noqa code, value = attribute_details(data) internal_value = value - if "product" in data: - # we need the attribute to determine the type of the value - attribute = ProductAttribute.objects.get( - code=code, product_class__products__id=data["product"] - ) - elif "product_class" in data and data["product_class"] is not None: - attribute = ProductAttribute.objects.get( - code=code, product_class__slug=data.get("product_class") - ) - elif "parent" in data: - attribute = ProductAttribute.objects.get( - code=code, product_class__products__id=data["parent"] - ) + attribute = self.get_data_attribute(data) - if attribute.required and value is None: - self.fail("attribute_required", code=code) - - # some of these attribute types need special processing, or their - # validation will fail - if attribute.type == attribute.OPTION: - internal_value = attribute.option_group.options.get(option=value) - elif attribute.type == attribute.MULTI_OPTION: - if attribute.required and not value: - self.fail("attribute_required", code=code) - internal_value = attribute.option_group.options.filter(option__in=value) - if len(value) != internal_value.count(): - non_existing = set(value) - set( - internal_value.values_list("option", flat=True) - ) - non_existing_as_error = ",".join(sorted(non_existing)) - self.fail("no_such_option", value=non_existing_as_error, code=code) - elif attribute.type == attribute.DATE: - date_field = serializers.DateField() - internal_value = date_field.to_internal_value(value) - elif attribute.type == attribute.DATETIME: - date_field = serializers.DateTimeField() - internal_value = date_field.to_internal_value(value) - elif attribute.type == attribute.ENTITY: - internal_value = entity_internal_value(attribute, value) - elif attribute.type in [attribute.IMAGE, attribute.FILE]: - image_field = ImageUrlField() - image_field._context = self.context - internal_value = image_field.to_internal_value(value) + internal_value = self.convert_to_internal_value(attribute, code, value) # the rest of the attribute types don't need special processing try: diff --git a/oscarapi/serializers/product.py b/oscarapi/serializers/product.py index 491996c6..e0d4d451 100644 --- a/oscarapi/serializers/product.py +++ b/oscarapi/serializers/product.py @@ -16,6 +16,7 @@ from oscarapi.utils.files import file_hash from oscarapi.utils.exists import find_existing_attribute_option_group from oscarapi.utils.accessors import getitems +from oscarapi.utils.attributes import AttributeConverter from oscarapi.serializers.fields import DrillDownHyperlinkedIdentityField from oscarapi.serializers.utils import ( OscarModelSerializer, @@ -197,25 +198,63 @@ class Meta: class ProductAttributeValueListSerializer(UpdateListSerializer): def to_internal_value(self, data): productclasses = set() - # attributes = set() + attributes = set() for item in data: product_class, code = getitems(item, "product_class", "code") if product_class: productclasses.add(product_class) + attributes.add(code) # if all attributes belong to the same productclass, everything is just # as expected and we can do an optimization by only resolving the productclass to the model instance and nothing else. try: - if len(productclasses) == 1: + if len(productclasses) == 1 and all(attributes): (product_class,) = productclasses pc = ProductClass.objects.get(slug=product_class) - return [ - {"value": item["value"], "attribute": item["code"], "product_class": pc} - for item in data - ] + difficult_attributes = { + at.code: at + for at in pc.attributes.filter( + type__in=[ + ProductAttribute.OPTION, + ProductAttribute.MULTI_OPTION, + ProductAttribute.DATE, + ProductAttribute.DATETIME, + ProductAttribute.ENTITY, + ] + ) + } + cv = AttributeConverter(self.context) + internal_value = [] + for item in data: + code, value = getitems(item, "code", "value") + if code is None: + internal_value.append(self.child.to_internal_value(item)) + + if code in difficult_attributes: + attribute = difficult_attributes[code] + converted_value = cv.convert_to_internal_value( + attribute, code, value + ) + internal_value.append( + { + "value": converted_value, + "attribute": attribute, + "product_class": pc, + } + ) + else: + internal_value.append( + { + "value": value, + "attribute": code, + "product_class": pc, + } + ) + return internal_value + except ProductClass.DoesNotExist: - raise Exception("productclasses", "bestaat niet", productclasses) + pass return super().to_internal_value(data) @@ -263,7 +302,7 @@ def update(self, instance, validated_data): # if we don't clear the dirty attributes all parent attributes # are marked as explicitly set, so they will be copied to the # child product. - product.attr._dirty.clear() + product.attr._dirty.clear() # pylint: disable=protected-access product.attr.save() return list(product.attr.get_values().filter(attribute__code__in=attr_codes)) diff --git a/oscarapi/tests/unit/testproduct.py b/oscarapi/tests/unit/testproduct.py index 8abedbf6..3a0771dd 100644 --- a/oscarapi/tests/unit/testproduct.py +++ b/oscarapi/tests/unit/testproduct.py @@ -994,6 +994,7 @@ def test_switch_product_class_patch(self): they may cause errors. """ product = Product.objects.get(pk=3) + self.assertEqual(product.attribute_values.count(), 11) ser = AdminProductSerializer( data={ "product_class": "t-shirt", @@ -1006,6 +1007,7 @@ def test_switch_product_class_patch(self): ) self.assertTrue(ser.is_valid(), "Something wrong %s" % ser.errors) obj = ser.save() + self.assertEqual( obj.attribute_values.count(), 2, diff --git a/oscarapi/utils/attributes.py b/oscarapi/utils/attributes.py new file mode 100644 index 00000000..c97376bd --- /dev/null +++ b/oscarapi/utils/attributes.py @@ -0,0 +1,78 @@ +import operator +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers +from rest_framework.fields import MISSING_ERROR_MESSAGE +from rest_framework.exceptions import ErrorDetail +from oscarapi.utils.loading import get_api_class + +attribute_details = operator.itemgetter("code", "value") +entity_internal_value = get_api_class("serializers.hooks", "entity_internal_value") + + +class AttributeFieldBase: + default_error_messages = { + "no_such_option": _("{code}: Option {value} does not exist."), + "invalid": _("Wrong type, {error}."), + "attribute_validation_error": _( + "Error assigning `{value}` to {code}, {error}." + ), + "attribute_required": _("Attribute {code} is required."), + "attribute_missing": _( + "No attribute exist with code={code}, " + "please define it in the product_class first." + ), + "child_without_parent": _( + "Can not find attribute if product_class is empty and " + "parent is empty as well, child without parent?" + ), + } + + def convert_to_internal_value(self, attribute, code, value): + internal_value = value + # pylint: disable=no-member + if attribute.required and value is None: + self.fail("attribute_required", code=code) + + # some of these attribute types need special processing, or their + # validation will fail + if attribute.type == attribute.OPTION: + internal_value = attribute.option_group.options.get(option=value) + elif attribute.type == attribute.MULTI_OPTION: + if attribute.required and not value: + self.fail("attribute_required", code=code) + internal_value = attribute.option_group.options.filter(option__in=value) + if len(value) != internal_value.count(): + non_existing = set(value) - set( + internal_value.values_list("option", flat=True) + ) + non_existing_as_error = ",".join(sorted(non_existing)) + self.fail("no_such_option", value=non_existing_as_error, code=code) + elif attribute.type == attribute.DATE: + date_field = serializers.DateField() + internal_value = date_field.to_internal_value(value) + elif attribute.type == attribute.DATETIME: + date_field = serializers.DateTimeField() + internal_value = date_field.to_internal_value(value) + elif attribute.type == attribute.ENTITY: + internal_value = entity_internal_value(attribute, value) + + return internal_value + + +class AttributeConverter(AttributeFieldBase): + def __init__(self, context): + self.context = context + self.errors = [] + + def fail(self, key, **kwargs): + """ + An implementation of fail that collects errors instead of raising them + """ + try: + msg = self.default_error_messages[key] + except KeyError: + class_name = self.__class__.__name__ + msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) + raise AssertionError(msg) + message_string = msg.format(**kwargs) + self.errors.append(ErrorDetail(message_string, code=key))