From 7f9cd469dc30a57b9a474b2d310d45357480d565 Mon Sep 17 00:00:00 2001 From: Daan van der Kallen Date: Thu, 14 Nov 2024 16:18:04 +0100 Subject: [PATCH] Fix after filter with null --- binder/views.py | 55 ++++++++++++++++++++++++++------------------- tests/test_after.py | 51 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/binder/views.py b/binder/views.py index 66f355da..972692cf 100644 --- a/binder/views.py +++ b/binder/views.py @@ -21,7 +21,7 @@ from django.http.request import RawPostDataException from django.http.multipartparser import MultiPartParser from django.db import models, connections -from django.db.models import Q, F, Count +from django.db.models import Q, F, Count, Case, When from django.db.models.lookups import Transform from django.utils import timezone from django.db import transaction @@ -1569,8 +1569,7 @@ def _after_expr(self, request, after_id, include_annotations): raise BinderRequestError(f'invalid value for after_id: {after_id!r}') # Now we will build up a comparison expr based on the order by - left_exprs = [] - right_exprs = [] + whens = [] for field in ordering: # First we have to split of a leading '-' as indicating reverse @@ -1578,30 +1577,40 @@ def _after_expr(self, request, after_id, include_annotations): if reverse: field = field[1:] - # Then we build 2 exprs for the left hand side (objs in the query) - # and the right hand side (the object with the provided after id) - left_expr = F(field) + # Then we determine if nulls come last + if field.endswith('__nulls_last'): + field = field[:-12] + nulls_last = True + elif field.endswith('__nulls_first'): + field = field[:-13] + nulls_last = False + elif connections[self.model.objects.db].vendor == 'mysql': + # In MySQL null is considered to be the lowest possible value for ordering + nulls_last = reverse + else: + # In other databases null is considered to be the highest possible value for ordering + nulls_last = not reverse - right_expr = obj + # Then we determine what the value is for the obj we need to be after + value = obj for attr in field.split('__'): - right_expr = getattr(right_expr, attr) - if isinstance(right_expr, models.Model): - right_expr = right_expr.pk - right_expr = Value(right_expr) - - # To handle reverse we flip the expressions - if reverse: - left_exprs.append(right_expr) - right_exprs.append(left_expr) + value = getattr(value, attr) + if isinstance(value, models.Model): + value = value.pk + + # Now we add some conditions for the comparison + if value is None: + # If the value is None, that means we have to add a condition for when the field is not None because only then it is different + # What the result should be in that case is determined by nulls last + whens.append(When(Q(**{field + '__isnull': False}), then=Value(not nulls_last))) else: - left_exprs.append(left_expr) - right_exprs.append(right_expr) + # If the field is None we give a result based on nulls last + whens.append(When(Q(**{field: None}), then=Value(nulls_last))) + # Otherwise we check with comparisons, note that equality is intentionally left open with these two options so in that case we go on to the next field + whens.append(When(Q(**{field + '__lt': value}), then=Value(reverse))) + whens.append(When(Q(**{field + '__gt': value}), then=Value(not reverse))) - # Now we turn this into one big comparison - if len(ordering) == 1: - expr = GreaterThan(left_exprs[0], right_exprs[0]) - else: - expr = GreaterThan(Tuple(*left_exprs), Tuple(*right_exprs)) + expr = Case(*whens, default=Value(False)) return expr, required_annotations diff --git a/tests/test_after.py b/tests/test_after.py index 7a4c65ec..9597ca8d 100644 --- a/tests/test_after.py +++ b/tests/test_after.py @@ -1,3 +1,6 @@ +import unittest +import os + from django.contrib.auth.models import User from django.test import TestCase @@ -12,12 +15,12 @@ def setUp(self): self.mapping = {} zoo1 = Zoo.objects.create(name='Zoo 2') - self.mapping[Animal.objects.create(name='Animal F', zoo=zoo1).id] = 'f' + self.mapping[Animal.objects.create(name='Animal F', zoo=zoo1, birth_date='1997-03-19').id] = 'f' self.mapping[Animal.objects.create(name='Animal E', zoo=zoo1).id] = 'e' self.mapping[Animal.objects.create(name='Animal D', zoo=zoo1).id] = 'd' zoo2 = Zoo.objects.create(name='Zoo 1') - self.mapping[Animal.objects.create(name='Animal C', zoo=zoo2).id] = 'c' + self.mapping[Animal.objects.create(name='Animal C', zoo=zoo2, birth_date='2000-08-05').id] = 'c' self.mapping[Animal.objects.create(name='Animal B', zoo=zoo2).id] = 'b' self.mapping[Animal.objects.create(name='Animal A', zoo=zoo2).id] = 'a' @@ -48,17 +51,53 @@ def test_default(self): self.assertEqual(self.get(after='d'), 'cba') def test_ordered(self): - self.assertEqual(self.get('name', ), 'abcdef') + self.assertEqual(self.get('name'), 'abcdef') self.assertEqual(self.get('name', after='c'), 'def') def test_ordered_relation(self): - self.assertEqual(self.get('zoo,name', ), 'defabc') + self.assertEqual(self.get('zoo,name'), 'defabc') self.assertEqual(self.get('zoo,name', after='f'), 'abc') def test_ordered_reverse(self): - self.assertEqual(self.get('-name', ), 'fedcba') + self.assertEqual(self.get('-name'), 'fedcba') self.assertEqual(self.get('-name', after='d'), 'cba') def test_ordered_relation_field(self): - self.assertEqual(self.get('zoo.name', ), 'cbafed') + self.assertEqual(self.get('zoo.name'), 'cbafed') self.assertEqual(self.get('zoo.name', after='a'), 'fed') + + @unittest.skipIf( + os.environ.get('BINDER_TEST_MYSQL', '0') != '0', + "Only available with PostgreSQL" + ) + def test_ordered_with_null(self): + self.assertEqual(self.get('birth_date'), 'fcedba') + self.assertEqual(self.get('birth_date', after='f'), 'cedba') + self.assertEqual(self.get('birth_date', after='e'), 'dba') + + @unittest.skipIf( + os.environ.get('BINDER_TEST_MYSQL', '0') != '0', + "Only available with PostgreSQL" + ) + def test_ordered_with_null_reversed(self): + self.assertEqual(self.get('-birth_date'), 'edbacf') + self.assertEqual(self.get('-birth_date', after='c'), 'f') + self.assertEqual(self.get('-birth_date', after='b'), 'acf') + + @unittest.skipIf( + os.environ.get('BINDER_TEST_MYSQL', '0') == '0', + "Only available with MySQL" + ) + def test_ordered_with_null_mysql(self): + self.assertEqual(self.get('birth_date'), 'edbafc') + self.assertEqual(self.get('birth_date', after='f'), 'c') + self.assertEqual(self.get('birth_date', after='e'), 'dbafc') + + @unittest.skipIf( + os.environ.get('BINDER_TEST_MYSQL', '0') == '0', + "Only available with MySQL" + ) + def test_ordered_with_null_reversed_mysql(self): + self.assertEqual(self.get('-birth_date'), 'cfedba') + self.assertEqual(self.get('-birth_date', after='c'), 'fedba') + self.assertEqual(self.get('-birth_date', after='b'), 'a')