Skip to content

Commit

Permalink
[Multi-Chain] Pass config explicitly everywhere (#432)
Browse files Browse the repository at this point in the history
This PR is based on #412. It creates config objects explicitly in the
payout script and passes it through the code as argument. Before, the
config was a global object.

This makes it easier, for example, to set up configuration at run time,
combine config with command line arguments, and separate tests from the
rest of the code.
  • Loading branch information
fhenneke authored Nov 18, 2024
1 parent 0870ebf commit 7f8d801
Show file tree
Hide file tree
Showing 17 changed files with 141 additions and 89 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ jobs:
- name: Unit Tests
run:
python -m pytest tests/unit
env:
NODE_URL: https://rpc.ankr.com/eth
10 changes: 6 additions & 4 deletions src/abis/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# TODO - following this issue: https://github.com/ethereum/web3.py/issues/3017
from web3.contract import Contract # type: ignore

from src.config import config
from src.config import IOConfig
from src.logger import set_log

ABI_PATH = config.io_config.project_root_dir / Path("src/abis")
ABI_PATH = IOConfig.from_env().project_root_dir / Path("src/abis")

log = set_log(__name__)

Expand Down Expand Up @@ -60,9 +60,11 @@ def get_contract(
# don't have to import a bunch of stuff to get the contract then want


def weth9(web3: Optional[Web3] = None) -> Contract | Type[Contract]:
def weth9(
web3: Optional[Web3] = None, address: ChecksumAddress | None = None
) -> Contract | Type[Contract]:
"""Returns an instance of WETH9 Contract"""
return IndexedContract.WETH9.get_contract(web3, config.payment_config.weth_address)
return IndexedContract.WETH9.get_contract(web3, address)


def erc20(
Expand Down
8 changes: 5 additions & 3 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def from_network(network: Network) -> AccountingConfig:
)


config = AccountingConfig.from_network(Network(os.environ.get("NETWORK", "mainnet")))

web3 = Web3(Web3.HTTPProvider(config.node_config.node_url))
web3 = Web3(
Web3.HTTPProvider(
NodeConfig.from_network(Network(os.environ.get("NETWORK", "mainnet"))).node_url
)
)
29 changes: 22 additions & 7 deletions src/fetch/payouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dune_client.types import Address
from pandas import DataFrame, Series

from src.config import config
from src.config import AccountingConfig
from src.fetch.dune import DuneFetcher
from src.fetch.prices import exchange_rate_atoms
from src.models.accounting_period import AccountingPeriod
Expand Down Expand Up @@ -269,7 +269,9 @@ class TokenConversion:
eth_to_token: Callable


def extend_payment_df(pdf: DataFrame, converter: TokenConversion) -> DataFrame:
def extend_payment_df(
pdf: DataFrame, converter: TokenConversion, config: AccountingConfig
) -> DataFrame:
"""
Extending the basic columns returned by SQL Query with some after-math:
- reward_eth as difference of payment and execution_cost
Expand All @@ -296,12 +298,13 @@ def extend_payment_df(pdf: DataFrame, converter: TokenConversion) -> DataFrame:
return pdf


def prepare_transfers(
def prepare_transfers( # pylint: disable=too-many-arguments
payout_df: DataFrame,
period: AccountingPeriod,
final_protocol_fee_wei: int,
partner_fee_tax_wei: int,
partner_fees_wei: dict[str, int],
config: AccountingConfig,
) -> PeriodPayouts:
"""
Manipulates the payout DataFrame to split into ETH and COW.
Expand Down Expand Up @@ -392,6 +395,7 @@ def construct_payout_dataframe(
slippage_df: DataFrame,
reward_target_df: DataFrame,
service_fee_df: DataFrame,
config: AccountingConfig,
) -> DataFrame:
"""
Method responsible for joining datasets related to payouts.
Expand Down Expand Up @@ -445,7 +449,7 @@ def construct_payout_dataframe(


def construct_partner_fee_payments(
partner_fees_df: DataFrame,
partner_fees_df: DataFrame, config: AccountingConfig
) -> tuple[dict[str, int], int]:
"""Compute actual partner fee payments taking partner fee tax into account
The result is a tuple. The first entry is a dictionary that contains the destination address of
Expand Down Expand Up @@ -485,13 +489,21 @@ def construct_partner_fee_payments(


def construct_payouts(
orderbook: MultiInstanceDBFetcher, dune: DuneFetcher, ignore_slippage_flag: bool
orderbook: MultiInstanceDBFetcher,
dune: DuneFetcher,
ignore_slippage_flag: bool,
config: AccountingConfig,
) -> list[Transfer]:
"""Workflow of solver reward payout logic post-CIP27"""
# pylint: disable-msg=too-many-locals

quote_rewards_df = orderbook.get_quote_rewards(dune.start_block, dune.end_block)
batch_rewards_df = orderbook.get_solver_rewards(dune.start_block, dune.end_block)
batch_rewards_df = orderbook.get_solver_rewards(
dune.start_block,
dune.end_block,
config.reward_config.batch_reward_cap_upper,
config.reward_config.batch_reward_cap_lower,
)
partner_fees_df = batch_rewards_df[["partner_list", "partner_fee_eth"]]
batch_rewards_df = batch_rewards_df.drop(
["partner_list", "partner_fee_eth"], axis=1
Expand Down Expand Up @@ -542,18 +554,20 @@ def construct_payouts(
pdf=merged_df,
# provide token conversion functions (ETH <--> COW)
converter=converter,
config=config,
),
# Dune: Fetch Solver Slippage & Reward Targets
slippage_df=slippage_df,
reward_target_df=reward_target_df,
service_fee_df=service_fee_df,
config=config,
)
# Sort by solver before breaking this data frame into Transfer objects.
complete_payout_df = complete_payout_df.sort_values("solver")

# compute partner fees
partner_fees_wei, total_partner_fee_wei_untaxed = construct_partner_fee_payments(
partner_fees_df
partner_fees_df, config
)
raw_protocol_fee_wei = int(complete_payout_df.protocol_fee_eth.sum())
final_protocol_fee_wei = raw_protocol_fee_wei - total_partner_fee_wei_untaxed
Expand Down Expand Up @@ -584,6 +598,7 @@ def construct_payouts(
final_protocol_fee_wei,
partner_fee_tax_wei,
partner_fees_wei,
config,
)
for overdraft in payouts.overdrafts:
dune.log_saver.print(str(overdraft), Category.OVERDRAFT)
Expand Down
4 changes: 2 additions & 2 deletions src/fetch/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from coinpaprika import client as cp
from dune_client.types import Address

from src.config import config
from src.config import IOConfig

log = logging.getLogger(__name__)
logging.config.fileConfig(
fname=config.io_config.log_config_file.absolute(), disable_existing_loggers=False
fname=IOConfig.from_env().log_config_file.absolute(), disable_existing_loggers=False
)

client = cp.Client()
Expand Down
28 changes: 20 additions & 8 deletions src/fetch/transfer_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import os
import ssl
from dataclasses import asdict

Expand All @@ -14,7 +15,7 @@
from gnosis.eth.ethereum_client import EthereumClient
from slack.web.client import WebClient

from src.config import config
from src.config import AccountingConfig, Network
from src.fetch.dune import DuneFetcher
from src.fetch.payouts import construct_payouts
from src.models.accounting_period import AccountingPeriod
Expand All @@ -26,7 +27,11 @@
from src.utils.script_args import generic_script_init


def manual_propose(transfers: list[Transfer], period: AccountingPeriod) -> None:
def manual_propose(
transfers: list[Transfer],
period: AccountingPeriod,
config: AccountingConfig,
) -> None:
"""
Entry point to manual creation of rewards payout transaction.
This function generates the CSV transfer file to be pasted into the COW Safe app
Expand All @@ -49,6 +54,7 @@ def auto_propose(
log_saver: PrintStore,
slack_client: WebClient,
dry_run: bool,
config: AccountingConfig,
) -> None:
"""
Entry point auto creation of rewards payout transaction.
Expand All @@ -66,6 +72,7 @@ def auto_propose(
transactions = prepend_unwrap_if_necessary(
client,
config.payment_config.payment_safe_address,
wrapped_native_token=config.payment_config.weth_address,
transactions=[t.as_multisend_tx() for t in transfers],
)
if len(transactions) > len(transfers):
Expand Down Expand Up @@ -105,6 +112,8 @@ def main() -> None:

args = generic_script_init(description="Fetch Complete Reimbursement")

config = AccountingConfig.from_network(Network(os.environ["NETWORK"]))

orderbook = MultiInstanceDBFetcher(
[config.orderbook_config.prod_db_url, config.orderbook_config.barn_db_url]
)
Expand All @@ -122,6 +131,7 @@ def main() -> None:
orderbook=orderbook,
dune=dune,
ignore_slippage_flag=args.ignore_slippage,
config=config,
)

payout_transfers = []
Expand All @@ -136,18 +146,20 @@ def main() -> None:
if args.post_tx:
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context.verify_mode = ssl.CERT_REQUIRED
slack_client = WebClient(
token=config.io_config.slack_token,
# https://stackoverflow.com/questions/59808346/python-3-slack-client-ssl-sslcertverificationerror
ssl=ssl_context,
)
auto_propose(
transfers=payout_transfers,
log_saver=dune.log_saver,
slack_client=WebClient(
token=config.io_config.slack_token,
# https://stackoverflow.com/questions/59808346/python-3-slack-client-ssl-sslcertverificationerror
ssl=ssl_context,
),
slack_client=slack_client,
dry_run=args.dry_run,
config=config,
)
else:
manual_propose(transfers=payout_transfers, period=dune.period)
manual_propose(transfers=payout_transfers, period=dune.period, config=config)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions src/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import logging.config
from logging import Logger

from src.config import config
from src.config import IOConfig

io_config = IOConfig.from_env()


# TODO - use this in every file that logs (and prints).
Expand All @@ -13,7 +15,7 @@ def set_log(name: str) -> Logger:
log = logging.getLogger(name)

logging.config.fileConfig(
fname=config.io_config.log_config_file.absolute(),
fname=io_config.log_config_file.absolute(),
disable_existing_loggers=False,
)
return log
6 changes: 1 addition & 5 deletions src/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from dune_client.types import Address

from src.config import config, web3
from src.config import web3
from src.utils.token_details import get_token_decimals


Expand Down Expand Up @@ -48,10 +48,6 @@ def __init__(self, address: str | Address, decimals: Optional[int] = None):
address = Address(address)
self.address = address

if address == config.payment_config.cow_token_address:
# Avoid Web3 Calls for main branch of program.
decimals = 18

self.decimals = (
decimals if decimals is not None else get_token_decimals(web3, address)
)
Expand Down
7 changes: 4 additions & 3 deletions src/multisend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from gnosis.safe.multi_send import MultiSend, MultiSendOperation, MultiSendTx
from gnosis.safe.safe import Safe

from src.config import config, web3
from src.config import web3, IOConfig
from src.abis.load import weth9

log = logging.getLogger(__name__)
logging.config.fileConfig(
fname=config.io_config.log_config_file.absolute(), disable_existing_loggers=False
fname=IOConfig.from_env().log_config_file.absolute(), disable_existing_loggers=False
)

# This contract address can be removed once this issue is resolved:
Expand All @@ -42,6 +42,7 @@ def prepend_unwrap_if_necessary(
client: EthereumClient,
safe_address: ChecksumAddress,
transactions: list[MultiSendTx],
wrapped_native_token: ChecksumAddress,
skip_validation: bool = False,
) -> list[MultiSendTx]:
"""
Expand All @@ -53,7 +54,7 @@ def prepend_unwrap_if_necessary(
# Amount of outgoing ETH from transfer
eth_needed = sum(t.value for t in transactions)
if eth_balance < eth_needed:
weth = weth9(client.w3)
weth = weth9(client.w3, wrapped_native_token)
weth_balance = weth.functions.balanceOf(safe_address).call()
if weth_balance + eth_balance < eth_needed:
message = (
Expand Down
25 changes: 11 additions & 14 deletions src/pg_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine

from src.config import config
from src.logger import set_log
from src.utils.query_file import open_query

Expand All @@ -31,31 +30,29 @@ def exec_query(cls, query: str, engine: Engine) -> DataFrame:
"""Executes query on DB engine"""
return pd.read_sql(sql=query, con=engine)

def get_solver_rewards(self, start_block: str, end_block: str) -> DataFrame:
def get_solver_rewards(
self,
start_block: str,
end_block: str,
reward_cap_upper: int,
reward_cap_lower: int,
) -> DataFrame:
"""
Returns aggregated solver rewards for accounting period defined by block range
"""
batch_reward_query_prod = (
open_query("orderbook/prod_batch_rewards.sql")
.replace("{{start_block}}", start_block)
.replace("{{end_block}}", end_block)
.replace(
"{{EPSILON_LOWER}}", str(config.reward_config.batch_reward_cap_lower)
)
.replace(
"{{EPSILON_UPPER}}", str(config.reward_config.batch_reward_cap_upper)
)
.replace("{{EPSILON_LOWER}}", str(reward_cap_lower))
.replace("{{EPSILON_UPPER}}", str(reward_cap_upper))
)
batch_reward_query_barn = (
open_query("orderbook/barn_batch_rewards.sql")
.replace("{{start_block}}", start_block)
.replace("{{end_block}}", end_block)
.replace(
"{{EPSILON_LOWER}}", str(config.reward_config.batch_reward_cap_lower)
)
.replace(
"{{EPSILON_UPPER}}", str(config.reward_config.batch_reward_cap_upper)
)
.replace("{{EPSILON_LOWER}}", str(reward_cap_lower))
.replace("{{EPSILON_UPPER}}", str(reward_cap_upper))
)
results = []

Expand Down
Loading

0 comments on commit 7f8d801

Please sign in to comment.