diff --git a/django_opensearch_dsl/registries.py b/django_opensearch_dsl/registries.py index d130568..b0f904a 100644 --- a/django_opensearch_dsl/registries.py +++ b/django_opensearch_dsl/registries.py @@ -2,6 +2,8 @@ from copy import deepcopy from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist +from django.db.models import Model +from opensearchpy.helpers.document import Document as DSLDocument from opensearchpy.helpers.utils import AttrDict from .apps import DODConfig @@ -160,6 +162,10 @@ def delete(self, instance, **kwargs): """ self.update(instance, action="delete", **kwargs) + def get_models(self): + """Get all models in the registry.""" + return set(self._models.keys()) + def get_indices(self, models=None): """Get all indices in the registry or the indices for a list of models.""" if models is not None: @@ -167,5 +173,13 @@ def get_indices(self, models=None): return set(self._indices.keys()) + def __contains__(self, obj): + """Check that a model is in the registry.""" + if issubclass(obj, Model): + return obj in self._models or obj in self._related_models + raise TypeError( + f"'in <{type(self).__name__}>' requires a Model subclass as left operand, not {type(dict).__name__}" + ) + registry = DocumentRegistry() diff --git a/tests/tests/fixtures.py b/tests/tests/fixtures.py index a84aed0..fec616c 100644 --- a/tests/tests/fixtures.py +++ b/tests/tests/fixtures.py @@ -28,6 +28,10 @@ class ModelE(models.Model): class Meta: app_label = "foo" + class ModelF(models.Model): + class Meta: + app_label = "foo" + def _generate_doc_mock(self, _model, index=None, mock_qs=None, _ignore_signals=False, _related_models=None): _index = index diff --git a/tests/tests/test_registries.py b/tests/tests/test_registries.py index e73e60c..5d4ced4 100644 --- a/tests/tests/test_registries.py +++ b/tests/tests/test_registries.py @@ -1,7 +1,6 @@ from unittest import TestCase, mock from unittest.mock import Mock -from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.test import override_settings @@ -30,34 +29,41 @@ def test_empty_registry(self): self.assertEqual(registry._models, {}) def test_register(self): - self.assertEqual(self.registry._models[self.ModelA], set([self.doc_a1, self.doc_a2])) - self.assertEqual(self.registry._models[self.ModelB], set([self.doc_b1])) + self.assertEqual(self.registry._models[self.ModelA], {self.doc_a1, self.doc_a2}) + self.assertEqual(self.registry._models[self.ModelB], {self.doc_b1}) self.assertEqual( - self.registry._indices[self.index_1], set([self.doc_a1, self.doc_a2, self.doc_c1, self.doc_d1, self.doc_e1]) + self.registry._indices[self.index_1], {self.doc_a1, self.doc_a2, self.doc_c1, self.doc_d1, self.doc_e1} ) - self.assertEqual(self.registry._indices[self.index_2], set([self.doc_b1])) + self.assertEqual(self.registry._indices[self.index_2], {self.doc_b1}) def test_register_with_related_models(self): - self.assertEqual(self.registry._related_models[self.ModelE], set([self.ModelD])) + self.assertEqual(self.registry._related_models[self.ModelE], {self.ModelD}) def test_get_related_doc(self): instance = self.ModelE() related_set = set() for doc in self.registry._get_related_doc(instance): related_set.add(doc) - self.assertEqual(related_set, set([self.doc_d1])) + self.assertEqual(related_set, {self.doc_d1}) def test_get_indices(self): - self.assertEqual(self.registry.get_indices(), set([self.index_1, self.index_2])) + self.assertEqual(self.registry.get_indices(), {self.index_1, self.index_2}) def test_get_indices_by_model(self): - self.assertEqual(self.registry.get_indices([self.ModelA]), set([self.index_1])) + self.assertEqual(self.registry.get_indices([self.ModelA]), {self.index_1}) def test_get_indices_by_unregister_model(self): ModelC = Mock() self.assertFalse(self.registry.get_indices([ModelC])) + def test_get_models(self): + self.assertEqual(self.registry.get_models(), {self.ModelA, self.ModelB, self.ModelC, self.ModelD, self.ModelE}) + + def test_contains(self): + self.assertIn(self.ModelA, self.registry) + self.assertNotIn(self.ModelF, self.registry) + def test_update_instance(self): doc_a3 = self._generate_doc_mock(self.ModelA, self.index_1, _ignore_signals=True)