diff --git a/src/objects/api/constants.py b/src/objects/api/constants.py index 0765673b..6e4e61fa 100644 --- a/src/objects/api/constants.py +++ b/src/objects/api/constants.py @@ -9,3 +9,4 @@ class Operators(models.TextChoices): lt = "lt", _("lower than") lte = "lte", _("lower than or equal to") icontains = "icontains", _("case-insensitive partial match") + in_list = "in", _("contains") diff --git a/src/objects/api/v1/filters.py b/src/objects/api/v1/filters.py index bb36d9f6..4bdb24ac 100644 --- a/src/objects/api/v1/filters.py +++ b/src/objects/api/v1/filters.py @@ -101,6 +101,10 @@ def filter_data_attrs(self, queryset, name, value: str): queryset = queryset.filter( **{f"data__{variable}__icontains": str_value} ) + elif operator == "in": + # in must be a list + values = str_value.split("|") + queryset = queryset.filter(**{f"data__{variable}__in": values}) else: # gt, gte, lt, lte operators diff --git a/src/objects/api/v2/filters.py b/src/objects/api/v2/filters.py index cccdc19a..785d8fa8 100644 --- a/src/objects/api/v2/filters.py +++ b/src/objects/api/v2/filters.py @@ -109,6 +109,10 @@ def filter_data_attrs(self, queryset, name, value: str): queryset = queryset.filter( **{f"data__{variable}__icontains": str_value} ) + elif operator == "in": + # in must be a list + values = str_value.split("|") + queryset = queryset.filter(**{f"data__{variable}__in": values}) else: # gt, gte, lt, lte operators diff --git a/src/objects/api/validators.py b/src/objects/api/validators.py index f36dd06c..50c2cd2c 100644 --- a/src/objects/api/validators.py +++ b/src/objects/api/validators.py @@ -81,9 +81,11 @@ def validate_data_attrs(value: str): } raise serializers.ValidationError(message, code=code) - if operator not in (Operators.exact, Operators.icontains) and isinstance( - string_to_value(val), str - ): + if operator not in ( + Operators.exact, + Operators.icontains, + Operators.in_list, + ) and isinstance(string_to_value(val), str): message = _( "Operator `%(operator)s` supports only dates and/or numeric values" ) % {"operator": operator} diff --git a/src/objects/tests/v1/test_filters.py b/src/objects/tests/v1/test_filters.py index c337ef93..5c74a3c5 100644 --- a/src/objects/tests/v1/test_filters.py +++ b/src/objects/tests/v1/test_filters.py @@ -328,6 +328,31 @@ def test_filter_exclude_old_records(self): data = response.json() self.assertEqual(len(data), 0) + def test_filter_in_string(self): + record = ObjectRecordFactory.create( + data={"name": "demo1"}, object__object_type=self.object_type + ) + record2 = ObjectRecordFactory.create( + data={"name": "demo2"}, object__object_type=self.object_type + ) + ObjectRecordFactory.create( + data={"name": "demo3"}, object__object_type=self.object_type + ) + + response = self.client.get(self.url, {"data_attrs": "name__in__demo1|demo2"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json() + self.assertEqual(len(data), 2) + self.assertEqual( + data[0]["url"], + f"http://testserver{reverse('object-detail', args=[record2.object.uuid])}", + ) + self.assertEqual( + data[1]["url"], + f"http://testserver{reverse('object-detail', args=[record.object.uuid])}", + ) + class FilterDateTests(TokenAuthMixin, APITestCase): @classmethod diff --git a/src/objects/tests/v2/test_filters.py b/src/objects/tests/v2/test_filters.py index 52ed5ddc..186cf0ea 100644 --- a/src/objects/tests/v2/test_filters.py +++ b/src/objects/tests/v2/test_filters.py @@ -400,6 +400,33 @@ def test_filter_date_field_gte(self): self.assertEqual(len(data), 0) + def test_filter_in_string(self): + record = ObjectRecordFactory.create( + data={"name": "demo1"}, object__object_type=self.object_type + ) + record2 = ObjectRecordFactory.create( + data={"name": "demo2"}, object__object_type=self.object_type + ) + ObjectRecordFactory.create( + data={"name": "demo3"}, object__object_type=self.object_type + ) + + response = self.client.get(self.url, {"data_attrs": "name__in__demo1|demo2"}) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + data = response.json()["results"] + + self.assertEqual(len(data), 2) + self.assertEqual( + data[0]["url"], + f"http://testserver{reverse('object-detail', args=[record2.object.uuid])}", + ) + self.assertEqual( + data[1]["url"], + f"http://testserver{reverse('object-detail', args=[record.object.uuid])}", + ) + class FilterDateTests(TokenAuthMixin, APITestCase): @classmethod