From f2324d7acf10e41ac017efe1f873bbfdb7e1e104 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Ribeiro?= Date: Thu, 25 Apr 2024 08:28:39 -0300 Subject: [PATCH] add codegen module --- cow_py/codegen/__init__.py | 5 + cow_py/codegen/abi_handler.py | 265 ++++++++++++++++++ cow_py/codegen/components/__init__.py | 17 ++ cow_py/codegen/components/abi_loader.py | 18 ++ cow_py/codegen/components/base_contract.py | 109 +++++++ cow_py/codegen/components/base_mixin.py | 18 ++ cow_py/codegen/components/contract_factory.py | 44 +++ cow_py/codegen/components/contract_loader.py | 43 +++ cow_py/codegen/components/get_abi_file.py | 15 + .../codegen/components/templates/__init__.py | 0 .../templates/contract_template.hbs | 14 + .../components/templates/partials/__init__.py | 0 .../templates/partials/contract_class.hbs | 6 + .../templates/partials/contract_mixin.hbs | 6 + .../templates/partials/dataclasses.hbs | 8 + .../components/templates/partials/enums.hbs | 10 + cow_py/codegen/main.py | 37 +++ cow_py/codegen/solidity_converter.py | 99 +++++++ tests/codegen/__init__.py | 0 .../codegen/components/test_base_contract.py | 60 ++++ .../components/test_contract_factory.py | 42 +++ tests/codegen/test_abi_handler.py | 210 ++++++++++++++ tests/codegen/test_solidity_converter.py | 84 ++++++ 23 files changed, 1110 insertions(+) create mode 100644 cow_py/codegen/__init__.py create mode 100644 cow_py/codegen/abi_handler.py create mode 100644 cow_py/codegen/components/__init__.py create mode 100644 cow_py/codegen/components/abi_loader.py create mode 100644 cow_py/codegen/components/base_contract.py create mode 100644 cow_py/codegen/components/base_mixin.py create mode 100644 cow_py/codegen/components/contract_factory.py create mode 100644 cow_py/codegen/components/contract_loader.py create mode 100644 cow_py/codegen/components/get_abi_file.py create mode 100644 cow_py/codegen/components/templates/__init__.py create mode 100644 cow_py/codegen/components/templates/contract_template.hbs create mode 100644 cow_py/codegen/components/templates/partials/__init__.py create mode 100644 cow_py/codegen/components/templates/partials/contract_class.hbs create mode 100644 cow_py/codegen/components/templates/partials/contract_mixin.hbs create mode 100644 cow_py/codegen/components/templates/partials/dataclasses.hbs create mode 100644 cow_py/codegen/components/templates/partials/enums.hbs create mode 100644 cow_py/codegen/main.py create mode 100644 cow_py/codegen/solidity_converter.py create mode 100644 tests/codegen/__init__.py create mode 100644 tests/codegen/components/test_base_contract.py create mode 100644 tests/codegen/components/test_contract_factory.py create mode 100644 tests/codegen/test_abi_handler.py create mode 100644 tests/codegen/test_solidity_converter.py diff --git a/cow_py/codegen/__init__.py b/cow_py/codegen/__init__.py new file mode 100644 index 0000000..00b8a0b --- /dev/null +++ b/cow_py/codegen/__init__.py @@ -0,0 +1,5 @@ +from .abi_handler import ABIHandler + +__all__ = [ + "ABIHandler", +] diff --git a/cow_py/codegen/abi_handler.py b/cow_py/codegen/abi_handler.py new file mode 100644 index 0000000..0bcac52 --- /dev/null +++ b/cow_py/codegen/abi_handler.py @@ -0,0 +1,265 @@ +import importlib.resources +import re +from typing import Any, Dict, List + +from pybars import Compiler + +from cow_py.codegen.components import templates +from cow_py.codegen.components.abi_loader import FileAbiLoader +from cow_py.codegen.components.templates import partials +from cow_py.codegen.solidity_converter import SolidityConverter + +CAMEL_TO_SNAKE_REGEX = re.compile( + r"(?<=[a-z0-9])(?=[A-Z])|" # Lowercase or digit to uppercase + r"(?<=[A-Z])(?=[A-Z][a-z])|" # Uppercase to uppercase followed by lowercase + r"(?<=[A-Za-z])(?=[0-9])|" # Letter to digit + r"(?<=[0-9])(?=[A-Z])" # Digit to uppercase +) + + +def compile_partial(partial_path: str) -> str: + with open(partial_path, "r") as file: + partial = file.read() + compiler = Compiler() + return compiler.compile(partial) + + +def get_filename_without_extension(path: str): + """ + Returns the a filename from the path, without the extension. + """ + return path.split("/")[-1].split(".")[0] + + +def to_python_conventional_name(name: str) -> str: + """Converts a camelCase or PascalCase name to a snake_case name.""" + if name.isupper(): + return name.lower() + + return CAMEL_TO_SNAKE_REGEX.sub("_", name).lower() + + +def _get_template_file() -> str: + pkg_files = importlib.resources.files(templates) + return str(next(x for x in pkg_files.iterdir() if x.suffix == ".hbs")) # type: ignore + + +def _get_partials_files() -> str: + pkg_files = importlib.resources.files(partials) + return [str(x) for x in pkg_files.iterdir() if x.suffix == ".hbs"] # type: ignore + + +class ABIHandlerError(Exception): + """Raised when an error occurs in the ABI handler.""" + + pass + + +class ABIHandler: + """ + Handles the generation of Python classes and methods from Ethereum contract ABIs. + + This class reads the ABI of a contract, processes its contents, and generates Python code that mirrors + the contract's functions and data structures. + + Attributes: + contract_name (str): Name of the contract, used for generating class names. + abi_file_path (str): Path to the ABI JSON file of the contract. + + Methods: + generate: Main method to generate Python code from the ABI.. + """ + + def __init__(self, contract_name: str, abi_file_path: str): + self.contract_name = contract_name + self.abi_file_path = abi_file_path + + def generate(self) -> str: + """ + Generates Python code representing the contract's ABI. + + This method processes the ABI file, extracting information about functions, + input/output arguments, enums, and data structures. It then uses this information + to generate corresponding Python classes and methods. + + Returns: + str: The generated Python code as a string. + + Raises: + ABIHandlerError: If an error occurs during ABI processing or code generation. + """ + try: + template_data = self._prepare_template_data() + return self._render_template(template_data) + except Exception as e: + raise ABIHandlerError(f"Error generating code: {str(e)}") from e + + def _prepare_template_data(self) -> Dict[str, Any]: + """ + Prepares data for the template rendering based on the contract's ABI. + + This method processes the ABI to extract relevant information for generating + Python code, such as methods, data classes, and enums. + + Returns: + Dict[str, Any]: A dictionary containing the structured data for rendering. + + Raises: + ABIHandlerError: If an error occurs during ABI processing. + """ + try: + methods, data_classes, enums = [], [], [] + generated_structs, generated_enums = set(), set() + + abi = FileAbiLoader(self.abi_file_path).load_abi() + + for item in abi: + if item["type"] == "function": + methods.append(self._process_function(item)) + for param in item["inputs"] + item.get("outputs", []): + self._process_parameters( + param, + data_classes, + enums, + generated_structs, + generated_enums, + ) + elif item["type"] == "event": + for param in item["inputs"]: + self._process_parameters( + param, + data_classes, + enums, + generated_structs, + generated_enums, + ) + + return { + "abiPath": self.abi_file_path, + "contractName": self.contract_name, + "methods": methods, + "dataClasses": data_classes, + "enums": enums, + } + except Exception as e: + raise ABIHandlerError(f"Error preparing template data: {str(e)}") from e + + def _process_parameters( + self, param, data_classes, enums, generated_structs, generated_enums + ): + if param["type"] == "tuple" and param["internalType"] not in generated_structs: + struct_name = SolidityConverter._get_struct_name(param["internalType"]) + properties = [ + { + "name": comp["name"], + "type": SolidityConverter.convert_type( + comp["type"], comp.get("internalType") + ), + } + for comp in param["components"] + ] + data_classes.append({"name": struct_name, "properties": properties}) + generated_structs.add(param["internalType"]) + elif ( + "enum " in param["internalType"] + and param["internalType"] not in generated_enums + ): + enum_name = SolidityConverter._get_struct_name(param["internalType"]) + enum_values = [ + {"name": item["name"], "value": item["value"]} + for item in param["components"] + ] + enums.append({"name": enum_name, "values": enum_values}) + generated_enums.add(param["internalType"]) + + def _process_function(self, function_item: Dict[str, Any]) -> Dict[str, Any]: + original_name = function_item["name"] + method_name = to_python_conventional_name(original_name) + + input_types = self._generate_function_input_args_with_types(function_item) + output_types = [ + SolidityConverter.convert_type(o["type"], o.get("internalType")) + for o in function_item.get("outputs", []) + ] + output_str = ( + "None" + if not output_types + else output_types[0] + if len(output_types) == 1 + else f'Tuple[{", ".join(output_types)}]' + ) + + return { + "name": method_name, + "inputs": input_types, + "outputType": output_str, + "originalName": original_name, + } + + def _generate_function_input_args_with_types( + self, function_item: Dict[str, Any] + ) -> List[Dict[str, Any]]: + input_args = [] + unnamed_arg_counters = {} # Track unnamed arguments of each type + + for input_item in function_item.get("inputs", []): + input_type = SolidityConverter.convert_type( + input_item["type"], input_item.get("internalType") + ) + + # Regex to transform type names like 'list[int]' into 'int_list' + base_name = re.sub(r"list\[(\w+)\]", r"\1_list", input_type.lower()) + + input_name = input_item.get("name") + if not input_name: + # If the argument is unnamed, use the base_name with a counter to create a unique name + unnamed_arg_counters[base_name] = ( + unnamed_arg_counters.get(base_name, -1) + 1 + ) + input_name = f"{base_name}_arg{unnamed_arg_counters[base_name]}" + + python_input_name = to_python_conventional_name(input_name) + + if input_item["type"] == "tuple": + struct_name = SolidityConverter._get_struct_name( + input_item["internalType"] + ) + properties = [ + { + "name": component["name"], + "type": SolidityConverter.convert_type( + component["type"], component.get("internalType") + ), + } + for component in input_item["components"] + ] + destructured_args = ", ".join( + [f"{python_input_name}.{prop['name']}" for prop in properties] + ) + input_args.append( + { + "name": python_input_name, + "type": struct_name, + "isTuple": True, + "destructuredArgs": f"({destructured_args})", + } + ) + else: + input_args.append( + {"name": python_input_name, "type": input_type, "isTuple": False} + ) + + return input_args + + def _render_template(self, data: Dict[str, Any]) -> str: + partials = { + get_filename_without_extension(partial_path): compile_partial(partial_path) + for partial_path in _get_partials_files() + } + + with open(_get_template_file(), "r") as file: + template = file.read() + + compiler = Compiler() + template = compiler.compile(template) + return template(data, partials=partials) diff --git a/cow_py/codegen/components/__init__.py b/cow_py/codegen/components/__init__.py new file mode 100644 index 0000000..df6ddb9 --- /dev/null +++ b/cow_py/codegen/components/__init__.py @@ -0,0 +1,17 @@ +from cow_py.codegen.components.abi_loader import FileAbiLoader +from cow_py.codegen.components.base_contract import BaseContract +from cow_py.codegen.components.base_mixin import BaseMixin +from cow_py.codegen.components.contract_factory import ContractFactory +from cow_py.codegen.components.contract_loader import ContractLoader +from cow_py.codegen.components.get_abi_file import get_abi_file +from cow_py.codegen.components.templates import partials + +__all__ = [ + "BaseContract", + "ContractFactory", + "FileAbiLoader", + "ContractLoader", + "BaseMixin", + "get_abi_file", + "partials", +] diff --git a/cow_py/codegen/components/abi_loader.py b/cow_py/codegen/components/abi_loader.py new file mode 100644 index 0000000..4cfa24c --- /dev/null +++ b/cow_py/codegen/components/abi_loader.py @@ -0,0 +1,18 @@ +import json +from abc import ABC, abstractmethod +from typing import Any, List + + +class AbiLoader(ABC): + @abstractmethod + def load_abi(self) -> List[Any]: + return [] + + +class FileAbiLoader(AbiLoader): + def __init__(self, file_name: str): + self.file_name = file_name + + def load_abi(self) -> List[Any]: + with open(self.file_name) as f: + return json.load(f) diff --git a/cow_py/codegen/components/base_contract.py b/cow_py/codegen/components/base_contract.py new file mode 100644 index 0000000..38bd724 --- /dev/null +++ b/cow_py/codegen/components/base_contract.py @@ -0,0 +1,109 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +from cow_py.codegen.components.contract_loader import ContractLoader +from cow_py.common.chains import Chain + + +class BaseContractError(Exception): + """Raised when an error occurs in the BaseContract class.""" + + pass + + +class BaseContract: + """ + A base class for contracts that implements common functionality. + + This class uses a singleton pattern to ensure that there's only one instance + of the contract for each contract address and chain combination. + + :ivar _instances: A dictionary to store instances of the BaseContract class. + """ + + ABI: Optional[List[Any]] = None + _instances: Dict[Tuple[Type, str, Chain], "BaseContract"] = {} + + def __new__(cls, address, chain, *args, **kwargs): + key = (cls, address, chain) + if key not in cls._instances: + cls._instances[key] = super(BaseContract, cls).__new__(cls) + return cls._instances[key] + + def __init__(self, address: str, chain: Chain, abi: List[Any] = None): + """ + Initializes the BaseContract with a contract address, chain, and optionally an ABI. + + :param address: The address of the contract on the specified chain + :param chain: The chain the contract is deployed on + :param abi: The ABI of the contract, optional + """ + if not hasattr(self, "_initialized"): # Avoid re-initialization + # Initialize the instance (only the first time) + self.contract_loader = ContractLoader(chain) + self.web3_contract = self.contract_loader.get_web3_contract( + address, abi or self.ABI or [] + ) + self._initialized = True + + @property + def address(self) -> str: + return self.web3_contract.address + + def _function_exists_in_abi(self, function_name): + """ + Checks if a function exists in the ABI of the contract. + + :param function_name: The name of the function to check for + :return: True if the function exists, False otherwise + """ + return any( + item.get("type") == "function" and item.get("name") == function_name + for item in self.web3_contract.abi + ) + + def _event_exists_in_abi(self, event_name): + """ + Checks if an event exists in the ABI of the contract. + + :param event_name: The name of the event to check for + :return: True if the event exists, False otherwise + """ + return any( + item.get("type") == "event" and item.get("name") == event_name + for item in self.web3_contract.abi + ) + + def __getattr__(self, name): + """ + Makes contract functions directly accessible as attributes of the BaseContract. + + :param name: The name of the attribute being accessed + :return: The wrapped contract function if it exists, raises AttributeError otherwise + + Raises: + BaseContractError: If an error occurs while accessing the contract function. + """ + if name == "_initialized": + # This is needed to avoid infinite recursion + raise AttributeError(name) + + try: + if hasattr(self.web3_contract, name): + return getattr(self.web3_contract, name) + + if self._event_exists_in_abi(name): + return getattr(self.web3_contract.events, name) + + if self._function_exists_in_abi(name): + function = getattr(self.web3_contract.functions, name) + + def wrapped_call(*args, **kwargs): + return function(*args, **kwargs).call() + + return wrapped_call + except Exception as e: + raise BaseContractError( + f"Error accessing attribute {name}: {str(e)}" + ) from e + + raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") diff --git a/cow_py/codegen/components/base_mixin.py b/cow_py/codegen/components/base_mixin.py new file mode 100644 index 0000000..02510f0 --- /dev/null +++ b/cow_py/codegen/components/base_mixin.py @@ -0,0 +1,18 @@ +from abc import ABC + +from web3.contract.async_contract import AsyncContract + + +class BaseMixin(ABC): + web3_contract: AsyncContract + + def call_contract_method(self, method_name, *args): + """ + Generic method to call a contract function. + + :param method_name: The name of the contract method to call. + :param args: Arguments to pass to the contract method. + :return: The result of the contract method call. + """ + method = getattr(self.web3_contract.functions, method_name) + return method(*args).call() diff --git a/cow_py/codegen/components/contract_factory.py b/cow_py/codegen/components/contract_factory.py new file mode 100644 index 0000000..ccfe8e7 --- /dev/null +++ b/cow_py/codegen/components/contract_factory.py @@ -0,0 +1,44 @@ +from typing import Dict, Tuple, Type + +from cow_py.codegen.components.abi_loader import AbiLoader +from cow_py.codegen.components.base_contract import BaseContract +from cow_py.common.chains import Chain + + +class ContractFactory: + _contract_classes: Dict[Tuple[str, Chain], Type[BaseContract]] = {} + + @classmethod + def get_contract_class( + cls, contract_name: str, chain: Chain, abi_loader: AbiLoader + ) -> Type[BaseContract]: + """ + Retrieves the contract class for a given contract name and chain, creating it if it doesn't exist. + + :param contract_name: The name of the contract + :param chain: The chain the contract is deployed on + :return: The contract class for the given contract name and chain + """ + key = (contract_name, chain) + if key not in cls._contract_classes: + abi = abi_loader.load_abi() + cls._contract_classes[key] = type( + f"{contract_name}", (BaseContract,), {"ABI": abi} + ) + return cls._contract_classes[key] + + @classmethod + def create( + cls, contract_name: str, chain: Chain, address: str, abi_loader: AbiLoader + ) -> BaseContract: + """ + Creates an instance of the contract class for a given contract identifier (name or address) and chain. + + :param chain: The chain the contract is deployed on + :param contract_identifier: The name or address of the contract on the specified chain, optional + :param address_override: address with which to instantiate the contract, optional. We do this because some + pool contracts only have a MockPool contract whose ABI we'd like to use + :return: An instance of the contract class for the given contract identifier and chain + """ + contract_class = cls.get_contract_class(contract_name, chain, abi_loader) + return contract_class(address, chain) diff --git a/cow_py/codegen/components/contract_loader.py b/cow_py/codegen/components/contract_loader.py new file mode 100644 index 0000000..aa68d52 --- /dev/null +++ b/cow_py/codegen/components/contract_loader.py @@ -0,0 +1,43 @@ +from cow_py.web3.provider import Web3Provider + + +class ContractLoaderError(Exception): + """Raised when an error occurs in the ContractLoader class.""" + + pass + + +class ContractLoader: + """ + A utility class to load contract ABIs and create web3 contract instances. + """ + + def __init__(self, network): + """ + Initializes a ContractLoader instance for a specified network. + + :param network: The network the contract loader is associated with. + """ + self.network = network + self._abis = {} + + def get_web3_contract(self, contract_address, abi=None): + """ + Creates a web3 contract instance for the specified contract address and ABI. + + :param contract_address: The address of the contract. + :param abi_file_name: The file name of the ABI, optional. + :return: A web3 contract instance. + + Raises: + ContractLoaderError: If an error occurs while creating the web3 contract instance. + """ + try: + w3 = Web3Provider.get_instance(self.network) + + return w3.eth.contract( + address=w3.to_checksum_address(contract_address), + abi=abi, + ) + except Exception as e: + raise ContractLoaderError(f"Error loading contract: {str(e)}") from e diff --git a/cow_py/codegen/components/get_abi_file.py b/cow_py/codegen/components/get_abi_file.py new file mode 100644 index 0000000..721fe6d --- /dev/null +++ b/cow_py/codegen/components/get_abi_file.py @@ -0,0 +1,15 @@ +import importlib.resources + +from cow_py.contracts import abi + + +def get_abi_file(contract_name: str) -> str: + pkg_files = importlib.resources.files(abi) + return str( + next( + x + for x in pkg_files.iterdir() + if x.suffix == ".json" # type: ignore + and x.name.split(".json")[0] == contract_name + ) + ) diff --git a/cow_py/codegen/components/templates/__init__.py b/cow_py/codegen/components/templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cow_py/codegen/components/templates/contract_template.hbs b/cow_py/codegen/components/templates/contract_template.hbs new file mode 100644 index 0000000..a9db95d --- /dev/null +++ b/cow_py/codegen/components/templates/contract_template.hbs @@ -0,0 +1,14 @@ +{{! Import statements }} +from typing import List, Tuple, Any +from hexbytes import HexBytes +from cow_py.common.chains import Chain +from dataclasses import dataclass +from enum import Enum +from cow_py.codegen.components import BaseMixin, BaseContract, FileAbiLoader, ContractFactory, get_abi_file + +{{>enums}} +{{>dataclasses}} + +{{>contract_mixin}} + +{{>contract_class}} diff --git a/cow_py/codegen/components/templates/partials/__init__.py b/cow_py/codegen/components/templates/partials/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cow_py/codegen/components/templates/partials/contract_class.hbs b/cow_py/codegen/components/templates/partials/contract_class.hbs new file mode 100644 index 0000000..ce8a4f4 --- /dev/null +++ b/cow_py/codegen/components/templates/partials/contract_class.hbs @@ -0,0 +1,6 @@ +class {{contractName}}(BaseContract, {{contractName}}Mixin): + def __init__(self, chain: Chain = Chain.MAINNET, address: str = ""): + {{!-- TEMPORARY -- abiPath should be a resolved ABI given we're already generating this --}} + abi_loader = FileAbiLoader(get_abi_file("{{contractName}}")) + contract = ContractFactory.create('{{contractName}}', chain, address, abi_loader) + super({{contractName}}, self).__init__(address, chain, abi=contract.ABI) diff --git a/cow_py/codegen/components/templates/partials/contract_mixin.hbs b/cow_py/codegen/components/templates/partials/contract_mixin.hbs new file mode 100644 index 0000000..4a3d24b --- /dev/null +++ b/cow_py/codegen/components/templates/partials/contract_mixin.hbs @@ -0,0 +1,6 @@ +class {{contractName}}Mixin(BaseMixin): +{{#each methods}} + def {{this.name}}(self{{#each this.inputs}}{{#if @first}}, {{/if}}{{this.name}}: {{this.type}}{{#unless @last}}, {{/unless}}{{/each}}) -> {{this.outputType}}: + return self.call_contract_method('{{this.originalName}}'{{#each this.inputs}}{{#if @first}}, {{/if}}{{#if this.isTuple}}{{this.destructuredArgs}}{{else}}{{this.name}}{{/if}}{{#unless @last}}, {{/unless}}{{/each}}) + +{{/each}} diff --git a/cow_py/codegen/components/templates/partials/dataclasses.hbs b/cow_py/codegen/components/templates/partials/dataclasses.hbs new file mode 100644 index 0000000..b7e2b9d --- /dev/null +++ b/cow_py/codegen/components/templates/partials/dataclasses.hbs @@ -0,0 +1,8 @@ +{{#each dataClasses}} +@dataclass +class {{this.name}}: +{{#each this.properties}} + {{this.name}}: {{this.type}} +{{/each}} + +{{/each}} \ No newline at end of file diff --git a/cow_py/codegen/components/templates/partials/enums.hbs b/cow_py/codegen/components/templates/partials/enums.hbs new file mode 100644 index 0000000..787d2de --- /dev/null +++ b/cow_py/codegen/components/templates/partials/enums.hbs @@ -0,0 +1,10 @@ +{{#each enums}} +{{#if @first}} +# TODO: Enums must be fixed before using them. They currently only use placeholder values. +{{/if}} +class {{this.name}}(Enum): +{{#each this.values}} + {{this.name}} = {{this.value}} +{{/each}} + +{{/each}} \ No newline at end of file diff --git a/cow_py/codegen/main.py b/cow_py/codegen/main.py new file mode 100644 index 0000000..15604e1 --- /dev/null +++ b/cow_py/codegen/main.py @@ -0,0 +1,37 @@ +import importlib.resources +import os + +from cow_py.codegen.abi_handler import ABIHandler +from cow_py.contracts import abi + + +def get_all_abis(): + pkg_files = importlib.resources.files(abi) + return [ + posix_path + for posix_path in pkg_files.iterdir() + if posix_path.suffix == ".json" # type: ignore + ] + + +def main(): + contracts_abis = get_all_abis() + for abi_file_path in contracts_abis: + contract_name = str(abi_file_path).split("/")[-1].split(".json")[0] + handler = ABIHandler(contract_name, str(abi_file_path)) + + content = handler.generate() + + base_path = os.path.dirname(os.path.abspath(__file__)) + + os.makedirs(f"{base_path}/__generated__", exist_ok=True) + generated = f"{base_path}/__generated__/{contract_name}.py" + + with open(generated, "w") as f: + f.write(content) + + print("Done") + + +if __name__ == "__main__": + main() diff --git a/cow_py/codegen/solidity_converter.py b/cow_py/codegen/solidity_converter.py new file mode 100644 index 0000000..90766f0 --- /dev/null +++ b/cow_py/codegen/solidity_converter.py @@ -0,0 +1,99 @@ +import re +from typing import Optional + +SOLIDITY_TO_PYTHON_TYPES = { + "address": "str", + "bool": "bool", + "string": "str", + "bytes": "HexBytes", + "uint": "int", + "int": "int", +} +DYNAMIC_SOLIDITY_TYPES = { + f"{prefix}{i*8 if prefix != 'bytes' else i}": ( + "int" if prefix != "bytes" else "HexBytes" + ) + for prefix in ["uint", "int", "bytes"] + for i in range(1, 33) +} +SOLIDITY_TO_PYTHON_TYPES.update(DYNAMIC_SOLIDITY_TYPES) + + +class SolidityConverterError(Exception): + """Raised when an error occurs in the SolidityConverter.""" + + pass + + +class SolidityConverter: + """ + Converts Solidity data types to equivalent Python data types. + + This class provides methods to map Solidity types as found in Ethereum smart contracts' ABIs + to Python types, facilitating the generation of Python classes and methods to interact with these contracts. + + Methods: + convert_type: Converts a Solidity data type to its Python equivalent. + """ + + @staticmethod + def _get_struct_name(internal_type: str) -> str: + """ + Extracts the struct name from a given internal type. + + Args: + internal_type (str): The internal type string from an ABI, often representing a struct. + + Returns: + str: The extracted name of the struct. + + Raises: + SolidityConverterError: If the internal type is not in the expected format. + """ + if not internal_type or "struct " not in internal_type: + raise SolidityConverterError( + f"Invalid internal type for struct: {internal_type}" + ) + return internal_type.replace("struct ", "").replace(".", "_").replace("[]", "") + + @classmethod + def convert_type(cls, solidity_type: str, internal_type: str) -> str: + """ + Converts a Solidity type to the corresponding Python type. + + Args: + solidity_type (str): The Solidity type as specified in the contract's ABI. + internal_type (str): The internal type representation, used for more complex data structures. + + Returns: + str: The Python type equivalent to the given Solidity type. + """ + if re.search(r"enum", internal_type) or (re.search(r"enum", solidity_type)): + return cls._extract_enum_name(internal_type, solidity_type) + elif solidity_type == "tuple": + return cls._get_struct_name(internal_type) + else: + return cls._convert_array_or_basic_type(solidity_type) + + @staticmethod + def _extract_enum_name( + internal_type: Optional[str], solidity_type: Optional[str] = None + ) -> str: + if internal_type and re.search(r"enum", internal_type): + return internal_type.replace("enum ", "").replace(".", "_") + elif solidity_type and re.search(r"enum", solidity_type): + return solidity_type.replace("enum ", "").replace(".", "_") + raise SolidityConverterError(f"Invalid internal type for enum: {internal_type}") + + @staticmethod + def _convert_array_or_basic_type(solidity_type: str) -> str: + array_match = re.match(r"(.+?)(\[\d*\])", solidity_type) + if array_match: + base_type, array_size = array_match.groups() + if array_size == "[]": + return f'List[{SOLIDITY_TO_PYTHON_TYPES.get(base_type, "Any")}]' + else: + size = int(array_size[1:-1]) + return f'Tuple[{", ".join([SOLIDITY_TO_PYTHON_TYPES.get(base_type, "Any")] * size)}]' + else: + return SOLIDITY_TO_PYTHON_TYPES.get(solidity_type, "Any") diff --git a/tests/codegen/__init__.py b/tests/codegen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/codegen/components/test_base_contract.py b/tests/codegen/components/test_base_contract.py new file mode 100644 index 0000000..8294dcb --- /dev/null +++ b/tests/codegen/components/test_base_contract.py @@ -0,0 +1,60 @@ +import pytest +from unittest.mock import patch, Mock +from cow_py.codegen.components.base_contract import ( + BaseContract, + Chain, +) + + +@patch("cow_py.codegen.components.base_contract.ContractLoader") +def test_base_contract_singleton(mock_loader): + address = "0x123" + chain = Chain.MAINNET + contract1 = BaseContract(address, chain) + contract2 = BaseContract(address, chain) + assert ( + contract1 is contract2 + ), "BaseContract should return the same instance for the same address and chain" + + +class MockWithoutAttributes(Mock): + # By default Mock objects allow access to any attribute, even if it doesn't exist. + # on the Base Contract class, we want to raise an AttributeError if the attribute doesn't exist. + def __getattr__(self, name: str): + if name == "balanceOf" or name == "nonExistentMethod": + raise AttributeError() + return super().__getattr__(name) + + +@pytest.fixture +def contract_with_abi(): + abi = [ + {"type": "function", "name": "balanceOf"}, + {"type": "event", "name": "Transfer"}, + ] + with patch("cow_py.codegen.components.base_contract.ContractLoader") as mock_loader: + mock_contract = MockWithoutAttributes() + mock_contract.abi = abi + mock_contract.functions = Mock( + balanceOf=Mock(return_value=Mock(call=Mock(return_value="1000"))), + ) + + mock_loader.return_value.get_web3_contract.return_value = mock_contract + contract = BaseContract("0x456", Chain.MAINNET, abi) + return contract + + +def test_base_contract_function_exists_in_abi(contract_with_abi): + assert contract_with_abi._function_exists_in_abi("balanceOf") + assert not contract_with_abi._function_exists_in_abi("transfer") + + +def test_base_contract_event_exists_in_abi(contract_with_abi): + assert contract_with_abi._event_exists_in_abi("Transfer") + assert not contract_with_abi._event_exists_in_abi("Approval") + + +def test_base_contract_getattr(contract_with_abi): + assert contract_with_abi.balanceOf() == "1000" + with pytest.raises(AttributeError): + _ = contract_with_abi.nonExistentMethod diff --git a/tests/codegen/components/test_contract_factory.py b/tests/codegen/components/test_contract_factory.py new file mode 100644 index 0000000..137c6eb --- /dev/null +++ b/tests/codegen/components/test_contract_factory.py @@ -0,0 +1,42 @@ +import unittest +from unittest.mock import MagicMock, patch +from cow_py.codegen.components.contract_factory import ContractFactory, BaseContract +from cow_py.common.chains import Chain + + +class TestContractFactory(unittest.TestCase): + def setUp(self): + self.contract_name = "MockContract" + self.chain = Chain.MAINNET + self.abi_loader = MagicMock() + self.abi_loader.load_abi = MagicMock( + return_value=[{"type": "function", "name": "mockFunction"}] + ) + self.address = "0xe91D153E0b41518A2Ce8Dd3D7944Fa863463a97d" + + def test_contract_factory_get_contract_class(self): + with patch.dict( + "cow_py.codegen.components.contract_factory.ContractFactory._contract_classes", + clear=True, + ): + first_class = ContractFactory.get_contract_class( + self.contract_name, self.chain, self.abi_loader + ) + second_class = ContractFactory.get_contract_class( + self.contract_name, self.chain, self.abi_loader + ) + + self.abi_loader.load_abi.assert_called_once() + self.assertEqual(first_class, second_class) + + def test_contract_factory_create(self): + with patch.dict( + "cow_py.codegen.components.contract_factory.ContractFactory._contract_classes", + clear=True, + ): + contract_instance = ContractFactory.create( + self.contract_name, self.chain, self.address, self.abi_loader + ) + + self.assertIsInstance(contract_instance, BaseContract) + self.assertEqual(contract_instance.address, self.address) diff --git a/tests/codegen/test_abi_handler.py b/tests/codegen/test_abi_handler.py new file mode 100644 index 0000000..6110ecb --- /dev/null +++ b/tests/codegen/test_abi_handler.py @@ -0,0 +1,210 @@ +import pytest + +from cow_py.codegen.abi_handler import ( + to_python_conventional_name, + get_filename_without_extension, + _get_template_file, + _get_partials_files, + compile_partial, + ABIHandler, + FileAbiLoader, +) + +from unittest.mock import mock_open + + +@pytest.mark.parametrize( + "input_name, expected_output", + [ + # ("GPv2Order_Data", "gp_v2_order_data"), + ("simpleTest", "simple_test"), + ("ThisIsATest", "this_is_a_test"), + ("number2IsHere", "number_2_is_here"), + ("AnotherTest123", "another_test_123"), + ("JSONData", "json_data"), + # ("GPv2Order_Data_arg_1", "gp_v2_order_data_arg_1"), + ], +) +def test_to_python_conventional_name(input_name, expected_output): + assert to_python_conventional_name(input_name) == expected_output + + +def test_compile_partial(mocker): + # Test that compile_partial correctly compiles a partial template + mocked_file_content = "test content" + mocked_compiled_content = "compiled content" + + mocker.patch("builtins.open", mock_open(read_data=mocked_file_content)) + mocker.patch("pybars.Compiler.compile", return_value=mocked_compiled_content) + result = compile_partial("fake_path") + assert result == mocked_compiled_content + + +def test_get_filename_without_extension(): + # Test that get_filename_without_extension correctly removes the extension + assert get_filename_without_extension("folder/test.py") == "test" + + +def test_get_template_file(): + # Test that _get_template_file returns the correct template file path + assert _get_template_file().endswith("contract_template.hbs") + + +def test_get_partials_files(): + # Test that _get_partials_files returns the correct list of partial files + assert all([f.endswith(".hbs") for f in _get_partials_files()]) + + +@pytest.fixture +def abi_handler(): + return ABIHandler("TestContract", "/fake/path/to/abi.json") + + +def test_abi_handler_generate(mocker, abi_handler): + # Test that ABIHandler.generate correctly generates Python code from an ABI + mocked_abi_data = [ + {"type": "function", "name": "doSomething", "inputs": [], "outputs": []} + ] + mocker.patch( + "cow_py.codegen.abi_handler.FileAbiLoader.load_abi", + return_value=mocked_abi_data, + ) + mocker.patch( + "cow_py.codegen.abi_handler.ABIHandler._prepare_template_data", + return_value={"methods": []}, + ) + mocker.patch( + "cow_py.codegen.abi_handler.ABIHandler._render_template", + return_value="class MyContract:\n pass", + ) + + # Run the method + result = abi_handler.generate() + + # Verify the output + assert ( + result == "class MyContract:\n pass" + ), "Generated Python code does not match expected output." + + +def test_abi_handler_prepare_template_data(mocker, abi_handler): + # Test that ABIHandler._prepare_template_data correctly processes the ABI + sample_abi = [ + { + "type": "function", + "name": "setValue", + "inputs": [{"name": "value", "type": "uint256"}], + "outputs": [], + }, + { + "type": "event", + "name": "ValueChanged", + "inputs": [{"name": "value", "type": "uint256"}], + }, + ] + + mocker.patch.object(FileAbiLoader, "load_abi", return_value=sample_abi) + + mocker.patch.object( + abi_handler, + "_process_function", + return_value={ + "name": "set_value", + "inputs": ["uint256"], + "outputType": "None", + "originalName": "setValue", + }, + ) + mocker.patch.object(abi_handler, "_process_parameters", autospec=True) + + result = abi_handler._prepare_template_data() + + assert result["abiPath"] == "/fake/path/to/abi.json" + assert result["contractName"] == "TestContract" + assert len(result["methods"]) == 1 + assert result["methods"][0]["name"] == "set_value" + assert "dataClasses" in result + assert "enums" in result + + +def test_abi_handler_process_parameters(abi_handler): + # Test that ABIHandler._process_parameters correctly processes function parameters + param = { + "type": "tuple", + "internalType": "struct Value", + "components": [ + {"name": "x", "type": "uint256", "internalType": "uint256"}, + {"name": "y", "type": "uint256", "internalType": "uint256"}, + ], + } + data_classes = [] + enums = [] + generated_structs = set() + generated_enums = set() + + expected_data_class = { + "name": "Value", + "properties": [ + {"name": "x", "type": "int"}, + {"name": "y", "type": "int"}, + ], + } + + abi_handler._process_parameters( + param, data_classes, enums, generated_structs, generated_enums + ) + + assert "struct Value" in generated_structs + assert data_classes[0] == expected_data_class + + +def test_abi_handler_process_function(abi_handler, mocker): + # Test that ABIHandler._process_function correctly processes a function item + function_item = { + "type": "function", + "name": "getValue", + "inputs": [{"name": "key", "type": "uint256", "internalType": "uint256"}], + "outputs": [{"name": "result", "type": "uint256", "internalType": "uint256"}], + } + + result = abi_handler._process_function(function_item) + + expected_result = { + "name": "get_value", + "inputs": [{"name": "key", "type": "int", "isTuple": False}], + "outputType": "int", + "originalName": "getValue", + } + + assert result == expected_result + + +def test_abi_handler_render_template(abi_handler, mocker): + # Test that ABIHandler._render_template correctly renders the template with data + template_data = { + "abiPath": "/fake/path/to/abi.json", + "contractName": "TestContract", + "methods": [ + { + "name": "set_value", + "inputs": ["uint256"], + "outputType": "uint256", + "originalName": "setValue", + } + ], + "dataClasses": [], + "enums": [], + } + template_string = "class {{ contractName }}:\n pass" + expected_rendered_output = "class TestContract:\n pass" + + mocker.patch("builtins.open", mocker.mock_open(read_data=template_string)) + + mocker.patch( + "pybars.Compiler.compile", + return_value=lambda x, **kwargs: expected_rendered_output, + ) + + result = abi_handler._render_template(template_data) + + assert result == expected_rendered_output diff --git a/tests/codegen/test_solidity_converter.py b/tests/codegen/test_solidity_converter.py new file mode 100644 index 0000000..8324b2d --- /dev/null +++ b/tests/codegen/test_solidity_converter.py @@ -0,0 +1,84 @@ +import pytest + +from cow_py.codegen.solidity_converter import SolidityConverter, SolidityConverterError + + +def test_solidity_converter_get_struct_name(): + internal_type = "struct MyStruct" + expected_result = "MyStruct" + result = SolidityConverter._get_struct_name(internal_type) + assert result == expected_result + + +def test_solidity_converter_get_struct_name_invalid_internal_type(): + internal_type = "uint256" + with pytest.raises(SolidityConverterError): + SolidityConverter._get_struct_name(internal_type) + + +def test_solidity_converter_convert_type_enum(): + solidity_type = "enum MyEnum" + internal_type = "" + expected_result = "MyEnum" + result = SolidityConverter.convert_type(solidity_type, internal_type) + assert result == expected_result + + +def test_solidity_converter_convert_type_array(): + solidity_type = "uint256[]" + internal_type = "" + expected_result = "List[int]" + result = SolidityConverter.convert_type(solidity_type, internal_type) + assert result == expected_result + + +def test_solidity_converter_convert_type_tuple(): + solidity_type = "tuple" + internal_type = "struct MyStruct" + expected_result = "MyStruct" + result = SolidityConverter.convert_type(solidity_type, internal_type) + assert result == expected_result + + +def test_solidity_converter_convert_type_fixed_size_array(): + solidity_type = "uint256[3]" + internal_type = "" + expected_result = "Tuple[int, int, int]" + result = SolidityConverter.convert_type(solidity_type, internal_type) + assert result == expected_result + + +def test_solidity_converter_convert_type_unknown_type(): + solidity_type = "unknown_type" + internal_type = "" + expected_result = "Any" + result = SolidityConverter.convert_type(solidity_type, internal_type) + assert result == expected_result + + +def test_solidity_converter_extract_enum_name(): + internal_type = "enum MyEnum.Option" + expected_result = "MyEnum_Option" + result = SolidityConverter._extract_enum_name(internal_type) + assert result == expected_result + + +def test_solidity_converter_convert_array_or_basic_type_dynamic_array(): + solidity_type = "address[]" + expected_result = "List[str]" + result = SolidityConverter._convert_array_or_basic_type(solidity_type) + assert result == expected_result + + +def test_solidity_converter_convert_array_or_basic_type_fixed_size_array(): + solidity_type = "bool[5]" + expected_result = "Tuple[bool, bool, bool, bool, bool]" + result = SolidityConverter._convert_array_or_basic_type(solidity_type) + assert result == expected_result + + +def test_solidity_converter_convert_array_or_basic_type_basic_type(): + solidity_type = "bytes32" + expected_result = "HexBytes" + result = SolidityConverter._convert_array_or_basic_type(solidity_type) + assert result == expected_result