Skip to content

Commit

Permalink
test:state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ravleen-Solulab committed Sep 30, 2024
1 parent fe58225 commit c3640b5
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def test_end_block(mocked_db):
"""Test the end_block logic in TxPreparationRound."""
# Mock SynchronizedData and CollectSameUntilThresholdRound behavior
mocked_sync_data = MagicMock(spec=SynchronizedData)
round_instance = TxPreparationRound(synchronized_data=mocked_sync_data) # Removed synchronized_data_class
mock_context = MagicMock() # Create a mock context
round_instance = TxPreparationRound(synchronized_data=mocked_sync_data, context=mock_context)

with patch.object(TxPreparationRound, "end_block", return_value=(mocked_sync_data, Event.DONE)):
result = round_instance.end_block()
Expand All @@ -185,3 +186,4 @@ def test_end_block(mocked_db):
with patch.object(TxPreparationRound, "end_block", return_value=(mocked_sync_data, Event.NONE)):
result = round_instance.end_block()
assert result == (mocked_sync_data, Event.NONE)

Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
from unittest.mock import MagicMock, patch
from collections import Counter
from packages.valory.skills.decision_maker_abci.states.order_subscription import SubscriptionRound
from packages.valory.skills.decision_maker_abci.states.base import Event

@pytest.fixture
def mock_context():
"""Fixture for the context."""
context = MagicMock()
context.benchmarking_mode.enabled = False
return context

@pytest.fixture
def mock_sync_data():
"""Fixture for the synchronized data."""
return MagicMock()

@pytest.fixture
def subscription_round(mock_sync_data, mock_context):
"""Fixture for SubscriptionRound."""
round_instance = SubscriptionRound(synchronized_data=mock_sync_data, context=mock_context)

# Mocking the payload_values_count property to return a Counter
def mock_payload_values_count():
return Counter({
("payload_1",): 2,
("payload_2",): 1,
})

# Use a property to mock payload_values_count
round_instance.payload_values_count = property(mock_payload_values_count)

# Mocking the most_voted_payload_values property
round_instance.most_voted_payload_values = MagicMock(return_value=((), "valid_tx_hash", "", "agreement_id"))

# Mocking the threshold_reached property
round_instance.threshold_reached = True

return round_instance

def test_end_block_valid_tx(subscription_round):
"""Test end_block with a valid transaction hash."""
subscription_round.most_voted_payload_values = ((), "valid_tx_hash", "", "agreement_id")

sync_data, event = subscription_round.end_block()

assert event != Event.SUBSCRIPTION_ERROR
assert subscription_round.synchronized_data.update.called
assert subscription_round.synchronized_data.update.call_args[1]['agreement_id'] == "agreement_id"

def test_end_block_no_tx(subscription_round):
"""Test end_block when there is no transaction payload."""
subscription_round.most_voted_payload_values = ((), SubscriptionRound.NO_TX_PAYLOAD, "", "agreement_id")

sync_data, event = subscription_round.end_block()

assert event == Event.NO_SUBSCRIPTION
subscription_round.synchronized_data.update.assert_not_called()

def test_end_block_error_tx(subscription_round):
"""Test end_block when the transaction hash is an error payload."""
subscription_round.most_voted_payload_values = ((), SubscriptionRound.ERROR_PAYLOAD, "", "agreement_id")

sync_data, event = subscription_round.end_block()

assert event == Event.SUBSCRIPTION_ERROR
subscription_round.synchronized_data.update.assert_not_called()

def test_end_block_benchmarking_mode(subscription_round, mock_context):
"""Test end_block in benchmarking mode."""
mock_context.benchmarking_mode.enabled = True

sync_data, event = subscription_round.end_block()

assert event == Event.MOCK_TX
subscription_round.synchronized_data.update.assert_not_called()

def test_end_block_threshold_not_reached(subscription_round):
"""Test end_block when the threshold is not reached."""
subscription_round.threshold_reached = False
assert subscription_round.end_block() is None
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2023-2024 Valory AG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ------------------------------------------------------------------------------

import pytest
from unittest.mock import MagicMock
from packages.valory.skills.decision_maker_abci.states.redeem import RedeemRound
from packages.valory.skills.decision_maker_abci.states.base import Event
from packages.valory.skills.abstract_round_abci.base import BaseSynchronizedData

@pytest.fixture
def redeem_round():
"""Fixture to set up a RedeemRound instance for testing."""
synchronized_data = MagicMock(spec=BaseSynchronizedData)
context = MagicMock()
redeem_instance = RedeemRound(synchronized_data, context)

# Set initial properties
redeem_instance.block_confirmations = 0
synchronized_data.period_count = 0
synchronized_data.db = MagicMock()

return redeem_instance

def test_initial_event(redeem_round):
"""Test that the initial event is set correctly."""
assert redeem_round.none_event == Event.NO_REDEEMING

def test_end_block_no_update(redeem_round):
"""Test the end_block behavior when no update occurs."""
# This ensures that block_confirmations and period_count are 0
redeem_round.block_confirmations = 0
redeem_round.synchronized_data.period_count = 0

# Mock the superclass's end_block to simulate behavior
redeem_round.synchronized_data.db.get = MagicMock(return_value='mock_value')

# Call the actual end_block method
result = redeem_round.end_block()

# Assert the result is a tuple and check for specific event
assert isinstance(result, tuple)
assert result[1] == Event.NO_REDEEMING # Adjust based on expected output


def test_end_block_with_update(redeem_round):
"""Test the end_block behavior when an update occurs."""
# Mock the super class's end_block to return a valid update
update_result = (redeem_round.synchronized_data, Event.NO_REDEEMING) # Use an actual event from your enum
RedeemRound.end_block = MagicMock(return_value=update_result)

result = redeem_round.end_block()
assert result == update_result

# Ensure no database update was attempted
redeem_round.synchronized_data.db.update.assert_not_called()

def test_end_block_with_period_count_update(redeem_round):
"""Test the behavior when period_count is greater than zero."""
# Set up the necessary attributes
redeem_round.synchronized_data.period_count = 1

# Directly assign a valid integer to nb_participants
redeem_round.nb_participants = 3

# Set up mock return values for db.get if needed
mock_keys = RedeemRound.selection_key
for key in mock_keys:
redeem_round.synchronized_data.db.get = MagicMock(return_value='mock_value')

# Call the actual end_block method
result = redeem_round.end_block()

# Add assertions based on what you expect the result to be
assert isinstance(result, tuple) # Ensure it returns a tuple
assert result[1] == Event.NO_REDEEMING # Adjust based on expected behavior


0 comments on commit c3640b5

Please sign in to comment.