diff --git a/packages/valory/skills/staking_abci/tests/test_rounds.py b/packages/valory/skills/staking_abci/tests/test_rounds.py index e55d54daf..2d07da625 100644 --- a/packages/valory/skills/staking_abci/tests/test_rounds.py +++ b/packages/valory/skills/staking_abci/tests/test_rounds.py @@ -19,19 +19,18 @@ """This package contains the tests for rounds of StakingAbciApp.""" -import json from dataclasses import dataclass, field -from unittest.mock import MagicMock -from typing import Any, Callable, Dict, FrozenSet, Hashable, Mapping, Optional, List +from typing import Any, Callable, Dict, FrozenSet, Hashable, Mapping, List, Type from unittest import mock +from unittest.mock import MagicMock import pytest -from packages.valory.skills.abstract_round_abci.base import( - BaseTxPayload, +from packages.valory.skills.abstract_round_abci.base import ( + BaseTxPayload, AbciAppDB, - get_name - ) + get_name, CollectSameUntilThresholdRound +) from packages.valory.skills.abstract_round_abci.test_tools.rounds import ( BaseCollectSameUntilThresholdRoundTest, ) @@ -43,49 +42,42 @@ CheckpointCallPreparedRound, FinishedStakingRound, ServiceEvictedRound, - StakingAbciApp + StakingAbciApp, StakingState ) + @pytest.fixture def abci_app() -> StakingAbciApp: """Fixture for StakingAbciApp.""" synchronized_data = MagicMock() logger = MagicMock() context = MagicMock() - + return StakingAbciApp( synchronized_data=synchronized_data, logger=logger, context=context ) + DUMMY_SERVICE_STATE = { - "service_staking_state": 0, # Assuming 0 means UNSTAKED + "service_staking_state": StakingState.UNSTAKED.value, "tx_submitter": "dummy_submitter", - "most_voted_tx_hash": "dummy_tx_hash", + "tx_hash": "dummy_tx_hash", } -DUMMY_PARTICIPANT_TO_CHECKPOINT = { - "agent_0": "checkpoint_0", - "agent_1": "checkpoint_1", -} def get_participants() -> FrozenSet[str]: """Participants""" return frozenset([f"agent_{i}" for i in range(MAX_PARTICIPANTS)]) -def get_payloads( - payload_cls: BaseTxPayload, - data: Optional[str], -) -> Mapping[str, BaseTxPayload]: +def get_checkpoint_payloads(data: Dict) -> Mapping[str, CallCheckpointPayload]: """Get payloads.""" return { - participant: payload_cls( + participant: CallCheckpointPayload( participant, - tx_submitter="dummy_submitter", - tx_hash="dummy_tx_hash", - service_staking_state=0 + **data, ) for participant in get_participants() } @@ -110,6 +102,7 @@ class RoundTestCase: class BaseStakingRoundTestClass(BaseCollectSameUntilThresholdRoundTest): """Base test class for Staking rounds.""" + round_class: Type[CollectSameUntilThresholdRound] synchronized_data: SynchronizedData _synchronized_data_class = SynchronizedData _event_class = Event @@ -134,9 +127,6 @@ def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None: exit_event=test_case.event, ) - # Debugging line: print result after running the test - print(f"Test case {test_case.name} result: {result}") - self._complete_run(result) @@ -151,32 +141,30 @@ class TestCallCheckpointRound(BaseStakingRoundTestClass): RoundTestCase( name="Happy path", initial_data={}, - payloads=get_payloads( - payload_cls=CallCheckpointPayload, - data=json.dumps(DUMMY_SERVICE_STATE), - ), + payloads=get_checkpoint_payloads({ + "service_staking_state": StakingState.STAKED.value, + "tx_submitter": "dummy_submitter", + "tx_hash": "dummy_tx_hash", + }), final_data={ - "service_staking_state": 0, + "service_staking_state": StakingState.STAKED.value, "tx_submitter": "dummy_submitter", - "most_voted_tx_hash": "dummy_tx_hash", + "tx_hash": "dummy_tx_hash", }, event=Event.DONE, - most_voted_payload=json.dumps(DUMMY_SERVICE_STATE), + most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"], synchronized_data_attr_checks=[ - lambda synchronized_data: synchronized_data.service_staking_state == 0, + lambda synchronized_data: synchronized_data.service_staking_state == StakingState.STAKED.value, lambda synchronized_data: synchronized_data.tx_submitter == "dummy_submitter", ], ), RoundTestCase( name="Service not staked", initial_data={}, - payloads=get_payloads( - payload_cls=CallCheckpointPayload, - data=json.dumps(DUMMY_SERVICE_STATE), - ), + payloads=get_checkpoint_payloads(DUMMY_SERVICE_STATE), final_data={}, event=Event.SERVICE_NOT_STAKED, - most_voted_payload=json.dumps(DUMMY_SERVICE_STATE), + most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"], synchronized_data_attr_checks=[ lambda synchronized_data: synchronized_data.service_staking_state == 0, ], @@ -184,13 +172,14 @@ class TestCallCheckpointRound(BaseStakingRoundTestClass): RoundTestCase( name="Service evicted", initial_data={}, - payloads=get_payloads( - payload_cls=CallCheckpointPayload, - data=json.dumps(DUMMY_SERVICE_STATE), - ), + payloads=get_checkpoint_payloads({ + "service_staking_state": StakingState.EVICTED.value, + "tx_submitter": "dummy_submitter", + "tx_hash": "dummy_tx_hash", + }), final_data={}, event=Event.SERVICE_EVICTED, - most_voted_payload=json.dumps(DUMMY_SERVICE_STATE), + most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"], synchronized_data_attr_checks=[ lambda synchronized_data: synchronized_data.service_staking_state == 0, ], @@ -198,13 +187,14 @@ class TestCallCheckpointRound(BaseStakingRoundTestClass): RoundTestCase( name="Next checkpoint not reached", initial_data={}, - payloads=get_payloads( - payload_cls=CallCheckpointPayload, - data=json.dumps(DUMMY_SERVICE_STATE), - ), + payloads=get_checkpoint_payloads({ + "service_staking_state": StakingState.STAKED.value, + "tx_submitter": "dummy_submitter", + "tx_hash": None, + }), final_data={}, event=Event.NEXT_CHECKPOINT_NOT_REACHED_YET, - most_voted_payload=json.dumps(DUMMY_SERVICE_STATE), + most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"], synchronized_data_attr_checks=[ lambda synchronized_data: synchronized_data.service_staking_state == 0, ], @@ -212,19 +202,13 @@ class TestCallCheckpointRound(BaseStakingRoundTestClass): ], ) def test_run(self, test_case: RoundTestCase) -> None: - - # Run the test - self.run_test(test_case) - - def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None: - """Run the test with added debugging.""" + """Run the test.""" self.synchronized_data.update(**test_case.initial_data) test_round = self.round_class( synchronized_data=self.synchronized_data, context=mock.MagicMock() ) - print("Starting _test_round...") result = self._test_round( test_round=test_round, round_payloads=test_case.payloads, @@ -236,13 +220,9 @@ def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None: exit_event=test_case.event, ) - # Debugging line: print result after running the test - print(f"Test case {test_case.name} result: {result}") - self._complete_run(result) - class TestCheckpointCallPreparedRound: """Tests for CheckpointCallPreparedRound.""" @@ -313,6 +293,7 @@ def test_staking_abci_app_initialization(abci_app: StakingAbciApp) -> None: }, } + def test_synchronized_data_initialization() -> None: """Test the initialization and attributes of SynchronizedData.""" data = SynchronizedData(db=AbciAppDB(setup_data={"test": ["test"]}))