From e3542e663d8e06218d0f8ac157070b809b28e674 Mon Sep 17 00:00:00 2001 From: andrei-marinica Date: Thu, 3 Jun 2021 18:13:24 +0300 Subject: [PATCH] callbacks in modules --- .../async/forwarder/src/call_transf_exec.rs | 8 +- .../feature-tests/async/forwarder/src/lib.rs | 59 ---------- elrond-wasm-derive/src/contract_impl.rs | 20 ++-- .../src/generate/callback_gen.rs | 102 +++++++++++++----- .../interaction/callback_selector_result.rs | 8 ++ elrond-wasm/src/types/interaction/mod.rs | 2 + 6 files changed, 106 insertions(+), 93 deletions(-) create mode 100644 elrond-wasm/src/types/interaction/callback_selector_result.rs diff --git a/contracts/feature-tests/async/forwarder/src/call_transf_exec.rs b/contracts/feature-tests/async/forwarder/src/call_transf_exec.rs index e7f4f02f1b..53f679e0ee 100644 --- a/contracts/feature-tests/async/forwarder/src/call_transf_exec.rs +++ b/contracts/feature-tests/async/forwarder/src/call_transf_exec.rs @@ -65,6 +65,12 @@ pub trait ForwarderTransferExecuteModule { let gas_left_after = self.blockchain().get_gas_left(); - (gas_left_before, gas_left_after, Self::BigUint::zero(), token).into() + ( + gas_left_before, + gas_left_after, + Self::BigUint::zero(), + token, + ) + .into() } } diff --git a/contracts/feature-tests/async/forwarder/src/lib.rs b/contracts/feature-tests/async/forwarder/src/lib.rs index 7bbf21bd4f..2dff9484fc 100644 --- a/contracts/feature-tests/async/forwarder/src/lib.rs +++ b/contracts/feature-tests/async/forwarder/src/lib.rs @@ -40,63 +40,4 @@ pub trait Forwarder: }; self.send().direct_egld(to, amount, data); } - - #[callback(retrieve_funds_callback)] - fn retrieve_funds_callback_root( - &self, - #[payment_token] token: TokenIdentifier, - #[payment] payment: Self::BigUint, - ) { - // manual callback forwarding to modules is currently necessary - self.retrieve_funds_callback(token, payment) - } - - #[callback(send_funds_twice_callback)] - fn send_funds_twice_callback_root( - &self, - to: &Address, - token_identifier: &TokenIdentifier, - amount: &Self::BigUint, - ) -> AsyncCall { - // manual callback forwarding to modules is currently necessary - self.send_funds_twice_callback(to, token_identifier, amount) - } - - #[callback(esdt_issue_callback)] - fn esdt_issue_callback_root( - &self, - caller: &Address, - #[payment_token] token_identifier: TokenIdentifier, - #[payment] returned_tokens: Self::BigUint, - #[call_result] result: AsyncCallResult<()>, - ) { - // manual callback forwarding to modules is currently necessary - self.esdt_issue_callback(caller, token_identifier, returned_tokens, result) - } - - #[callback(nft_issue_callback)] - fn nft_issue_callback_root( - &self, - caller: &Address, - #[call_result] result: AsyncCallResult, - ) { - // manual callback forwarding to modules is currently necessary - self.nft_issue_callback(caller, result) - } - - #[callback(sft_issue_callback)] - fn sft_issue_callback_root( - &self, - caller: &Address, - #[call_result] result: AsyncCallResult, - ) { - // manual callback forwarding to modules is currently necessary - self.sft_issue_callback(caller, result) - } - - #[callback(change_roles_callback)] - fn change_roles_callback_root(&self, #[call_result] result: AsyncCallResult<()>) { - // manual callback forwarding to modules is currently necessary - self.change_roles_callback(result) - } } diff --git a/elrond-wasm-derive/src/contract_impl.rs b/elrond-wasm-derive/src/contract_impl.rs index b1a3f1b02d..64ff96aec7 100644 --- a/elrond-wasm-derive/src/contract_impl.rs +++ b/elrond-wasm-derive/src/contract_impl.rs @@ -16,15 +16,15 @@ pub fn contract_implementation( contract: &ContractTrait, is_contract_main: bool, ) -> proc_macro2::TokenStream { - let proxy_trait_imports = generate_all_proxy_trait_imports(&contract); + let proxy_trait_imports = generate_all_proxy_trait_imports(contract); let trait_name_ident = contract.trait_name.clone(); - let method_impls = extract_method_impls(&contract); - let call_methods = generate_call_methods(&contract); - let auto_impl_defs = generate_auto_impl_defs(&contract); - let auto_impls = generate_auto_impls(&contract); - let endpoints = generate_wasm_endpoints(&contract); - let function_selector_body = generate_function_selector_body(&contract); - let callback_body = generate_callback_body(&contract.methods); + let method_impls = extract_method_impls(contract); + let call_methods = generate_call_methods(contract); + let auto_impl_defs = generate_auto_impl_defs(contract); + let auto_impls = generate_auto_impls(contract); + let endpoints = generate_wasm_endpoints(contract); + let function_selector_body = generate_function_selector_body(contract); + let (callback_selector_body, callback_body) = generate_callback_selector_and_main(contract); let where_self_big_int = snippets::where_self_big_int(); let (callbacks_def, callbacks_impl, callback_proxies_obj) = @@ -77,6 +77,10 @@ pub fn contract_implementation( #function_selector_body } + fn callback_selector<'a>(&self, mut ___cb_data_deserializer___: elrond_wasm::hex_call_data::HexCallDataDeserializer<'a>) -> elrond_wasm::types::CallbackSelectorResult<'a> { + #callback_selector_body + } + fn callback(&self) { #callback_body } diff --git a/elrond-wasm-derive/src/generate/callback_gen.rs b/elrond-wasm-derive/src/generate/callback_gen.rs index 411cce4e92..7f3a765d1b 100644 --- a/elrond-wasm-derive/src/generate/callback_gen.rs +++ b/elrond-wasm-derive/src/generate/callback_gen.rs @@ -6,14 +6,46 @@ use super::{ payable_gen::*, util::*, }; -use crate::model::{Method, PublicRole}; +use crate::model::{ContractTrait, Method, PublicRole, Supertrait}; -pub fn generate_callback_body(methods: &[Method]) -> proc_macro2::TokenStream { - let raw_decl = find_raw_callback(methods); +pub fn generate_callback_selector_and_main( + contract: &ContractTrait, +) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let raw_decl = find_raw_callback(&contract.methods); if let Some(raw) = raw_decl { - generate_call_method_body(&raw) + let as_call_method = generate_call_method_body(&raw); + let cb_selector_body = quote! { + #as_call_method + elrond_wasm::types::CallbackSelectorResult::Processed + }; + let cb_main_body = quote! { + let _ = self.callback_selector(elrond_wasm::hex_call_data::HexCallDataDeserializer::new(&[])); + }; + (cb_selector_body, cb_main_body) } else { - generate_callback_body_regular(methods) + let match_arms: Vec = match_arms(contract.methods.as_slice()); + let module_calls: Vec = + module_calls(contract.supertraits.as_slice()); + if match_arms.is_empty() && module_calls.is_empty() { + let cb_selector_body = quote! { + elrond_wasm::types::CallbackSelectorResult::NotProcessed(___cb_data_deserializer___) + }; + let cb_main_body = quote! {}; + (cb_selector_body, cb_main_body) + } else { + let cb_selector_body = callback_selector_body(match_arms, module_calls); + let cb_main_body = quote! { + let ___tx_hash___ = elrond_wasm::api::BlockchainApi::get_tx_hash(&self.blockchain()); + let ___cb_data_raw___ = elrond_wasm::api::StorageReadApi::storage_load_boxed_bytes(&self.get_storage_raw(), &___tx_hash___.as_bytes()); + elrond_wasm::api::StorageWriteApi::storage_store_slice_u8(&self.get_storage_raw(), &___tx_hash___.as_bytes(), &[]); // cleanup + let mut ___cb_data_deserializer___ = elrond_wasm::hex_call_data::HexCallDataDeserializer::new(___cb_data_raw___.as_slice()); + if let elrond_wasm::types::CallbackSelectorResult::NotProcessed(_) = + self::EndpointWrappers::callback_selector(self, ___cb_data_deserializer___) { + self.error_api().signal_error(err_msg::CALLBACK_BAD_FUNC); + } + }; + (cb_selector_body, cb_main_body) + } } } @@ -24,13 +56,31 @@ fn find_raw_callback(methods: &[Method]) -> Option { .cloned() } -fn generate_callback_body_regular(methods: &[Method]) -> proc_macro2::TokenStream { - let mut has_call_result = false; - let match_arms: Vec = methods +fn callback_selector_body( + match_arms: Vec, + module_calls: Vec, +) -> proc_macro2::TokenStream { + quote! { + let mut ___call_result_loader___ = EndpointDynArgLoader::new(self.argument_api()); + match ___cb_data_deserializer___.get_func_name() { + [] => { + return elrond_wasm::types::CallbackSelectorResult::Processed; + } + #(#match_arms)* + _ => {}, + } + #(#module_calls)* + elrond_wasm::types::CallbackSelectorResult::NotProcessed(___cb_data_deserializer___) + } +} + +fn match_arms(methods: &[Method]) -> Vec { + methods .iter() .filter_map(|m| { if let PublicRole::Callback(callback) = &m.public_role { let payable_snippet = generate_payable_snippet(m); + let mut has_call_result = false; let arg_init_snippets: Vec = m .method_args .iter() @@ -71,6 +121,7 @@ fn generate_callback_body_regular(methods: &[Method]) -> proc_macro2::TokenStrea ___cb_closure_loader___.assert_no_more_args(); #call_result_assert_no_more_args #body_with_result ; + return elrond_wasm::types::CallbackSelectorResult::Processed; }, }; Some(match_arm) @@ -78,23 +129,24 @@ fn generate_callback_body_regular(methods: &[Method]) -> proc_macro2::TokenStrea None } }) - .collect(); - if match_arms.is_empty() { - // no callback code needed - quote! {} - } else { - quote! { - let ___tx_hash___ = elrond_wasm::api::BlockchainApi::get_tx_hash(&self.blockchain()); - let ___cb_data_raw___ = elrond_wasm::api::StorageReadApi::storage_load_boxed_bytes(&self.get_storage_raw(), &___tx_hash___.as_bytes()); - elrond_wasm::api::StorageWriteApi::storage_store_slice_u8(&self.get_storage_raw(), &___tx_hash___.as_bytes(), &[]); // cleanup - let mut ___cb_data_deserializer___ = elrond_wasm::hex_call_data::HexCallDataDeserializer::new(___cb_data_raw___.as_slice()); - let mut ___call_result_loader___ = EndpointDynArgLoader::new(self.argument_api()); + .collect() +} - match ___cb_data_deserializer___.get_func_name() { - [] => { return; } - #(#match_arms)* - other => self.error_api().signal_error(err_msg::CALLBACK_BAD_FUNC) +pub fn module_calls(supertraits: &[Supertrait]) -> Vec { + supertraits + .iter() + .map(|supertrait| { + let module_path = &supertrait.module_path; + quote! { + match #module_path EndpointWrappers::callback_selector(self, ___cb_data_deserializer___) { + elrond_wasm::types::CallbackSelectorResult::Processed => { + return elrond_wasm::types::CallbackSelectorResult::Processed; + }, + elrond_wasm::types::CallbackSelectorResult::NotProcessed(recovered_deser) => { + ___cb_data_deserializer___ = recovered_deser; + }, + } } - } - } + }) + .collect() } diff --git a/elrond-wasm/src/types/interaction/callback_selector_result.rs b/elrond-wasm/src/types/interaction/callback_selector_result.rs new file mode 100644 index 0000000000..fea6c05c25 --- /dev/null +++ b/elrond-wasm/src/types/interaction/callback_selector_result.rs @@ -0,0 +1,8 @@ +use crate::HexCallDataDeserializer; + +/// Used internally between the `callback` and `callback_selector` methods. +/// It is likely to be removed in the future. +pub enum CallbackSelectorResult<'a> { + Processed, + NotProcessed(HexCallDataDeserializer<'a>), +} diff --git a/elrond-wasm/src/types/interaction/mod.rs b/elrond-wasm/src/types/interaction/mod.rs index 59e2b6c086..d271fce299 100644 --- a/elrond-wasm/src/types/interaction/mod.rs +++ b/elrond-wasm/src/types/interaction/mod.rs @@ -1,6 +1,7 @@ mod arg_buffer; mod async_call; mod callback_call; +mod callback_selector_result; mod contract_call; mod send_egld; mod send_esdt; @@ -9,6 +10,7 @@ mod send_token; pub use arg_buffer::ArgBuffer; pub use async_call::AsyncCall; pub use callback_call::CallbackCall; +pub use callback_selector_result::CallbackSelectorResult; pub use contract_call::{new_contract_call, ContractCall}; pub use send_egld::SendEgld; pub use send_esdt::SendEsdt;