diff --git a/binder/models.py b/binder/models.py index 762f60fc..717b4866 100644 --- a/binder/models.py +++ b/binder/models.py @@ -10,6 +10,7 @@ from django import forms from django.db import models +from django.db.models import Value from django.db.models.fields.files import FieldFile, FileField from django.contrib.postgres.fields import CITextField, ArrayField, DateTimeRangeField as DTRangeField from django.core import checks @@ -29,6 +30,29 @@ from . import history +@models.CharField.register_lookup +@models.TextField.register_lookup +class FuzzyLookup(models.Lookup): + + lookup_name = 'fuzzy' + + def get_prep_lookup(self): + assert isinstance(self.rhs, str) + pattern = ['%'] + for part in self.rhs.split(): + for char in part: + if char in '%_[\\': + char.append('\\') + pattern.append(char) + pattern.append('%') + return Value(''.join(pattern)) + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + return f'{lhs} ilike {rhs} escape \'\\\'', (*lhs_params, *rhs_params) + + class DateTimeRangeField(DTRangeField): default_error_messages = { @@ -347,7 +371,7 @@ def clean_value(self, qualifier, v): class TextFieldFilter(FieldFilter): fields = [models.CharField, models.TextField] - allowed_qualifiers = [None, 'in', 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'isnull'] + allowed_qualifiers = [None, 'in', 'iexact', 'contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'isnull', 'fuzzy'] # Always valid(?) def clean_value(self, qualifier, v): diff --git a/tests/filters/test_text_filters.py b/tests/filters/test_text_filters.py new file mode 100644 index 00000000..5dd054e4 --- /dev/null +++ b/tests/filters/test_text_filters.py @@ -0,0 +1,51 @@ +import unittest, os + +from django.test import TestCase, Client +from django.contrib.auth.models import User + +from binder.json import jsonloads + +from ..testapp.models import Zoo + + +@unittest.skipIf( + os.environ.get('BINDER_TEST_MYSQL', '0') != '0', + "Only available with PostgreSQL" +) +class TextFiltersTest(TestCase): + + def setUp(self): + super().setUp() + u = User(username='testuser', is_active=True, is_superuser=True) + u.set_password('test') + u.save() + + self.client = Client() + r = self.client.login(username='testuser', password='test') + self.assertTrue(r) + + Zoo(name='Burgers Zoo').save() + Zoo(name='Artis').save() + Zoo(name='Apenheul').save() + Zoo(name='Ouwehand Zoo').save() + + def test_filter_fuzzy(self): + response = self.client.get('/zoo/', data={'.name:fuzzy': 'b zo'}) + self.assertEqual(response.status_code, 200) + result = jsonloads(response.content) + self.assertEqual(1, len(result['data'])) + self.assertEqual('Burgers Zoo', result['data'][0]['name']) + + response = self.client.get('/zoo/', data={'.name:fuzzy': ' zo '}) + self.assertEqual(response.status_code, 200) + result = jsonloads(response.content) + self.assertEqual(2, len(result['data'])) + self.assertEqual('Burgers Zoo', result['data'][0]['name']) + self.assertEqual('Ouwehand Zoo', result['data'][1]['name']) + + response = self.client.get('/zoo/', data={'.name:fuzzy': 'ar'}) + self.assertEqual(response.status_code, 200) + result = jsonloads(response.content) + self.assertEqual(response.status_code, 200) + self.assertEqual(1, len(result['data'])) + self.assertEqual('Artis', result['data'][0]['name'])