diff --git a/src/accounts/base.cairo b/src/accounts/base.cairo index eaca705..921a38a 100644 --- a/src/accounts/base.cairo +++ b/src/accounts/base.cairo @@ -27,9 +27,9 @@ pub mod RosettaAccount { use core::num::traits::Zero; use core::panic_with_felt252; use starknet::{ - ContractAddress, EthAddress, get_contract_address, get_caller_address, get_tx_info + ContractAddress, EthAddress, ClassHash, get_contract_address, get_caller_address, get_tx_info }; - use starknet::syscalls::{call_contract_syscall}; + use starknet::syscalls::{call_contract_syscall, replace_class_syscall}; use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; use rosettacontracts::accounts::utils::{is_valid_eth_signature, RosettanetSignature, RosettanetCall, RosettanetMulticall, prepare_multicall_context, validate_target_function, generate_tx_hash}; use crate::rosettanet::{IRosettanetDispatcher, IRosettanetDispatcherTrait}; @@ -45,6 +45,7 @@ pub mod RosettaAccount { pub const TRANSFER_ENTRYPOINT: felt252 = 0x83afd3f4caedc6eebf44246fe54e38c95e3179a5ec9ea81740eca5b482d12e; pub const MULTICALL_SELECTOR: felt252 = 0xFFFFFFFF; // multicall eth selector + pub const UPGRADE_SELECTOR: felt252 = 0xFFFFFFFE; // upgrades contract #[storage] struct Storage { @@ -71,14 +72,22 @@ pub mod RosettaAccount { let sn_target: ContractAddress = IRosettanetDispatcher { contract_address: self.registry.read() }.get_starknet_address(eth_target); assert(sn_target != starknet::contract_address_const::<0>(), 'target not registered'); - // executes multicall + // Multicall or upgrade call if(call.to == self.ethereum_address.read()) { // This is multicall - assert(*call.calldata.at(0) == MULTICALL_SELECTOR, 'wrong multicall selector'); - assert(call.value == 0, 'multicall value not zero'); - panic_with_felt252(Errors::UNIMPLEMENTED_FEATURE); - let context = prepare_multicall_context(call.calldata); // First calldata element removed inside this function - return self.execute_multicall(context); + let selector = *call.calldata.at(0); + if(selector == MULTICALL_SELECTOR) { + assert(call.value == 0, 'multicall value not zero'); + panic_with_felt252(Errors::UNIMPLEMENTED_FEATURE); + let context = prepare_multicall_context(call.calldata); // First calldata element removed inside this function + return self.execute_multicall(context); + } else if(selector == UPGRADE_SELECTOR) { + let latest_hash: ClassHash = IRosettanetDispatcher { contract_address: self.registry.read() }.latest_class(); + replace_class_syscall(latest_hash).unwrap(); + return array![array![latest_hash.into()].span()]; + } else { + panic_with_felt252(Errors::UNIMPLEMENTED_FEATURE); + } } // If value transfer, send STRK before calling contract diff --git a/src/rosettanet.cairo b/src/rosettanet.cairo index 698c723..6200e03 100644 --- a/src/rosettanet.cairo +++ b/src/rosettanet.cairo @@ -12,7 +12,7 @@ pub trait IRosettanet { fn get_starknet_address(self: @TState, eth_address: EthAddress) -> ContractAddress; fn get_ethereum_address(self: @TState, sn_address: ContractAddress) -> EthAddress; fn precalculate_starknet_account(self: @TState, eth_address: EthAddress) -> ContractAddress; - fn account_class(self: @TState) -> ClassHash; + fn latest_class(self: @TState) -> ClassHash; fn native_currency(self: @TState) -> ContractAddress; fn developer(self: @TState) -> ContractAddress; } @@ -73,13 +73,18 @@ pub mod Rosettanet { struct Storage { sn_to_eth: Map, eth_to_sn: Map, - account_class: ClassHash, + latest_class: ClassHash, + // Accounts will always deployed with initial class, so we can always precalculate the addresses. + // They may need to upgrade to the latest hash after deployment. + initial_class: ClassHash, dev: ContractAddress, strk: ContractAddress } #[constructor] - fn constructor(ref self: ContractState, developer: ContractAddress, strk: ContractAddress) { + fn constructor(ref self: ContractState, account_class: ClassHash, developer: ContractAddress, strk: ContractAddress) { + self.initial_class.write(account_class); + self.latest_class.write(account_class); self.dev.write(developer); self.strk.write(strk); @@ -106,7 +111,7 @@ pub mod Rosettanet { let eth_address_felt: felt252 = eth_address.into(); let (account, _) = deploy_syscall( - self.account_class.read(), eth_address_felt, array![eth_address_felt, get_contract_address().into()].span(), true + self.initial_class.read(), eth_address_felt, array![eth_address_felt, get_contract_address().into()].span(), true ) .unwrap(); @@ -135,7 +140,7 @@ pub mod Rosettanet { fn set_account_class(ref self: ContractState, class: ClassHash) { assert(get_caller_address() == self.dev.read(), 'only dev'); - self.account_class.write(class); + self.latest_class.write(class); self.emit(AccountClassChanged {changer: get_caller_address(), new_class: class}); } @@ -186,15 +191,15 @@ pub mod Rosettanet { let eth_address_felt: felt252 = eth_address.into(); calculate_contract_address_from_deploy_syscall( eth_address_felt, - self.account_class.read(), + self.initial_class.read(), array![eth_address_felt, get_contract_address().into()].span(), 0.try_into().unwrap() ) } - /// Returns current account class hash - fn account_class(self: @ContractState) -> ClassHash { - self.account_class.read() + /// Returns latest account class hash + fn latest_class(self: @ContractState) -> ClassHash { + self.latest_class.read() } /// Returns native currency address on current network