From e4959bf2cb6814bb61cc2b4d065c7e9ba171c8a3 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Mon, 6 Nov 2023 11:21:21 +0100 Subject: [PATCH] Add generic stats endpoint --- binder/router.py | 8 +++ binder/views.py | 145 +++++++++++++++++++++++++++++++++++++++++++- tests/test_stats.py | 75 +++++++++++++++++++++++ 3 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 tests/test_stats.py diff --git a/binder/router.py b/binder/router.py index 8e4bb9d6..ff2b2951 100644 --- a/binder/router.py +++ b/binder/router.py @@ -184,6 +184,14 @@ def urls(self): urls.append(re_path(r'^{}/(?P[0-9]+)/{}/$'.format(route.route, ff), view.as_view(), {'file_field': ff, 'router': self}, name='{}.{}'.format(name, ff))) + # Stats endpoint + urls.append(re_path( + r'^{}/stats/$'.format(route.route), + view.as_view(), + {'method': 'stats', 'router': self}, + name='{}.stats'.format(name), + )) + # Custom endpoints for m in dir(view): method = getattr(view, m) diff --git a/binder/views.py b/binder/views.py index 3ec6ac20..3285f7a9 100644 --- a/binder/views.py +++ b/binder/views.py @@ -20,7 +20,7 @@ from django.http.request import RawPostDataException from django.http.multipartparser import MultiPartParser from django.db import models, connections -from django.db.models import Q, F +from django.db.models import Q, F, Count, Sum, Min, Max, Avg from django.db.models.lookups import Transform from django.utils import timezone from django.db import transaction @@ -35,6 +35,15 @@ from .json import JsonResponse, jsonloads +STAT_AGGREGATES = { + 'count': Count, + 'sum': Sum, + 'min': Min, + 'max': Max, + 'average': Avg, +} + + def get_joins_from_queryset(queryset): """ Given a queryset returns a set of lines that are used to determine which @@ -2834,6 +2843,140 @@ def view_history(self, request, pk=None, **kwargs): return history.view_changesets(request, changesets.order_by('-id')) + def stats(self, request): + # We only apply annotations when used, so we can just pretend everything is included to simplify stuff + try: + annotations = self.model.Annotations + except AttributeError: + include_annotations = {'': []} + else: + include_annotations = {'': [ + attr + for attr in dir(annotations) + if not (attr.startswith('__') and attr.endswith('__')) + ]} + + queryset, annotations = self._get_filtered_queryset_base(request, None, include_annotations) + + try: + stats = self._parse_stats(request.GET['stats'], include_annotations['']) + except KeyError: + raise BinderRequestError('no stats parameter provided') + + return JsonResponse({ + key: self._get_stat(request, queryset, annotations, include_annotations, **stat) + for key, stat in stats.items() + }) + + + def _parse_stats(self, stats, annotations): + stats = jsonloads(stats) + + if not isinstance(stats, dict): + raise BinderRequestError('stats should be a dictionary') + + errors = [] + for stat, params in stats.items(): + if not isinstance(params, dict): + errors.append(f'stats.{stat} should be a dictionary') + continue + + for key, value in params.items(): + if key == 'field': + params['field'] = self._check_field(stat, key, value, annotations, errors) + + elif key == 'aggregate': + try: + params['aggregate'] = STAT_AGGREGATES[value] + except (ValueError, KeyError): + errors.append(f'stats.{stat}.aggregate is not a valid aggregate') + + elif key == 'group_by': + params['group_by'] = self._check_field(stat, key, value, annotations, errors) + + elif key == 'filters': + if not isinstance(params, dict): + errors.append(f'stats.{stat}.filters should be a dictionary') + + else: + errors.append(f'stats.{stat}.{key} is not a valid key') + + if errors: + raise BinderRequestError('\n'.join(errors)) + + return stats + + def _check_field(self, stat, key, value, annotations, errors): + if not isinstance(value, str): + errors.append(f'stats.{stat}.{key} should be a string') + return + + model = self.model + parts = [] + while True: + try: + head, value = value.split('.', 1) + except ValueError: + break + + try: + field = model._meta.get_field(head) + assert field.is_relation + except (FieldDoesNotExist, AssertionError): + errors.append(f'stats.{stat}.{key} references relation {model.__name__}.{head} that does not exist') + return + + parts.append(head) + if isinstance(field, django.db.models.fields.reverse_related.ForeignObjectRel): + model = field.related_model + else: + model = field.remote_field.model + + try: + if parts or value not in annotations: + model._meta.get_field(value) + except FieldDoesNotExist: + errors.append(f'stats.{stat}.{key} references field {model.__name__}.{value} that does not exist') + return + + parts.append(value) + return '__'.join(parts) + + + def _get_stat(self, request, queryset, annotations, include_annotations, field='id', aggregate=Count, group_by=None, filters={}): + for key, value in filters.items(): + q, distinct = self._parse_filter(key, value, request, include_annotations) + queryset = self._apply_q_with_possible_annotations(queryset, q, annotations) + if distinct: + queryset = queryset.distinct() + queryset = self._apply_annotations(queryset, annotations, field, group_by) + + if group_by is None: + return queryset.aggregate(result=aggregate(field))['result'] + else: + return dict( + queryset + .order_by() + .values(group_by) + .annotate(_binder_stats_aggregate=aggregate(field)) + .values_list(group_by, '_binder_stats_aggregate') + ) + + + def _apply_annotations(self, queryset, annotations, *fields): + for field in fields: + if field is None: + continue + field = field.split('__', 1)[0] + try: + annotation = annotations.pop(field) + except KeyError: + pass + else: + queryset = queryset.annotate(**{field: annotation}) + return queryset + + def api_catchall(request): try: diff --git a/tests/test_stats.py b/tests/test_stats.py new file mode 100644 index 00000000..fd0f2069 --- /dev/null +++ b/tests/test_stats.py @@ -0,0 +1,75 @@ +import json + +from django.test import TestCase +from django.contrib.auth.models import User + +from .testapp.models import Animal, Caretaker, Zoo + + +class StatsTest(TestCase): + + def setUp(self): + zoo_1 = Zoo.objects.create(name='Zoo 1') + zoo_2 = Zoo.objects.create(name='Zoo 2') + + caretaker = Caretaker.objects.create(name='Caretaker') + + Animal.objects.create(name='Animal 1', zoo=zoo_1, caretaker=caretaker) + Animal.objects.create(name='Animal 2', zoo=zoo_2, caretaker=caretaker) + Animal.objects.create(name='Animal 3', zoo=zoo_2, caretaker=None) + + u = User(username='testuser', is_active=True, is_superuser=True) + u.set_password('test') + u.save() + + self.assertTrue(self.client.login(username='testuser', password='test')) + + def get_stats(self, params={}, **stats): + res = self.client.get('/animal/stats/', { + 'stats': json.dumps(stats), + **params, + }) + if res.status_code != 200: + print(res.content.decode()) + self.assertEqual(res.status_code, 200) + return json.loads(res.content) + + def test_animals_without_caretaker(self): + res = self.get_stats( + animals_without_caretaker={ + 'filters': {'caretaker:isnull': 'true'}, + }, + ) + self.assertEqual(res, { + 'animals_without_caretaker': 1, + }) + + def test_animals_by_zoo(self): + res = self.get_stats( + animals_by_zoo={ + 'group_by': 'zoo.name', + }, + ) + self.assertEqual(res, { + 'animals_by_zoo': { + 'Zoo 1': 1, + 'Zoo 2': 2, + }, + }) + + def test_stats_filtered(self): + res = self.get_stats( + total={}, + animals_without_caretaker={ + 'filters': {'caretaker:isnull': 'true'}, + }, + animals_by_zoo={ + 'group_by': 'zoo.name', + }, + params={'.zoo.name': 'Zoo 1'}, + ) + self.assertEqual(res, { + 'total': 1, + 'animals_without_caretaker': 0, + 'animals_by_zoo': {'Zoo 1': 1}, + })