diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 43265093..d662a8bd 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -26,3 +26,5 @@ jobs: - name: Unit Tests run: python -m pytest tests/unit + env: + NODE_URL: https://rpc.ankr.com/eth diff --git a/src/abis/load.py b/src/abis/load.py index f52adb86..46a89716 100644 --- a/src/abis/load.py +++ b/src/abis/load.py @@ -13,10 +13,10 @@ # TODO - following this issue: https://github.com/ethereum/web3.py/issues/3017 from web3.contract import Contract # type: ignore -from src.config import config +from src.config import IOConfig from src.logger import set_log -ABI_PATH = config.io_config.project_root_dir / Path("src/abis") +ABI_PATH = IOConfig.from_env().project_root_dir / Path("src/abis") log = set_log(__name__) @@ -60,9 +60,11 @@ def get_contract( # don't have to import a bunch of stuff to get the contract then want -def weth9(web3: Optional[Web3] = None) -> Contract | Type[Contract]: +def weth9( + web3: Optional[Web3] = None, address: ChecksumAddress | None = None +) -> Contract | Type[Contract]: """Returns an instance of WETH9 Contract""" - return IndexedContract.WETH9.get_contract(web3, config.payment_config.weth_address) + return IndexedContract.WETH9.get_contract(web3, address) def erc20( diff --git a/src/config.py b/src/config.py index a4cddd05..14775317 100644 --- a/src/config.py +++ b/src/config.py @@ -291,6 +291,8 @@ def from_network(network: Network) -> AccountingConfig: ) -config = AccountingConfig.from_network(Network(os.environ.get("NETWORK", "mainnet"))) - -web3 = Web3(Web3.HTTPProvider(config.node_config.node_url)) +web3 = Web3( + Web3.HTTPProvider( + NodeConfig.from_network(Network(os.environ.get("NETWORK", "mainnet"))).node_url + ) +) diff --git a/src/fetch/payouts.py b/src/fetch/payouts.py index 100a1b1a..1bc79912 100644 --- a/src/fetch/payouts.py +++ b/src/fetch/payouts.py @@ -14,7 +14,7 @@ from dune_client.types import Address from pandas import DataFrame, Series -from src.config import config +from src.config import AccountingConfig from src.fetch.dune import DuneFetcher from src.fetch.prices import exchange_rate_atoms from src.models.accounting_period import AccountingPeriod @@ -269,7 +269,9 @@ class TokenConversion: eth_to_token: Callable -def extend_payment_df(pdf: DataFrame, converter: TokenConversion) -> DataFrame: +def extend_payment_df( + pdf: DataFrame, converter: TokenConversion, config: AccountingConfig +) -> DataFrame: """ Extending the basic columns returned by SQL Query with some after-math: - reward_eth as difference of payment and execution_cost @@ -296,12 +298,13 @@ def extend_payment_df(pdf: DataFrame, converter: TokenConversion) -> DataFrame: return pdf -def prepare_transfers( +def prepare_transfers( # pylint: disable=too-many-arguments payout_df: DataFrame, period: AccountingPeriod, final_protocol_fee_wei: int, partner_fee_tax_wei: int, partner_fees_wei: dict[str, int], + config: AccountingConfig, ) -> PeriodPayouts: """ Manipulates the payout DataFrame to split into ETH and COW. @@ -392,6 +395,7 @@ def construct_payout_dataframe( slippage_df: DataFrame, reward_target_df: DataFrame, service_fee_df: DataFrame, + config: AccountingConfig, ) -> DataFrame: """ Method responsible for joining datasets related to payouts. @@ -445,7 +449,7 @@ def construct_payout_dataframe( def construct_partner_fee_payments( - partner_fees_df: DataFrame, + partner_fees_df: DataFrame, config: AccountingConfig ) -> tuple[dict[str, int], int]: """Compute actual partner fee payments taking partner fee tax into account The result is a tuple. The first entry is a dictionary that contains the destination address of @@ -485,13 +489,21 @@ def construct_partner_fee_payments( def construct_payouts( - orderbook: MultiInstanceDBFetcher, dune: DuneFetcher, ignore_slippage_flag: bool + orderbook: MultiInstanceDBFetcher, + dune: DuneFetcher, + ignore_slippage_flag: bool, + config: AccountingConfig, ) -> list[Transfer]: """Workflow of solver reward payout logic post-CIP27""" # pylint: disable-msg=too-many-locals quote_rewards_df = orderbook.get_quote_rewards(dune.start_block, dune.end_block) - batch_rewards_df = orderbook.get_solver_rewards(dune.start_block, dune.end_block) + batch_rewards_df = orderbook.get_solver_rewards( + dune.start_block, + dune.end_block, + config.reward_config.batch_reward_cap_upper, + config.reward_config.batch_reward_cap_lower, + ) partner_fees_df = batch_rewards_df[["partner_list", "partner_fee_eth"]] batch_rewards_df = batch_rewards_df.drop( ["partner_list", "partner_fee_eth"], axis=1 @@ -542,18 +554,20 @@ def construct_payouts( pdf=merged_df, # provide token conversion functions (ETH <--> COW) converter=converter, + config=config, ), # Dune: Fetch Solver Slippage & Reward Targets slippage_df=slippage_df, reward_target_df=reward_target_df, service_fee_df=service_fee_df, + config=config, ) # Sort by solver before breaking this data frame into Transfer objects. complete_payout_df = complete_payout_df.sort_values("solver") # compute partner fees partner_fees_wei, total_partner_fee_wei_untaxed = construct_partner_fee_payments( - partner_fees_df + partner_fees_df, config ) raw_protocol_fee_wei = int(complete_payout_df.protocol_fee_eth.sum()) final_protocol_fee_wei = raw_protocol_fee_wei - total_partner_fee_wei_untaxed @@ -584,6 +598,7 @@ def construct_payouts( final_protocol_fee_wei, partner_fee_tax_wei, partner_fees_wei, + config, ) for overdraft in payouts.overdrafts: dune.log_saver.print(str(overdraft), Category.OVERDRAFT) diff --git a/src/fetch/prices.py b/src/fetch/prices.py index fd51aeb4..f86f51e3 100644 --- a/src/fetch/prices.py +++ b/src/fetch/prices.py @@ -12,11 +12,11 @@ from coinpaprika import client as cp from dune_client.types import Address -from src.config import config +from src.config import IOConfig log = logging.getLogger(__name__) logging.config.fileConfig( - fname=config.io_config.log_config_file.absolute(), disable_existing_loggers=False + fname=IOConfig.from_env().log_config_file.absolute(), disable_existing_loggers=False ) client = cp.Client() diff --git a/src/fetch/transfer_file.py b/src/fetch/transfer_file.py index 7383ddb8..b2c7a008 100644 --- a/src/fetch/transfer_file.py +++ b/src/fetch/transfer_file.py @@ -4,6 +4,7 @@ from __future__ import annotations +import os import ssl from dataclasses import asdict @@ -14,7 +15,7 @@ from gnosis.eth.ethereum_client import EthereumClient from slack.web.client import WebClient -from src.config import config +from src.config import AccountingConfig, Network from src.fetch.dune import DuneFetcher from src.fetch.payouts import construct_payouts from src.models.accounting_period import AccountingPeriod @@ -26,7 +27,11 @@ from src.utils.script_args import generic_script_init -def manual_propose(transfers: list[Transfer], period: AccountingPeriod) -> None: +def manual_propose( + transfers: list[Transfer], + period: AccountingPeriod, + config: AccountingConfig, +) -> None: """ Entry point to manual creation of rewards payout transaction. This function generates the CSV transfer file to be pasted into the COW Safe app @@ -49,6 +54,7 @@ def auto_propose( log_saver: PrintStore, slack_client: WebClient, dry_run: bool, + config: AccountingConfig, ) -> None: """ Entry point auto creation of rewards payout transaction. @@ -66,6 +72,7 @@ def auto_propose( transactions = prepend_unwrap_if_necessary( client, config.payment_config.payment_safe_address, + wrapped_native_token=config.payment_config.weth_address, transactions=[t.as_multisend_tx() for t in transfers], ) if len(transactions) > len(transfers): @@ -105,6 +112,8 @@ def main() -> None: args = generic_script_init(description="Fetch Complete Reimbursement") + config = AccountingConfig.from_network(Network(os.environ["NETWORK"])) + orderbook = MultiInstanceDBFetcher( [config.orderbook_config.prod_db_url, config.orderbook_config.barn_db_url] ) @@ -122,6 +131,7 @@ def main() -> None: orderbook=orderbook, dune=dune, ignore_slippage_flag=args.ignore_slippage, + config=config, ) payout_transfers = [] @@ -136,18 +146,20 @@ def main() -> None: if args.post_tx: ssl_context = ssl.create_default_context(cafile=certifi.where()) ssl_context.verify_mode = ssl.CERT_REQUIRED + slack_client = WebClient( + token=config.io_config.slack_token, + # https://stackoverflow.com/questions/59808346/python-3-slack-client-ssl-sslcertverificationerror + ssl=ssl_context, + ) auto_propose( transfers=payout_transfers, log_saver=dune.log_saver, - slack_client=WebClient( - token=config.io_config.slack_token, - # https://stackoverflow.com/questions/59808346/python-3-slack-client-ssl-sslcertverificationerror - ssl=ssl_context, - ), + slack_client=slack_client, dry_run=args.dry_run, + config=config, ) else: - manual_propose(transfers=payout_transfers, period=dune.period) + manual_propose(transfers=payout_transfers, period=dune.period, config=config) if __name__ == "__main__": diff --git a/src/logger.py b/src/logger.py index d93a5d07..a387e597 100644 --- a/src/logger.py +++ b/src/logger.py @@ -3,7 +3,9 @@ import logging.config from logging import Logger -from src.config import config +from src.config import IOConfig + +io_config = IOConfig.from_env() # TODO - use this in every file that logs (and prints). @@ -13,7 +15,7 @@ def set_log(name: str) -> Logger: log = logging.getLogger(name) logging.config.fileConfig( - fname=config.io_config.log_config_file.absolute(), + fname=io_config.log_config_file.absolute(), disable_existing_loggers=False, ) return log diff --git a/src/models/token.py b/src/models/token.py index 9a73d41c..031071e9 100644 --- a/src/models/token.py +++ b/src/models/token.py @@ -9,7 +9,7 @@ from dune_client.types import Address -from src.config import config, web3 +from src.config import web3 from src.utils.token_details import get_token_decimals @@ -48,10 +48,6 @@ def __init__(self, address: str | Address, decimals: Optional[int] = None): address = Address(address) self.address = address - if address == config.payment_config.cow_token_address: - # Avoid Web3 Calls for main branch of program. - decimals = 18 - self.decimals = ( decimals if decimals is not None else get_token_decimals(web3, address) ) diff --git a/src/multisend.py b/src/multisend.py index 49ae23d7..90972909 100644 --- a/src/multisend.py +++ b/src/multisend.py @@ -12,12 +12,12 @@ from gnosis.safe.multi_send import MultiSend, MultiSendOperation, MultiSendTx from gnosis.safe.safe import Safe -from src.config import config, web3 +from src.config import web3, IOConfig from src.abis.load import weth9 log = logging.getLogger(__name__) logging.config.fileConfig( - fname=config.io_config.log_config_file.absolute(), disable_existing_loggers=False + fname=IOConfig.from_env().log_config_file.absolute(), disable_existing_loggers=False ) # This contract address can be removed once this issue is resolved: @@ -42,6 +42,7 @@ def prepend_unwrap_if_necessary( client: EthereumClient, safe_address: ChecksumAddress, transactions: list[MultiSendTx], + wrapped_native_token: ChecksumAddress, skip_validation: bool = False, ) -> list[MultiSendTx]: """ @@ -53,7 +54,7 @@ def prepend_unwrap_if_necessary( # Amount of outgoing ETH from transfer eth_needed = sum(t.value for t in transactions) if eth_balance < eth_needed: - weth = weth9(client.w3) + weth = weth9(client.w3, wrapped_native_token) weth_balance = weth.functions.balanceOf(safe_address).call() if weth_balance + eth_balance < eth_needed: message = ( diff --git a/src/pg_client.py b/src/pg_client.py index 4e753f34..ce6255e4 100644 --- a/src/pg_client.py +++ b/src/pg_client.py @@ -8,7 +8,6 @@ from sqlalchemy import create_engine from sqlalchemy.engine import Engine -from src.config import config from src.logger import set_log from src.utils.query_file import open_query @@ -31,7 +30,13 @@ def exec_query(cls, query: str, engine: Engine) -> DataFrame: """Executes query on DB engine""" return pd.read_sql(sql=query, con=engine) - def get_solver_rewards(self, start_block: str, end_block: str) -> DataFrame: + def get_solver_rewards( + self, + start_block: str, + end_block: str, + reward_cap_upper: int, + reward_cap_lower: int, + ) -> DataFrame: """ Returns aggregated solver rewards for accounting period defined by block range """ @@ -39,23 +44,15 @@ def get_solver_rewards(self, start_block: str, end_block: str) -> DataFrame: open_query("orderbook/prod_batch_rewards.sql") .replace("{{start_block}}", start_block) .replace("{{end_block}}", end_block) - .replace( - "{{EPSILON_LOWER}}", str(config.reward_config.batch_reward_cap_lower) - ) - .replace( - "{{EPSILON_UPPER}}", str(config.reward_config.batch_reward_cap_upper) - ) + .replace("{{EPSILON_LOWER}}", str(reward_cap_lower)) + .replace("{{EPSILON_UPPER}}", str(reward_cap_upper)) ) batch_reward_query_barn = ( open_query("orderbook/barn_batch_rewards.sql") .replace("{{start_block}}", start_block) .replace("{{end_block}}", end_block) - .replace( - "{{EPSILON_LOWER}}", str(config.reward_config.batch_reward_cap_lower) - ) - .replace( - "{{EPSILON_UPPER}}", str(config.reward_config.batch_reward_cap_upper) - ) + .replace("{{EPSILON_LOWER}}", str(reward_cap_lower)) + .replace("{{EPSILON_UPPER}}", str(reward_cap_upper)) ) results = [] diff --git a/src/utils/query_file.py b/src/utils/query_file.py index 17ea5df7..14a8d558 100644 --- a/src/utils/query_file.py +++ b/src/utils/query_file.py @@ -6,7 +6,9 @@ import os -from src.config import config +from src.config import IOConfig + +io_config = IOConfig.from_env() def open_query(filename: str) -> str: @@ -23,9 +25,9 @@ def open_dashboard_query(filename: str) -> str: def query_file(filename: str) -> str: """Returns proper path for filename in QUERY_PATH""" - return os.path.join(config.io_config.query_dir, filename) + return os.path.join(io_config.query_dir, filename) def dashboard_file(filename: str) -> str: """Returns proper path for filename in DASHBOARD_PATH""" - return os.path.join(config.io_config.dashboard_dir, filename) + return os.path.join(io_config.dashboard_dir, filename) diff --git a/src/utils/token_details.py b/src/utils/token_details.py index 720fb542..f69ac85e 100644 --- a/src/utils/token_details.py +++ b/src/utils/token_details.py @@ -9,12 +9,12 @@ from web3 import Web3 from src.abis.load import erc20 -from src.config import config +from src.config import IOConfig log = logging.getLogger(__name__) logging.config.fileConfig( - fname=config.io_config.log_config_file.absolute(), disable_existing_loggers=False + fname=IOConfig.from_env().log_config_file.absolute(), disable_existing_loggers=False ) diff --git a/tests/e2e/test_prices.py b/tests/e2e/test_prices.py index 748dff3a..7ca0f535 100644 --- a/tests/e2e/test_prices.py +++ b/tests/e2e/test_prices.py @@ -3,7 +3,7 @@ from dune_client.types import Address -from src.config import config +from src.config import AccountingConfig, Network from src.fetch.prices import ( TokenId, exchange_rate_atoms, @@ -16,12 +16,13 @@ class TestPrices(unittest.TestCase): def setUp(self) -> None: + self.config = AccountingConfig.from_network(Network.MAINNET) self.some_date = datetime.strptime("2024-09-01", "%Y-%m-%d") self.cow_price = usd_price(TokenId.COW, self.some_date) self.eth_price = usd_price(TokenId.ETH, self.some_date) self.usdc_price = usd_price(TokenId.USDC, self.some_date) - self.cow_address = config.reward_config.reward_token_address - self.weth_address = Address(config.payment_config.weth_address) + self.cow_address = self.config.reward_config.reward_token_address + self.weth_address = Address(self.config.payment_config.weth_address) self.usdc_address = Address("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48") def test_usd_price(self): diff --git a/tests/queries/test_batch_rewards.py b/tests/queries/test_batch_rewards.py index 19d1f9ed..88b02272 100644 --- a/tests/queries/test_batch_rewards.py +++ b/tests/queries/test_batch_rewards.py @@ -3,12 +3,16 @@ import pandas.testing from pandas import DataFrame +from src.config import RewardConfig, Network from src.pg_client import MultiInstanceDBFetcher class TestBatchRewards(unittest.TestCase): def setUp(self) -> None: db_url = "postgres:postgres@localhost:5432/postgres" + reward_config = RewardConfig.from_network(Network.MAINNET) + self.batch_reward_cap_upper = reward_config.batch_reward_cap_upper + self.batch_reward_cap_lower = reward_config.batch_reward_cap_lower self.fetcher = MultiInstanceDBFetcher([db_url]) with open( "./tests/queries/batch_rewards_test_db.sql", "r", encoding="utf-8" @@ -17,7 +21,12 @@ def setUp(self) -> None: def test_get_batch_rewards(self): start_block, end_block = "0", "100" - batch_rewards = self.fetcher.get_solver_rewards(start_block, end_block) + batch_rewards = self.fetcher.get_solver_rewards( + start_block, + end_block, + self.batch_reward_cap_upper, + self.batch_reward_cap_lower, + ) expected = DataFrame( { "solver": [ diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index f1585bfd..ab5e969f 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -7,7 +7,7 @@ from web3 import Web3 from src.abis.load import erc20 -from src.config import config +from src.config import PaymentConfig, Network from src.fetch.transfer_file import Transfer from src.models.accounting_period import AccountingPeriod from src.models.token import Token @@ -19,6 +19,7 @@ class TestTransfer(unittest.TestCase): def setUp(self) -> None: + self.payment_config = PaymentConfig.from_network(Network.MAINNET) self.token_1 = Token(Address.from_int(1), 18) self.token_2 = Token(Address.from_int(2), 18) @@ -39,7 +40,7 @@ def test_basic_as_multisend_tx(self): ), ) erc20_transfer = Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.payment_config.cow_token_address), recipient=Address(receiver), amount_wei=15, ) @@ -47,7 +48,7 @@ def test_basic_as_multisend_tx(self): erc20_transfer.as_multisend_tx(), MultiSendTx( operation=MultiSendOperation.CALL, - to=config.payment_config.cow_token_address.address, + to=self.payment_config.cow_token_address.address, value=0, data=erc20().encodeABI(fn_name="transfer", args=[receiver, 15]), ), @@ -94,7 +95,7 @@ def test_summarize(self): [ Transfer(token=None, recipient=receiver, amount_wei=eth_amount), Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.payment_config.cow_token_address), recipient=receiver, amount_wei=cow_amount, ), diff --git a/tests/unit/test_multisend.py b/tests/unit/test_multisend.py index 3ae0731f..93f99d27 100644 --- a/tests/unit/test_multisend.py +++ b/tests/unit/test_multisend.py @@ -6,7 +6,7 @@ from web3 import Web3 from src.abis.load import weth9 -from src.config import config +from src.config import Network, PaymentConfig from src.fetch.transfer_file import Transfer from src.models.token import Token from src.multisend import build_encoded_multisend, prepend_unwrap_if_necessary @@ -16,12 +16,11 @@ class TestMultiSend(unittest.TestCase): def setUp(self) -> None: node_url = "https://rpc.ankr.com/eth" self.client = EthereumClient(URI(node_url)) + self.payment_config = PaymentConfig.from_network(Network.MAINNET) def test_prepend_unwrap(self): many_eth = 99999999 * 10**18 - safe_address = Web3().to_checksum_address( - "0xA03be496e67Ec29bC62F01a428683D7F9c204930" - ) + safe_address = self.payment_config.payment_safe_address big_native_transfer = Transfer( token=None, recipient=Address.zero(), amount_wei=many_eth ).as_multisend_tx() @@ -32,10 +31,11 @@ def test_prepend_unwrap(self): client=self.client, safe_address=safe_address, transactions=[big_native_transfer], + wrapped_native_token=self.payment_config.weth_address, ) eth_balance = self.client.get_balance(safe_address) - weth = weth9(self.client.w3) + weth = weth9(self.client.w3, self.payment_config.weth_address) weth_balance = weth.functions.balanceOf(safe_address).call() transactions = [ @@ -46,7 +46,11 @@ def test_prepend_unwrap(self): ).as_multisend_tx() ] transactions = prepend_unwrap_if_necessary( - self.client, safe_address, transactions, skip_validation=True + self.client, + safe_address, + transactions, + self.payment_config.weth_address, + skip_validation=True, ) self.assertEqual(2, len(transactions)) @@ -63,7 +67,7 @@ def test_prepend_unwrap(self): def test_multisend_encoding(self): receiver = Address("0xde786877a10dbb7eba25a4da65aecf47654f08ab") - cow_token = Token(config.payment_config.cow_token_address) + cow_token = Token(self.payment_config.cow_token_address) self.assertEqual( build_encoded_multisend([], client=self.client), "0x8d80ff0a" # MethodID diff --git a/tests/unit/test_payouts.py b/tests/unit/test_payouts.py index f4528bd4..3b95e5e8 100644 --- a/tests/unit/test_payouts.py +++ b/tests/unit/test_payouts.py @@ -5,7 +5,7 @@ from dune_client.types import Address from pandas import DataFrame -from src.config import config +from src.config import AccountingConfig, Network from src.fetch.payouts import ( extend_payment_df, normalize_address_field, @@ -25,6 +25,7 @@ class TestPayoutTransformations(unittest.TestCase): """Contains tests all stray methods in src/fetch/payouts.py""" def setUp(self) -> None: + self.config = AccountingConfig.from_network(Network.MAINNET) self.solvers = list( map( str, @@ -52,7 +53,7 @@ def setUp(self) -> None: map( str, [ - config.reward_config.cow_bonding_pool, + self.config.reward_config.cow_bonding_pool, Address.from_int(10), Address.from_int(11), Address.from_int(12), @@ -97,7 +98,9 @@ def test_extend_payment_df(self): "network_fee_eth": self.network_fee_eth, } base_payout_df = DataFrame(base_data_dict) - result = extend_payment_df(base_payout_df, converter=self.mock_converter) + result = extend_payment_df( + base_payout_df, converter=self.mock_converter, config=self.config + ) expected_data_dict = { "solver": self.solvers, "num_quotes": self.num_quotes, @@ -210,6 +213,7 @@ def test_construct_payouts(self): } ), converter=self.mock_converter, + config=self.config, ) slippages = DataFrame( @@ -239,6 +243,7 @@ def test_construct_payouts(self): slippage_df=slippages, reward_target_df=reward_targets, service_fee_df=service_fee_df, + config=self.config, ) expected = DataFrame( { @@ -278,7 +283,7 @@ def test_construct_payouts(self): "0x0000000000000000000000000000000000000008", ], "pool_address": [ - str(config.reward_config.cow_bonding_pool), + str(self.config.reward_config.cow_bonding_pool), "0x0000000000000000000000000000000000000010", "0x0000000000000000000000000000000000000011", "0x0000000000000000000000000000000000000012", @@ -296,10 +301,10 @@ def test_construct_payouts(self): str(self.solvers[3]), ], "reward_token_address": [ - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), ], } ) @@ -354,17 +359,17 @@ def test_prepare_transfers(self): "0x0000000000000000000000000000000000000008", ], "reward_token_address": [ - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), - str(config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), + str(self.config.reward_config.reward_token_address), ], } ) period = AccountingPeriod("1985-03-10", 1) protocol_fee_amount = sum(self.protocol_fee_eth) payout_transfers = prepare_transfers( - full_payout_data, period, protocol_fee_amount, 0, {} + full_payout_data, period, protocol_fee_amount, 0, {}, self.config ) self.assertEqual( [ @@ -374,31 +379,31 @@ def test_prepare_transfers(self): amount_wei=1, ), Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.config.payment_config.cow_token_address), recipient=Address(self.reward_targets[0]), amount_wei=600000000000000000, ), Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.config.payment_config.cow_token_address), recipient=Address(self.reward_targets[1]), amount_wei=12000000000000000000, ), Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.config.payment_config.cow_token_address), recipient=Address(self.reward_targets[2]), amount_wei=90000000000000000000, ), Transfer( - token=Token(config.payment_config.cow_token_address), + token=Token(self.config.payment_config.cow_token_address), recipient=Address(self.reward_targets[3]), amount_wei=int( 180000000000000000000 - * (1 - config.reward_config.service_fee_factor) + * (1 - self.config.reward_config.service_fee_factor) ), ), Transfer( token=None, - recipient=config.protocol_fee_config.protocol_fee_safe, + recipient=self.config.protocol_fee_config.protocol_fee_safe, amount_wei=3000000000000000, ), ], @@ -420,11 +425,12 @@ def test_prepare_transfers(self): class TestRewardAndPenaltyDatum(unittest.TestCase): def setUp(self) -> None: + self.config = AccountingConfig.from_network(Network.MAINNET) self.solver = Address.from_int(1) self.solver_name = "Solver1" self.reward_target = Address.from_int(2) self.buffer_accounting_target = Address.from_int(3) - self.cow_token_address = config.payment_config.cow_token_address + self.cow_token_address = self.config.payment_config.cow_token_address self.cow_token = Token(self.cow_token_address) self.conversion_rate = 1000 @@ -444,7 +450,7 @@ def sample_record( primary_reward_eth=primary_reward, primary_reward_cow=primary_reward * self.conversion_rate, slippage_eth=slippage, - quote_reward_cow=config.reward_config.quote_reward_cow * num_quotes, + quote_reward_cow=self.config.reward_config.quote_reward_cow * num_quotes, service_fee=service_fee, reward_token_address=self.cow_token_address, )