From a7b49bd13f467f469de8fb352514565293b51778 Mon Sep 17 00:00:00 2001 From: Felipe Alvarado Date: Tue, 19 Nov 2024 10:59:28 +0100 Subject: [PATCH] Add exception when deactivating a non-enabled module --- .../history/indexers/tx_processor.py | 38 ++++++++++++++++++- .../history/tests/test_tx_processor.py | 38 ++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/safe_transaction_service/history/indexers/tx_processor.py b/safe_transaction_service/history/indexers/tx_processor.py index f6834a594..156fd1385 100644 --- a/safe_transaction_service/history/indexers/tx_processor.py +++ b/safe_transaction_service/history/indexers/tx_processor.py @@ -55,6 +55,10 @@ class OwnerCannotBeRemoved(TxProcessorException): pass +class ModuleCannotBeDisabled(TxProcessorException): + pass + + class UserOperationFailed(TxProcessorException): pass @@ -306,6 +310,38 @@ def swap_owner( ) safe_message_models.SafeMessageConfirmation.objects.filter(owner=owner).delete() + def disable_module( + self, + internal_tx: InternalTx, + safe_status: SafeStatus, + module: ChecksumAddress, + ) -> None: + """ + Disables a module for a Safe by removing it from the enabled modules list. + + :param internal_tx: + :param safe_status: + :param module: + :return: + :raises ModuleCannotBeRemoved: If the module is not in the list of enabled modules. + """ + contract_address = internal_tx._from + if module not in safe_status.enabled_modules: + logger.error( + "[%s] Error processing trace=%s with tx-hash=%s. Cannot disable module=%s . " + "Current enabled modules=%s", + contract_address, + internal_tx.trace_address, + internal_tx.ethereum_tx_id, + module, + safe_status.enabled_modules, + ) + raise ModuleCannotBeDisabled( + f"Cannot disable module {module}. Current enabled modules {safe_status.enabled_modules}" + ) + + safe_status.enabled_modules.remove(module) + def store_new_safe_status( self, safe_last_status: SafeLastStatus, internal_tx: InternalTx ) -> SafeLastStatus: @@ -511,7 +547,7 @@ def __process_decoded_transaction( self.store_new_safe_status(safe_last_status, internal_tx) elif function_name == "disableModule": logger.debug("[%s] Disabling Module", contract_address) - safe_last_status.enabled_modules.remove(arguments["module"]) + self.disable_module(internal_tx, safe_last_status, arguments["module"]) self.store_new_safe_status(safe_last_status, internal_tx) elif function_name in { "execTransactionFromModule", diff --git a/safe_transaction_service/history/tests/test_tx_processor.py b/safe_transaction_service/history/tests/test_tx_processor.py index 7255de065..193354bd8 100644 --- a/safe_transaction_service/history/tests/test_tx_processor.py +++ b/safe_transaction_service/history/tests/test_tx_processor.py @@ -16,7 +16,11 @@ SafeMessageFactory, ) -from ..indexers.tx_processor import SafeTxProcessor, SafeTxProcessorProvider +from ..indexers.tx_processor import ( + ModuleCannotBeDisabled, + SafeTxProcessor, + SafeTxProcessorProvider, +) from ..models import ( InternalTxDecoded, ModuleTransaction, @@ -576,6 +580,38 @@ def test_process_module_tx(self): module_tx.value, module_internal_tx_decoded.arguments["value"] ) + def test_process_disable_module_tx(self): + safe_tx_processor = self.tx_processor + safe_last_status = SafeLastStatusFactory(nonce=0) + safe_address = safe_last_status.address + module = Account.create().address + disable_module_tx_decoded = InternalTxDecodedFactory( + function_name="disableModule", + module=module, + internal_tx___from=safe_address, + internal_tx__value=0, + ) + + with self.assertRaises(ModuleCannotBeDisabled): + safe_tx_processor.process_decoded_transaction(disable_module_tx_decoded) + + enable_module_tx_decoded = InternalTxDecodedFactory( + function_name="enableModule", + module=module, + internal_tx___from=safe_address, + internal_tx__value=0, + ) + self.assertTrue( + safe_tx_processor.process_decoded_transaction(enable_module_tx_decoded) + ) + safe_last_status = SafeLastStatus.objects.get(address=safe_address) + self.assertEqual(safe_last_status.enabled_modules, [module]) + self.assertTrue( + safe_tx_processor.process_decoded_transaction(disable_module_tx_decoded) + ) + safe_last_status = SafeLastStatus.objects.get(address=safe_address) + self.assertEqual(safe_last_status.enabled_modules, []) + def test_store_new_safe_status(self): # Create a new SafeLastStatus safe_last_status = SafeLastStatusFactory(nonce=0)