Skip to content

Commit

Permalink
feat: add get_models() and __contains__() in registries.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ghkdxofla authored and qcoumes committed Mar 18, 2024
1 parent 60f1955 commit be22b93
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
14 changes: 14 additions & 0 deletions django_opensearch_dsl/registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from copy import deepcopy

from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Model
from opensearchpy.helpers.document import Document as DSLDocument
from opensearchpy.helpers.utils import AttrDict

from .apps import DODConfig
Expand Down Expand Up @@ -160,12 +162,24 @@ def delete(self, instance, **kwargs):
"""
self.update(instance, action="delete", **kwargs)

def get_models(self):
"""Get all models in the registry."""
return set(self._models.keys())

def get_indices(self, models=None):
"""Get all indices in the registry or the indices for a list of models."""
if models is not None:
return set(index for index, docs in self._indices.items() for doc in docs if doc.django.model in models)

return set(self._indices.keys())

def __contains__(self, obj):
"""Check that a model is in the registry."""
if issubclass(obj, Model):
return obj in self._models or obj in self._related_models
raise TypeError(
f"'in <{type(self).__name__}>' requires a Model subclass as left operand, not {type(dict).__name__}"
)


registry = DocumentRegistry()
4 changes: 4 additions & 0 deletions tests/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class ModelE(models.Model):
class Meta:
app_label = "foo"

class ModelF(models.Model):
class Meta:
app_label = "foo"

def _generate_doc_mock(self, _model, index=None, mock_qs=None, _ignore_signals=False, _related_models=None):
_index = index

Expand Down
24 changes: 15 additions & 9 deletions tests/tests/test_registries.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest import TestCase, mock
from unittest.mock import Mock

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.test import override_settings

Expand Down Expand Up @@ -30,34 +29,41 @@ def test_empty_registry(self):
self.assertEqual(registry._models, {})

def test_register(self):
self.assertEqual(self.registry._models[self.ModelA], set([self.doc_a1, self.doc_a2]))
self.assertEqual(self.registry._models[self.ModelB], set([self.doc_b1]))
self.assertEqual(self.registry._models[self.ModelA], {self.doc_a1, self.doc_a2})
self.assertEqual(self.registry._models[self.ModelB], {self.doc_b1})

self.assertEqual(
self.registry._indices[self.index_1], set([self.doc_a1, self.doc_a2, self.doc_c1, self.doc_d1, self.doc_e1])
self.registry._indices[self.index_1], {self.doc_a1, self.doc_a2, self.doc_c1, self.doc_d1, self.doc_e1}
)
self.assertEqual(self.registry._indices[self.index_2], set([self.doc_b1]))
self.assertEqual(self.registry._indices[self.index_2], {self.doc_b1})

def test_register_with_related_models(self):
self.assertEqual(self.registry._related_models[self.ModelE], set([self.ModelD]))
self.assertEqual(self.registry._related_models[self.ModelE], {self.ModelD})

def test_get_related_doc(self):
instance = self.ModelE()
related_set = set()
for doc in self.registry._get_related_doc(instance):
related_set.add(doc)
self.assertEqual(related_set, set([self.doc_d1]))
self.assertEqual(related_set, {self.doc_d1})

def test_get_indices(self):
self.assertEqual(self.registry.get_indices(), set([self.index_1, self.index_2]))
self.assertEqual(self.registry.get_indices(), {self.index_1, self.index_2})

def test_get_indices_by_model(self):
self.assertEqual(self.registry.get_indices([self.ModelA]), set([self.index_1]))
self.assertEqual(self.registry.get_indices([self.ModelA]), {self.index_1})

def test_get_indices_by_unregister_model(self):
ModelC = Mock()
self.assertFalse(self.registry.get_indices([ModelC]))

def test_get_models(self):
self.assertEqual(self.registry.get_models(), {self.ModelA, self.ModelB, self.ModelC, self.ModelD, self.ModelE})

def test_contains(self):
self.assertIn(self.ModelA, self.registry)
self.assertNotIn(self.ModelF, self.registry)

def test_update_instance(self):
doc_a3 = self._generate_doc_mock(self.ModelA, self.index_1, _ignore_signals=True)

Expand Down

0 comments on commit be22b93

Please sign in to comment.