Skip to content

Commit

Permalink
check runs
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravleen-Solulab committed Oct 16, 2024
1 parent a351372 commit 43e5c15
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#
# ------------------------------------------------------------------------------


"""This package contains the tests for Decision Maker"""

import json
Expand All @@ -34,153 +33,156 @@
SynchronizedData,
TxPreparationRound,
)
from packages.valory.skills.mech_interact_abci.states.base import MechMetadata


# Updated MechMetadata class to handle the correct parameters
class MechMetadata:
def __init__(self, request_id: str, data: str):
"""The class for test of Mech Data"""

def __init__(self, request_id: str, data: str) -> None:
"""Initialize MechMetadata with request ID and data."""
self.request_id = request_id
self.data = data


@pytest.fixture
def mocked_db():
def mocked_db() -> MagicMock:
"""Fixture to mock the database."""
return MagicMock()


@pytest.fixture
def sync_data(mocked_db):
def sync_data(mocked_db: MagicMock) -> SynchronizedData:
"""Fixture for SynchronizedData."""
return SynchronizedData(db=mocked_db)


def test_sampled_bet_index(sync_data, mocked_db):
def test_sampled_bet_index(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the sampled_bet_index property."""
mocked_db.get_strict.return_value = "5"
assert sync_data.sampled_bet_index == 5
mocked_db.get_strict.assert_called_once_with("sampled_bet_index")


def test_is_mech_price_set(sync_data, mocked_db):
def test_is_mech_price_set(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the is_mech_price_set property."""
mocked_db.get.return_value = True
assert sync_data.is_mech_price_set is True
mocked_db.get.assert_called_once_with("mech_price", False)


def test_available_mech_tools(sync_data, mocked_db):
def test_available_mech_tools(
sync_data: SynchronizedData, mocked_db: MagicMock
) -> None:
"""Test the available_mech_tools property."""
mocked_db.get_strict.return_value = '["tool1", "tool2"]'
assert sync_data.available_mech_tools == ["tool1", "tool2"]
mocked_db.get_strict.assert_called_once_with("available_mech_tools")


def test_is_policy_set(sync_data, mocked_db):
def test_is_policy_set(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the is_policy_set property."""
mocked_db.get.return_value = True
assert sync_data.is_policy_set is True
mocked_db.get.assert_called_once_with("policy", False)


def test_has_tool_selection_run(sync_data, mocked_db):
def test_has_tool_selection_run(
sync_data: SynchronizedData, mocked_db: MagicMock
) -> None:
"""Test the has_tool_selection_run property."""
mocked_db.get.return_value = "tool1"
assert sync_data.has_tool_selection_run is True
mocked_db.get.assert_called_once_with("mech_tool", None)


def test_mech_tool(sync_data, mocked_db):
def test_mech_tool(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the mech_tool property."""
mocked_db.get_strict.return_value = "tool1"
assert sync_data.mech_tool == "tool1"
mocked_db.get_strict.assert_called_once_with("mech_tool")


def test_utilized_tools(sync_data, mocked_db):
def test_utilized_tools(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the utilized_tools property."""
mocked_db.get_strict.return_value = '{"tx1": "tool1"}'
assert sync_data.utilized_tools == {"tx1": "tool1"}
mocked_db.get_strict.assert_called_once_with("utilized_tools")


def test_redeemed_condition_ids(sync_data, mocked_db):
def test_redeemed_condition_ids(
sync_data: SynchronizedData, mocked_db: MagicMock
) -> None:
"""Test the redeemed_condition_ids property."""
mocked_db.get.return_value = '["cond1", "cond2"]'
assert sync_data.redeemed_condition_ids == {"cond1", "cond2"}
mocked_db.get.assert_called_once_with("redeemed_condition_ids", None)


def test_payout_so_far(sync_data, mocked_db):
def test_payout_so_far(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the payout_so_far property."""
mocked_db.get.return_value = "100"
assert sync_data.payout_so_far == 100
mocked_db.get.assert_called_once_with("payout_so_far", None)


def test_vote(sync_data, mocked_db):
def test_vote(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the vote property."""
mocked_db.get_strict.return_value = "1"
assert sync_data.vote == 1
mocked_db.get_strict.assert_called_once_with("vote")


def test_confidence(sync_data, mocked_db):
def test_confidence(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the confidence property."""
mocked_db.get_strict.return_value = "0.9"
assert sync_data.confidence == 0.9
mocked_db.get_strict.assert_called_once_with("confidence")


def test_bet_amount(sync_data, mocked_db):
def test_bet_amount(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the bet_amount property."""
mocked_db.get_strict.return_value = "50"
assert sync_data.bet_amount == 50
mocked_db.get_strict.assert_called_once_with("bet_amount")


def test_is_profitable(sync_data, mocked_db):
def test_is_profitable(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the is_profitable property."""
mocked_db.get_strict.return_value = True
assert sync_data.is_profitable is True
mocked_db.get_strict.assert_called_once_with("is_profitable")


def test_tx_submitter(sync_data, mocked_db):
def test_tx_submitter(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the tx_submitter property."""
mocked_db.get_strict.return_value = "submitter1"
assert sync_data.tx_submitter == "submitter1"
mocked_db.get_strict.assert_called_once_with("tx_submitter")


@patch("packages.valory.skills.decision_maker_abci.policy.EGreedyPolicy.deserialize")
def test_policy_property(mock_deserialize, sync_data, mocked_db):
# Set up mock return value for db.get_strict
def test_policy_property(
mock_deserialize: MagicMock, sync_data: SynchronizedData, mocked_db: MagicMock
) -> None:
"""Test for policy property"""
mock_policy_serialized = "serialized_policy_string"
mocked_db.get_strict.return_value = mock_policy_serialized

# Mock the expected result of deserialization
expected_policy = EGreedyPolicy(
eps=0.1
) # Provide a value for `eps`, adjust as appropriate
expected_policy = EGreedyPolicy(eps=0.1)
mock_deserialize.return_value = expected_policy

# Access the policy property to trigger the logic
result = sync_data.policy

# Assertions to ensure db and deserialize were called correctly
mocked_db.get_strict.assert_called_once_with("policy")
mock_deserialize.assert_called_once_with(mock_policy_serialized)
assert result == expected_policy


def test_mech_requests(sync_data, mocked_db):
def test_mech_requests(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the mech_requests property."""
mocked_db.get.return_value = '[{"request_id": "1", "data": "request_data"}]'
requests = json.loads(mocked_db.get.return_value)

# Correctly create MechMetadata objects
mech_requests = [
MechMetadata(request_id=item["request_id"], data=item["data"])
for item in requests
Expand All @@ -191,7 +193,7 @@ def test_mech_requests(sync_data, mocked_db):
assert mech_requests[0].request_id == "1"


def test_weighted_accuracy(sync_data, mocked_db):
def test_weighted_accuracy(sync_data: SynchronizedData, mocked_db: MagicMock) -> None:
"""Test the weighted_accuracy property."""
selected_mech_tool = "tool1"
policy_db_name = "policy"
Expand All @@ -206,11 +208,10 @@ def test_weighted_accuracy(sync_data, mocked_db):
assert sync_data.weighted_accuracy == policy.weighted_accuracy[selected_mech_tool]


def test_end_block(mocked_db):
def test_end_block(mocked_db: MagicMock) -> None:
"""Test the end_block logic in TxPreparationRound."""
# Mock SynchronizedData and CollectSameUntilThresholdRound behavior
mocked_sync_data = MagicMock(spec=SynchronizedData)
mock_context = MagicMock() # Create a mock context
mock_context = MagicMock()
round_instance = TxPreparationRound(
synchronized_data=mocked_sync_data, context=mock_context
)
Expand All @@ -225,4 +226,4 @@ def test_end_block(mocked_db):
TxPreparationRound, "end_block", return_value=(mocked_sync_data, Event.NONE)
):
result = round_instance.end_block()
assert result == (mocked_sync_data, Event.NONE)
assert result == (mocked_sync_data, Event.NONE)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

"""This package contains the tests for Decision Maker"""


import pytest

from packages.valory.skills.abstract_round_abci.base import (
Expand All @@ -42,28 +43,34 @@
class MockSynchronizedData(BaseSynchronizedData):
"""A mock class for SynchronizedData."""

def __init__(self, db=None):
def __init__(self, db=None) -> None:
"""Mock function"""
super().__init__(db) # Pass db to the parent class


class MockContext:
"""A mock class for context used in the rounds."""

def __init__(self):
def __init__(self) -> None:
"""Mock function"""
self.some_attribute = "mock_value" # Add any necessary attributes here


class TestFinalStates:
"""The class for test of Final States"""

@pytest.fixture
def setup_round(self):
def setup_round(self) -> tuple[MockSynchronizedData, MockContext]:
"""Fixture to set up a round instance."""
synchronized_data = MockSynchronizedData(
db="mock_db"
) # Provide a mock db value
context = MockContext()
return synchronized_data, context

def test_benchmarking_mode_disabled_round(self, setup_round):
def test_benchmarking_mode_disabled_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of BenchmarkingModeDisabledRound."""
synchronized_data, context = setup_round
round_instance = BenchmarkingModeDisabledRound(
Expand All @@ -72,7 +79,9 @@ def test_benchmarking_mode_disabled_round(self, setup_round):
assert isinstance(round_instance, BenchmarkingModeDisabledRound)
assert isinstance(round_instance, DegenerateRound)

def test_finished_decision_maker_round(self, setup_round):
def test_finished_decision_maker_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of FinishedDecisionMakerRound."""
synchronized_data, context = setup_round
round_instance = FinishedDecisionMakerRound(
Expand All @@ -81,7 +90,9 @@ def test_finished_decision_maker_round(self, setup_round):
assert isinstance(round_instance, FinishedDecisionMakerRound)
assert isinstance(round_instance, DegenerateRound)

def test_finished_decision_request_round(self, setup_round):
def test_finished_decision_request_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of FinishedDecisionRequestRound."""
synchronized_data, context = setup_round
round_instance = FinishedDecisionRequestRound(
Expand All @@ -90,7 +101,9 @@ def test_finished_decision_request_round(self, setup_round):
assert isinstance(round_instance, FinishedDecisionRequestRound)
assert isinstance(round_instance, DegenerateRound)

def test_finished_subscription_round(self, setup_round):
def test_finished_subscription_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of FinishedSubscriptionRound."""
synchronized_data, context = setup_round
round_instance = FinishedSubscriptionRound(
Expand All @@ -99,7 +112,9 @@ def test_finished_subscription_round(self, setup_round):
assert isinstance(round_instance, FinishedSubscriptionRound)
assert isinstance(round_instance, DegenerateRound)

def test_finished_without_redeeming_round(self, setup_round):
def test_finished_without_redeeming_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of FinishedWithoutRedeemingRound."""
synchronized_data, context = setup_round
round_instance = FinishedWithoutRedeemingRound(
Expand All @@ -108,7 +123,9 @@ def test_finished_without_redeeming_round(self, setup_round):
assert isinstance(round_instance, FinishedWithoutRedeemingRound)
assert isinstance(round_instance, DegenerateRound)

def test_finished_without_decision_round(self, setup_round):
def test_finished_without_decision_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of FinishedWithoutDecisionRound."""
synchronized_data, context = setup_round
round_instance = FinishedWithoutDecisionRound(
Expand All @@ -117,7 +134,9 @@ def test_finished_without_decision_round(self, setup_round):
assert isinstance(round_instance, FinishedWithoutDecisionRound)
assert isinstance(round_instance, DegenerateRound)

def test_refill_required_round(self, setup_round):
def test_refill_required_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of RefillRequiredRound."""
synchronized_data, context = setup_round
round_instance = RefillRequiredRound(
Expand All @@ -126,7 +145,9 @@ def test_refill_required_round(self, setup_round):
assert isinstance(round_instance, RefillRequiredRound)
assert isinstance(round_instance, DegenerateRound)

def test_benchmarking_done_round(self, setup_round):
def test_benchmarking_done_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of BenchmarkingDoneRound and its end_block method."""
synchronized_data, context = setup_round
round_instance = BenchmarkingDoneRound(
Expand All @@ -139,7 +160,9 @@ def test_benchmarking_done_round(self, setup_round):
with pytest.raises(SystemExit):
round_instance.end_block() # Should exit the program

def test_impossible_round(self, setup_round):
def test_impossible_round(
self, setup_round: tuple[MockSynchronizedData, MockContext]
) -> None:
"""Test instantiation of ImpossibleRound."""
synchronized_data, context = setup_round
round_instance = ImpossibleRound(
Expand Down

0 comments on commit 43e5c15

Please sign in to comment.