diff --git a/django_mock_queries/mocks.py b/django_mock_queries/mocks.py index 82d71eb..27b1119 100644 --- a/django_mock_queries/mocks.py +++ b/django_mock_queries/mocks.py @@ -13,7 +13,6 @@ from types import MethodType -from .constants import DjangoModelDeletionCollector, DjangoDbRouter from .query import MockSet # noinspection PyUnresolvedReferences @@ -330,10 +329,6 @@ class Mocker: A decorator that patches multiple class methods with a magic mock instance that does nothing. """ - shared_mocks = {} - shared_patchers = {} - shared_original = {} - def __init__(self, cls, *methods, **kwargs): self.cls = cls self.methods = methods @@ -342,8 +337,6 @@ def __init__(self, cls, *methods, **kwargs): self.inst_patchers = {} self.inst_original = {} - self.outer = kwargs.get('outer', True) - def __enter__(self): self._patch_object_methods(self.cls, *self.methods) return self @@ -359,11 +352,8 @@ def decorated(*args, **kwargs): return decorated def __exit__(self, exc_type, exc_val, exc_tb): - for key, patcher in self.inst_patchers.items(): + for patcher in self.inst_patchers.values(): patcher.stop() - if self.outer: - for key, patcher in self.shared_patchers.items(): - patcher.stop() def _key(self, method, obj=None): return '{}.{}'.format(obj or self.cls, method) @@ -374,10 +364,10 @@ def _method_obj(self, name, obj, *sources): return d[self._key(name, obj=obj)] def method(self, name, obj=None): - return self._method_obj(name, obj, self.shared_mocks, self.inst_mocks) + return self._method_obj(name, obj, self.inst_mocks) def original_method(self, name, obj=None): - return self._method_obj(name, obj, self.shared_original, self.inst_original) + return self._method_obj(name, obj, self.inst_original) def _get_source_method(self, obj, method): source_obj = obj @@ -406,20 +396,20 @@ def _patch_method(self, method_name, source_obj, source_method): return patch_object(source_obj, source_method, **mock_args) def _patch_object_methods(self, obj, *methods, **kwargs): - if kwargs.get('shared', False): - original, patchers, mocks = self.shared_original, self.shared_patchers, self.shared_mocks - else: - original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks + original, patchers, mocks = self.inst_original, self.inst_patchers, self.inst_mocks for method in methods: key = self._key(method, obj=obj) source_obj, source_method = self._get_source_method(obj, method) - original[key] = original.get(key, None) or getattr(source_obj, source_method) - patcher = self._patch_method(method, source_obj, source_method) - patchers[key] = patcher - mocks[key] = patcher.start() + if key not in original: + original[key] = getattr(source_obj, source_method) + + if key not in patchers: + patcher = self._patch_method(method, source_obj, source_method) + patchers[key] = patcher + mocks[key] = patcher.start() class ModelMocker(Mocker): @@ -427,7 +417,7 @@ class ModelMocker(Mocker): A decorator that patches django base model's db read/write methods and wires them to a MockSet. """ - default_methods = ['objects', '_do_update'] + default_methods = ['objects', '_do_update', 'delete'] if django.VERSION[0] >= 3: default_methods += ['_base_manager._insert', ] @@ -442,11 +432,8 @@ def __init__(self, cls, *methods, **kwargs): self.objects = MockSet(model=self.cls) self.objects.on('added', self._on_added) - self.state = {} - def __enter__(self): result = super(ModelMocker, self).__enter__() - self._patch_object_methods(DjangoModelDeletionCollector, 'collect', 'delete', shared=True) return result def _obj_pk(self, obj): @@ -482,23 +469,11 @@ def _do_update(self, *args, **_): else: return False - def collect(self, objects, *args, **kwargs): - model = getattr(objects, 'model', None) or objects[0] - - if not (model is self.cls or isinstance(model, self.cls)): - using = getattr(objects, 'db', None) or DjangoDbRouter.db_for_write(model._meta.model, instance=model) - self.state['collector'] = DjangoModelDeletionCollector(using=using) - - collect = self.original_method('collect', obj=DjangoModelDeletionCollector) - collect(self.state['collector'], objects, *args, **kwargs) - - self.state['model'] = model - - def delete(self, *args, **kwargs): - model = self.state.pop('model') - - if not (model is self.cls or isinstance(model, self.cls)): - delete = self.original_method('delete', obj=DjangoModelDeletionCollector) - return delete(self.state.pop('collector'), *args, **kwargs) - else: - return self.objects.filter(pk=getattr(model, self.cls._meta.pk.attname)).delete() + def delete(self, *_args, **_kwargs): + pk = self._obj_pk(self.objects[0]) + if not pk: + raise ValueError( + f"{self.cls._meta.object_name} object can't be deleted because " + f'its {self.cls._meta.pk.attname} attribute is set to None.' + ) + return self.objects.filter(pk=pk).delete() diff --git a/tests/test_mocks.py b/tests/test_mocks.py index 5505c85..08367f5 100644 --- a/tests/test_mocks.py +++ b/tests/test_mocks.py @@ -486,6 +486,27 @@ def car_added(obj): self.assertIsInstance(objects['added'], Car) self.assertEqual(objects['added'], objects['car']) + def test_model_mocker_delete_from_instance_with_nested_context_manager(self): + + def create_delete_models(): + car = Car.objects.create(speed=10) + car.delete() + + manufacturer = Manufacturer.objects.create(name='foo') + manufacturer.delete() + + def models_exist(): + return Manufacturer.objects.exists() or Car.objects.exists() + + with ModelMocker(Manufacturer), ModelMocker(Car): + create_delete_models() + assert not models_exist() + + # Test same scenario with reversed context manager order + with ModelMocker(Car), ModelMocker(Manufacturer): + create_delete_models() + assert not models_exist() + def test_model_mocker_event_updated_from_manager(self): objects = {}