Skip to content

Commit

Permalink
fix: tests logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamantios committed Sep 4, 2024
1 parent 9fa9d98 commit b74817e
Showing 1 changed file with 41 additions and 60 deletions.
101 changes: 41 additions & 60 deletions packages/valory/skills/staking_abci/tests/test_rounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@

"""This package contains the tests for rounds of StakingAbciApp."""

import json
from dataclasses import dataclass, field
from unittest.mock import MagicMock
from typing import Any, Callable, Dict, FrozenSet, Hashable, Mapping, Optional, List
from typing import Any, Callable, Dict, FrozenSet, Hashable, Mapping, List, Type
from unittest import mock
from unittest.mock import MagicMock

import pytest

from packages.valory.skills.abstract_round_abci.base import(
BaseTxPayload,
from packages.valory.skills.abstract_round_abci.base import (
BaseTxPayload,
AbciAppDB,
get_name
)
get_name, CollectSameUntilThresholdRound
)
from packages.valory.skills.abstract_round_abci.test_tools.rounds import (
BaseCollectSameUntilThresholdRoundTest,
)
Expand All @@ -43,49 +42,42 @@
CheckpointCallPreparedRound,
FinishedStakingRound,
ServiceEvictedRound,
StakingAbciApp
StakingAbciApp, StakingState
)


@pytest.fixture
def abci_app() -> StakingAbciApp:
"""Fixture for StakingAbciApp."""
synchronized_data = MagicMock()
logger = MagicMock()
context = MagicMock()

return StakingAbciApp(
synchronized_data=synchronized_data,
logger=logger,
context=context
)


DUMMY_SERVICE_STATE = {
"service_staking_state": 0, # Assuming 0 means UNSTAKED
"service_staking_state": StakingState.UNSTAKED.value,
"tx_submitter": "dummy_submitter",
"most_voted_tx_hash": "dummy_tx_hash",
"tx_hash": "dummy_tx_hash",
}

DUMMY_PARTICIPANT_TO_CHECKPOINT = {
"agent_0": "checkpoint_0",
"agent_1": "checkpoint_1",
}

def get_participants() -> FrozenSet[str]:
"""Participants"""
return frozenset([f"agent_{i}" for i in range(MAX_PARTICIPANTS)])


def get_payloads(
payload_cls: BaseTxPayload,
data: Optional[str],
) -> Mapping[str, BaseTxPayload]:
def get_checkpoint_payloads(data: Dict) -> Mapping[str, CallCheckpointPayload]:
"""Get payloads."""
return {
participant: payload_cls(
participant: CallCheckpointPayload(
participant,
tx_submitter="dummy_submitter",
tx_hash="dummy_tx_hash",
service_staking_state=0
**data,
)
for participant in get_participants()
}
Expand All @@ -110,6 +102,7 @@ class RoundTestCase:
class BaseStakingRoundTestClass(BaseCollectSameUntilThresholdRoundTest):
"""Base test class for Staking rounds."""

round_class: Type[CollectSameUntilThresholdRound]
synchronized_data: SynchronizedData
_synchronized_data_class = SynchronizedData
_event_class = Event
Expand All @@ -134,9 +127,6 @@ def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None:
exit_event=test_case.event,
)

# Debugging line: print result after running the test
print(f"Test case {test_case.name} result: {result}")

self._complete_run(result)


Expand All @@ -151,80 +141,74 @@ class TestCallCheckpointRound(BaseStakingRoundTestClass):
RoundTestCase(
name="Happy path",
initial_data={},
payloads=get_payloads(
payload_cls=CallCheckpointPayload,
data=json.dumps(DUMMY_SERVICE_STATE),
),
payloads=get_checkpoint_payloads({
"service_staking_state": StakingState.STAKED.value,
"tx_submitter": "dummy_submitter",
"tx_hash": "dummy_tx_hash",
}),
final_data={
"service_staking_state": 0,
"service_staking_state": StakingState.STAKED.value,
"tx_submitter": "dummy_submitter",
"most_voted_tx_hash": "dummy_tx_hash",
"tx_hash": "dummy_tx_hash",
},
event=Event.DONE,
most_voted_payload=json.dumps(DUMMY_SERVICE_STATE),
most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"],
synchronized_data_attr_checks=[
lambda synchronized_data: synchronized_data.service_staking_state == 0,
lambda synchronized_data: synchronized_data.service_staking_state == StakingState.STAKED.value,
lambda synchronized_data: synchronized_data.tx_submitter == "dummy_submitter",
],
),
RoundTestCase(
name="Service not staked",
initial_data={},
payloads=get_payloads(
payload_cls=CallCheckpointPayload,
data=json.dumps(DUMMY_SERVICE_STATE),
),
payloads=get_checkpoint_payloads(DUMMY_SERVICE_STATE),
final_data={},
event=Event.SERVICE_NOT_STAKED,
most_voted_payload=json.dumps(DUMMY_SERVICE_STATE),
most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"],
synchronized_data_attr_checks=[
lambda synchronized_data: synchronized_data.service_staking_state == 0,
],
),
RoundTestCase(
name="Service evicted",
initial_data={},
payloads=get_payloads(
payload_cls=CallCheckpointPayload,
data=json.dumps(DUMMY_SERVICE_STATE),
),
payloads=get_checkpoint_payloads({
"service_staking_state": StakingState.EVICTED.value,
"tx_submitter": "dummy_submitter",
"tx_hash": "dummy_tx_hash",
}),
final_data={},
event=Event.SERVICE_EVICTED,
most_voted_payload=json.dumps(DUMMY_SERVICE_STATE),
most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"],
synchronized_data_attr_checks=[
lambda synchronized_data: synchronized_data.service_staking_state == 0,
],
),
RoundTestCase(
name="Next checkpoint not reached",
initial_data={},
payloads=get_payloads(
payload_cls=CallCheckpointPayload,
data=json.dumps(DUMMY_SERVICE_STATE),
),
payloads=get_checkpoint_payloads({
"service_staking_state": StakingState.STAKED.value,
"tx_submitter": "dummy_submitter",
"tx_hash": None,
}),
final_data={},
event=Event.NEXT_CHECKPOINT_NOT_REACHED_YET,
most_voted_payload=json.dumps(DUMMY_SERVICE_STATE),
most_voted_payload=DUMMY_SERVICE_STATE["tx_submitter"],
synchronized_data_attr_checks=[
lambda synchronized_data: synchronized_data.service_staking_state == 0,
],
),
],
)
def test_run(self, test_case: RoundTestCase) -> None:

# Run the test
self.run_test(test_case)

def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None:
"""Run the test with added debugging."""
"""Run the test."""
self.synchronized_data.update(**test_case.initial_data)

test_round = self.round_class(
synchronized_data=self.synchronized_data, context=mock.MagicMock()
)

print("Starting _test_round...")
result = self._test_round(
test_round=test_round,
round_payloads=test_case.payloads,
Expand All @@ -236,13 +220,9 @@ def run_test(self, test_case: RoundTestCase, **kwargs: Any) -> None:
exit_event=test_case.event,
)

# Debugging line: print result after running the test
print(f"Test case {test_case.name} result: {result}")

self._complete_run(result)



class TestCheckpointCallPreparedRound:
"""Tests for CheckpointCallPreparedRound."""

Expand Down Expand Up @@ -313,6 +293,7 @@ def test_staking_abci_app_initialization(abci_app: StakingAbciApp) -> None:
},
}


def test_synchronized_data_initialization() -> None:
"""Test the initialization and attributes of SynchronizedData."""
data = SynchronizedData(db=AbciAppDB(setup_data={"test": ["test"]}))
Expand Down

0 comments on commit b74817e

Please sign in to comment.