Skip to content

Commit

Permalink
Add missing locks for price recalculation (saleor#15901) (saleor#15906)
Browse files Browse the repository at this point in the history
* Add missing locks for price recalculation

* Make sure to lock records in the same order

Co-authored-by: Maciej Korycinski <[email protected]>
  • Loading branch information
maarcingebala and korycins authored Apr 29, 2024
1 parent 0d785b4 commit 899bfc2
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 16 deletions.
29 changes: 23 additions & 6 deletions saleor/discount/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,10 +1532,20 @@ def mark_active_catalogue_promotion_rules_as_dirty(channel_ids: Iterable[int]):
rules = get_active_catalogue_promotion_rules()
PromotionRuleChannel = PromotionRule.channels.through
promotion_rules = PromotionRuleChannel.objects.filter(channel_id__in=channel_ids)
rules = rules.filter(
rule_ids = rules.filter(
Exists(promotion_rules.filter(promotionrule_id=OuterRef("id")))
)
rules.update(variants_dirty=True)
).values_list("id", flat=True)

with transaction.atomic():
rule_ids_to_update = list(
PromotionRule.objects.select_for_update(of=("self",))
.filter(id__in=rule_ids, variants_dirty=False)
.order_by("pk")
.values_list("id", flat=True)
)
PromotionRule.objects.filter(id__in=rule_ids_to_update).update(
variants_dirty=True
)


def mark_catalogue_promotion_rules_as_dirty(promotion_pks: Iterable[UUID]):
Expand All @@ -1546,6 +1556,13 @@ def mark_catalogue_promotion_rules_as_dirty(promotion_pks: Iterable[UUID]):
"""
if not promotion_pks:
return
PromotionRule.objects.filter(promotion_id__in=promotion_pks).update(
variants_dirty=True
)
with transaction.atomic():
rule_ids_to_update = list(
PromotionRule.objects.select_for_update(of=(["self"]))
.filter(promotion_id__in=promotion_pks, variants_dirty=False)
.order_by("pk")
.values_list("id", flat=True)
)
PromotionRule.objects.filter(id__in=rule_ids_to_update).update(
variants_dirty=True
)
25 changes: 21 additions & 4 deletions saleor/product/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from celery.utils.log import get_task_logger
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from django.db.models import Exists, OuterRef, Q, QuerySet
from django.utils import timezone

Expand Down Expand Up @@ -185,8 +186,17 @@ def update_variant_relations_for_active_promotion_rules_task():
channel_to_product_map = _get_channel_to_products_map(
existing_variant_relation + new_rule_to_variant_list
)
with transaction.atomic():
promotion_rule_ids = list(
PromotionRule.objects.select_for_update(of=("self",))
.filter(pk__in=ids, variants_dirty=True)
.order_by("pk")
.values_list("id", flat=True)
)
PromotionRule.objects.filter(pk__in=promotion_rule_ids).update(
variants_dirty=False
)

PromotionRule.objects.filter(pk__in=ids).update(variants_dirty=False)
mark_products_in_channels_as_dirty(channel_to_product_map, allow_replica=True)
update_variant_relations_for_active_promotion_rules_task.delay()

Expand Down Expand Up @@ -224,9 +234,16 @@ def recalculate_discounted_price_for_products_task():
settings.DATABASE_CONNECTION_REPLICA_NAME
).filter(id__in=products_ids)
update_discounted_prices_for_promotion(products, only_dirty_products=True)
ProductChannelListing.objects.filter(id__in=listing_ids).update(
discounted_price_dirty=False
)
with transaction.atomic():
channel_listings_ids = list(
ProductChannelListing.objects.select_for_update(of=("self",))
.filter(id__in=listing_ids, discounted_price_dirty=True)
.order_by("pk")
.values_list("id", flat=True)
)
ProductChannelListing.objects.filter(id__in=channel_listings_ids).update(
discounted_price_dirty=False
)
recalculate_discounted_price_for_products_task.delay()


Expand Down
14 changes: 11 additions & 3 deletions saleor/product/utils/product.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict

from django.conf import settings
from django.db import transaction
from django.db.models import Exists, OuterRef, QuerySet

from ...discount.models import PromotionRule
Expand Down Expand Up @@ -121,6 +122,13 @@ def mark_products_in_channels_as_dirty(
listing_ids_to_update.append(id)

if listing_ids_to_update:
ProductChannelListing.objects.filter(id__in=listing_ids_to_update).update(
discounted_price_dirty=True
)
with transaction.atomic():
channel_listing_ids = list(
ProductChannelListing.objects.select_for_update(of=("self",))
.filter(id__in=listing_ids_to_update, discounted_price_dirty=False)
.order_by("pk")
.values_list("id", flat=True)
)
ProductChannelListing.objects.filter(id__in=channel_listing_ids).update(
discounted_price_dirty=True
)
12 changes: 9 additions & 3 deletions saleor/product/utils/variant_prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,25 @@ def _update_or_create_listings(
):
if changed_products_listings_to_update:
ProductChannelListing.objects.bulk_update(
changed_products_listings_to_update, ["discounted_price_amount"]
sorted(changed_products_listings_to_update, key=lambda listing: listing.id),
["discounted_price_amount"],
)
if changed_variants_listings_to_update:
ProductVariantChannelListing.objects.bulk_update(
changed_variants_listings_to_update, ["discounted_price_amount"]
sorted(changed_variants_listings_to_update, key=lambda listing: listing.id),
["discounted_price_amount"],
)
if changed_variant_listing_promotion_rule_to_create:
_create_variant_listing_promotion_rule(
changed_variant_listing_promotion_rule_to_create
)
if changed_variant_listing_promotion_rule_to_update:
VariantChannelListingPromotionRule.objects.bulk_update(
changed_variant_listing_promotion_rule_to_update, ["discount_amount"]
sorted(
changed_variant_listing_promotion_rule_to_update,
key=lambda listing: listing.id,
),
["discount_amount"],
)


Expand Down

0 comments on commit 899bfc2

Please sign in to comment.