From b2cd2b4fa5bb440b2ad36456a01873ab3f1813e1 Mon Sep 17 00:00:00 2001 From: Alex Dusenbery Date: Fri, 9 Feb 2024 13:09:50 -0500 Subject: [PATCH] perf: remove n+1 queries, cache stuff for enrollment task --- license_manager/apps/api/utils.py | 27 +++++++++++++++---- license_manager/apps/subscriptions/models.py | 14 ++++++++++ .../apps/subscriptions/tests/test_models.py | 9 +++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/license_manager/apps/api/utils.py b/license_manager/apps/api/utils.py index 28544875..b7f19e3c 100644 --- a/license_manager/apps/api/utils.py +++ b/license_manager/apps/api/utils.py @@ -3,6 +3,7 @@ import os import urllib import uuid +from collections import defaultdict import boto3 from django.http import Http404 @@ -151,12 +152,22 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub logger.info('[check_missing_licenses] Starting to iterate over all `user_emails`...') + # Map licenses by email across all user_emails in a single DB query. + # Also, join the plans into the queryset, so that we don't do + # one query per license down in the loops below. + license_queryset = License.objects.filter( + subscription_plan__in=subscription_plan_filter, + user_email__in=user_emails, + ).select_related( + 'subscription_plan', + ) + licenses_by_email = defaultdict(list) + for license_record in license_queryset: + licenses_by_email[license_record.user_email].append(license_record) + for email in set(user_emails): logger.info(f'[check_missing_licenses] handling user email {email}') - filtered_licenses = License.objects.filter( - subscription_plan__in=subscription_plan_filter, - user_email=email, - ) + filtered_licenses = licenses_by_email.get(email, []) logger.info('[check_missing_licenses] user licenses for email %s: %s', email, filtered_licenses) @@ -173,11 +184,17 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub logger.info('[check_missing_licenses] handling user license %s', str(user_license.uuid)) subscription_plan = user_license.subscription_plan plan_key = f'{subscription_plan.uuid}_{course_key}' + + # TODO AED 2024-02-09: I think this chunk of code is defective. + # It's only mapping plan ids to booleans, but what we really want + # to know is, for each plan *and course*, if the plan's associated catalog + # contains the course. if plan_key in subscription_plan_course_map: plan_contains_content = subscription_plan_course_map.get(plan_key) else: plan_contains_content = subscription_plan.contains_content([course_key]) subscription_plan_course_map[plan_key] = plan_contains_content + logger.info( '[check_missing_licenses] does plan (%s) contain content?: %s', str(subscription_plan.uuid), @@ -189,7 +206,7 @@ def check_missing_licenses(customer_agreement, user_emails, course_run_keys, sub 'course_run_key': course_key, 'license_uuid': str(user_license.uuid) } - # assigned, not yet activated, incliude activation URL + # assigned, not yet activated, include activation URL if user_license.status == constants.ASSIGNED: this_enrollment['activation_link'] = get_license_activation_link( enterprise_slug, diff --git a/license_manager/apps/subscriptions/models.py b/license_manager/apps/subscriptions/models.py index 1858f73b..692a314c 100644 --- a/license_manager/apps/subscriptions/models.py +++ b/license_manager/apps/subscriptions/models.py @@ -7,6 +7,7 @@ from uuid import uuid4 from django.conf import settings +from django.core.cache import cache from django.core.serializers.json import DjangoJSONEncoder from django.core.validators import MinLengthValidator from django.db import models, transaction @@ -62,6 +63,10 @@ logger = getLogger(__name__) +CONTAINS_CONTENT_CACHE_TIMEOUT = 60 * 60 + +_CACHE_MISS = object() + class CustomerAgreement(TimeStampedModel): """ @@ -736,13 +741,22 @@ def contains_content(self, content_ids): Returns: bool: Whether the given content_ids are part of the subscription. """ + cache_key = self.get_contains_content_cache_key(content_ids) + cached_value = cache.get(cache_key, _CACHE_MISS) + if cached_value is not _CACHE_MISS: + return cached_value + enterprise_catalog_client = EnterpriseCatalogApiClient() content_in_catalog = enterprise_catalog_client.contains_content_items( self.enterprise_catalog_uuid, content_ids, ) + cache.set(cache_key, content_in_catalog, timeout=CONTAINS_CONTENT_CACHE_TIMEOUT) return content_in_catalog + def get_contains_content_cache_key(self, content_ids): + return f'plan_contains_content:{self.uuid}:{content_ids}' + history = HistoricalRecords() class Meta: diff --git a/license_manager/apps/subscriptions/tests/test_models.py b/license_manager/apps/subscriptions/tests/test_models.py index ae76f8e0..ec5bda50 100644 --- a/license_manager/apps/subscriptions/tests/test_models.py +++ b/license_manager/apps/subscriptions/tests/test_models.py @@ -5,6 +5,7 @@ import ddt import freezegun import pytest +from django.core.cache import cache from django.forms import ValidationError from django.test import TestCase from requests.exceptions import HTTPError @@ -55,7 +56,15 @@ def test_contains_content(self, contains_content, mock_enterprise_catalog_client # Mock the value from the enterprise catalog client mock_enterprise_catalog_client().contains_content_items.return_value = contains_content content_ids = ['test-key', 'another-key'] + + cache.delete(self.subscription_plan.get_contains_content_cache_key(content_ids)) + assert self.subscription_plan.contains_content(content_ids) == contains_content + + # call it again to utilize the cache + assert self.subscription_plan.contains_content(content_ids) == contains_content + + # ...but assert we only used the catalog client once mock_enterprise_catalog_client().contains_content_items.assert_called_with( self.subscription_plan.enterprise_catalog_uuid, content_ids,