From b91c7f3a4176cef5355c97c1f3f4de06ef9d6629 Mon Sep 17 00:00:00 2001 From: Larko <59736843+Larkooo@users.noreply.github.com> Date: Wed, 12 Jun 2024 10:56:05 -0400 Subject: [PATCH] refactor(torii-libp2p): use cainome + integration test (#2044) * feat: use cainome + integration test * fmt * chore: remove unused imports * Update tests.rs --- Cargo.lock | 1 + crates/torii/libp2p/Cargo.toml | 1 + crates/torii/libp2p/src/server/mod.rs | 85 -------------- crates/torii/libp2p/src/tests.rs | 161 +++++++++++++++++++------- crates/torii/libp2p/src/typed_data.rs | 66 +++-------- 5 files changed, 138 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96465fc1ff..667854d1df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14268,6 +14268,7 @@ version = "0.7.0-alpha.5" dependencies = [ "anyhow", "async-trait", + "cainome", "chrono", "crypto-bigint", "dojo-test-utils", diff --git a/crates/torii/libp2p/Cargo.toml b/crates/torii/libp2p/Cargo.toml index d95c0ff66a..08c0472988 100644 --- a/crates/torii/libp2p/Cargo.toml +++ b/crates/torii/libp2p/Cargo.toml @@ -26,6 +26,7 @@ starknet.workspace = true thiserror.workspace = true tracing-subscriber = { version = "0.3", features = [ "env-filter" ] } tracing.workspace = true +cainome.workspace = true [dev-dependencies] dojo-test-utils.workspace = true diff --git a/crates/torii/libp2p/src/server/mod.rs b/crates/torii/libp2p/src/server/mod.rs index 3562b5a46f..4a78db033b 100644 --- a/crates/torii/libp2p/src/server/mod.rs +++ b/crates/torii/libp2p/src/server/mod.rs @@ -20,7 +20,6 @@ use libp2p::swarm::{NetworkBehaviour, SwarmEvent}; use libp2p::{identify, identity, noise, ping, relay, tcp, yamux, PeerId, Swarm, Transport}; use libp2p_webrtc as webrtc; use rand::thread_rng; -use serde_json::Number; use starknet::core::types::{BlockId, BlockTag, FunctionCall}; use starknet::core::utils::get_selector_from_name; use starknet::providers::Provider; @@ -692,90 +691,6 @@ fn read_or_create_certificate(path: &Path) -> anyhow::Result { Ok(cert) } -// Deprecated. These should be potentially removed. As Ty -> TypedData is now done -// on the SDKs side -pub fn parse_ty_to_object(ty: &Ty) -> Result, Error> { - match ty { - Ty::Struct(struct_ty) => { - let mut object = IndexMap::new(); - for member in &struct_ty.children { - let mut member_object = IndexMap::new(); - member_object.insert("key".to_string(), PrimitiveType::Bool(member.key)); - member_object.insert( - "type".to_string(), - PrimitiveType::String(ty_to_string_type(&member.ty)), - ); - member_object.insert("value".to_string(), parse_ty_to_primitive(&member.ty)?); - object.insert(member.name.clone(), PrimitiveType::Object(member_object)); - } - Ok(object) - } - _ => Err(Error::InvalidMessageError("Expected Struct type".to_string())), - } -} - -pub fn ty_to_string_type(ty: &Ty) -> String { - match ty { - Ty::Primitive(primitive) => match primitive { - Primitive::U8(_) => "u8".to_string(), - Primitive::U16(_) => "u16".to_string(), - Primitive::U32(_) => "u32".to_string(), - Primitive::USize(_) => "usize".to_string(), - Primitive::U64(_) => "u64".to_string(), - Primitive::U128(_) => "u128".to_string(), - Primitive::U256(_) => "u256".to_string(), - Primitive::Felt252(_) => "felt252".to_string(), - Primitive::ClassHash(_) => "class_hash".to_string(), - Primitive::ContractAddress(_) => "contract_address".to_string(), - Primitive::Bool(_) => "bool".to_string(), - }, - Ty::Struct(_) => "struct".to_string(), - Ty::Tuple(_) => "tuple".to_string(), - Ty::Array(_) => "array".to_string(), - Ty::ByteArray(_) => "bytearray".to_string(), - Ty::Enum(_) => "enum".to_string(), - } -} - -pub fn parse_ty_to_primitive(ty: &Ty) -> Result { - match ty { - Ty::Primitive(primitive) => match primitive { - Primitive::U8(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U16(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U32(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::USize(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v as u64).unwrap_or(0u64)))) - } - Primitive::U64(value) => { - Ok(PrimitiveType::Number(Number::from(value.map(|v| v).unwrap_or(0u64)))) - } - Primitive::U128(value) => Ok(PrimitiveType::String( - value.map(|v| v.to_string()).unwrap_or_else(|| "0".to_string()), - )), - Primitive::U256(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::Felt252(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::ClassHash(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::ContractAddress(value) => Ok(PrimitiveType::String( - value.map(|v| format!("{:#x}", v)).unwrap_or_else(|| "0".to_string()), - )), - Primitive::Bool(value) => Ok(PrimitiveType::Bool(value.unwrap_or(false))), - }, - _ => Err(Error::InvalidMessageError("Expected Primitive type".to_string())), - } -} - #[cfg(test)] mod tests { use tempfile::tempdir; diff --git a/crates/torii/libp2p/src/tests.rs b/crates/torii/libp2p/src/tests.rs index 3d0e5581be..7db5ab677b 100644 --- a/crates/torii/libp2p/src/tests.rs +++ b/crates/torii/libp2p/src/tests.rs @@ -268,20 +268,25 @@ mod test { #[cfg(not(target_arch = "wasm32"))] #[tokio::test] async fn test_client_messaging() -> Result<(), Box> { + use std::time::Duration; + use dojo_test_utils::sequencer::{ get_default_test_starknet_config, SequencerConfig, TestSequencer, }; use dojo_types::schema::{Member, Struct, Ty}; + use dojo_world::contracts::abi::model::Layout; use indexmap::IndexMap; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; + use starknet::signers::SigningKey; use starknet_crypto::FieldElement; + use tokio::select; use tokio::time::sleep; use torii_core::sql::Sql; - use crate::server::{parse_ty_to_object, Relay}; - use crate::typed_data::{Domain, TypedData}; + use crate::server::Relay; + use crate::typed_data::{Domain, Field, SimpleField, TypedData}; use crate::types::Message; let _ = tracing_subscriber::fmt() @@ -300,7 +305,36 @@ mod test { .await; let provider = JsonRpcClient::new(HttpTransport::new(sequencer.url())); - let db = Sql::new(pool.clone(), FieldElement::from_bytes_be(&[0; 32]).unwrap()).await?; + let account = sequencer.raw_account(); + + let mut db = Sql::new(pool.clone(), FieldElement::from_bytes_be(&[0; 32]).unwrap()).await?; + + // Register the model of our Message + db.register_model( + Ty::Struct(Struct { + name: "Message".to_string(), + children: vec![ + Member { + name: "identity".to_string(), + ty: Ty::Primitive(Primitive::ContractAddress(None)), + key: true, + }, + Member { + name: "message".to_string(), + ty: Ty::ByteArray("".to_string()), + key: false, + }, + ], + }), + Layout::Fixed(vec![]), + FieldElement::ZERO, + FieldElement::ZERO, + 0, + 0, + 0, + ) + .await + .unwrap(); // Initialize the relay server let mut relay_server = Relay::new(db, provider, 9900, 9901, None, None)?; @@ -314,27 +348,57 @@ mod test { client.event_loop.lock().await.run().await; }); - let mut data = Struct { name: "Message".to_string(), children: vec![] }; - - data.children.push(Member { - name: "player".to_string(), - ty: dojo_types::schema::Ty::Primitive( - dojo_types::primitive::Primitive::ContractAddress(Some( - FieldElement::from_bytes_be(&[0; 32]).unwrap(), - )), - ), - key: true, - }); - - data.children.push(Member { - name: "message".to_string(), - ty: dojo_types::schema::Ty::Primitive(dojo_types::primitive::Primitive::U8(Some(0))), - key: false, - }); - let mut typed_data = TypedData::new( - IndexMap::new(), - "Message", + IndexMap::from_iter(vec![ + ( + "OffchainMessage".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "model".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "Message".to_string(), + r#type: "Model".to_string(), + }), + ], + ), + ( + "Model".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "identity".to_string(), + r#type: "ContractAddress".to_string(), + }), + Field::SimpleType(SimpleField { + name: "message".to_string(), + r#type: "string".to_string(), + }), + ], + ), + ( + "StarknetDomain".to_string(), + vec![ + Field::SimpleType(SimpleField { + name: "name".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "version".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "chainId".to_string(), + r#type: "shortstring".to_string(), + }), + Field::SimpleType(SimpleField { + name: "revision".to_string(), + r#type: "shortstring".to_string(), + }), + ], + ), + ]), + "OffchainMessage", Domain::new("Message", "1", "0x0", Some("1")), IndexMap::new(), ); @@ -346,37 +410,50 @@ mod test { typed_data.message.insert( "Message".to_string(), crate::typed_data::PrimitiveType::Object( - parse_ty_to_object(&Ty::Struct(data.clone())).unwrap(), + vec![ + ( + "identity".to_string(), + crate::typed_data::PrimitiveType::String( + account.account_address.to_string(), + ), + ), + ( + "message".to_string(), + crate::typed_data::PrimitiveType::String("mimi".to_string()), + ), + ] + .into_iter() + .collect(), ), ); + let message_hash = typed_data.encode(account.account_address).unwrap(); + let signature = + SigningKey::from_secret_scalar(account.private_key).sign(&message_hash).unwrap(); + client .command_sender .publish(Message { message: typed_data, - signature_r: FieldElement::from_bytes_be(&[0; 32]).unwrap(), - signature_s: FieldElement::from_bytes_be(&[0; 32]).unwrap(), + signature_r: signature.r, + signature_s: signature.s, }) .await?; sleep(std::time::Duration::from_secs(2)).await; - Ok(()) - // loop { - // select! { - // entity = sqlx::query("SELECT * FROM entities WHERE id = ?") - // .bind(format!("{:#x}", FieldElement::from_bytes_be(&[0; - // 32]).unwrap())).fetch_one(&pool) => { if let Ok(_) = entity { - // println!("Test OK: Received message within 5 seconds."); - // return Ok(()); - // } - // } - // _ = sleep(Duration::from_secs(5)) => { - // println!("Test Failed: Did not receive message within 5 seconds."); - // return Err("Timeout reached without receiving a message".into()); - // } - // } - // } + loop { + select! { + entity = sqlx::query("SELECT * FROM entities").fetch_one(&pool) => if entity.is_ok() { + println!("Test OK: Received message within 5 seconds."); + return Ok(()); + }, + _ = sleep(Duration::from_secs(5)) => { + println!("Test Failed: Did not receive message within 5 seconds."); + return Err("Timeout reached without receiving a message".into()); + } + } + } } #[cfg(target_arch = "wasm32")] diff --git a/crates/torii/libp2p/src/typed_data.rs b/crates/torii/libp2p/src/typed_data.rs index dc752f751b..733c7ca29d 100644 --- a/crates/torii/libp2p/src/typed_data.rs +++ b/crates/torii/libp2p/src/typed_data.rs @@ -1,11 +1,10 @@ use std::str::FromStr; +use cainome::cairo_serde::ByteArray; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use serde_json::Number; -use starknet::core::utils::{ - cairo_short_string_to_felt, get_selector_from_name, CairoShortStringToFeltError, -}; +use starknet::core::utils::{cairo_short_string_to_felt, get_selector_from_name}; use starknet_crypto::{poseidon_hash_many, FieldElement}; use crate::errors::Error; @@ -176,39 +175,6 @@ pub fn encode_type(name: &str, types: &IndexMap>) -> Result Result<(Vec, FieldElement, usize), CairoShortStringToFeltError> { - let short_strings: Vec<&str> = split_long_string(target_string); - let remainder = short_strings.last().unwrap_or(&""); - - let mut short_strings_encoded = short_strings - .iter() - .map(|&s| cairo_short_string_to_felt(s)) - .collect::, _>>()?; - - let (pending_word, pending_word_length) = if remainder.is_empty() || remainder.len() == 31 { - (FieldElement::ZERO, 0) - } else { - (short_strings_encoded.pop().unwrap(), remainder.len()) - }; - - Ok((short_strings_encoded, pending_word, pending_word_length)) -} - -fn split_long_string(long_str: &str) -> Vec<&str> { - let mut result = Vec::new(); - - let mut start = 0; - while start < long_str.len() { - let end = (start + 31).min(long_str.len()); - result.push(&long_str[start..end]); - start = end; - } - - result -} - #[derive(Debug, Default)] pub struct Ctx { pub base_type: String, @@ -273,7 +239,7 @@ fn get_hex(value: &str) -> Result { } else { // assume its a short string and encode cairo_short_string_to_felt(value) - .map_err(|_| Error::InvalidMessageError("Invalid short string".to_string())) + .map_err(|e| Error::InvalidMessageError(format!("Invalid shortstring for felt: {}", e))) } } @@ -330,8 +296,11 @@ impl PrimitiveType { let type_hash = encode_type(r#type, if ctx.is_preset { preset_types } else { types })?; - hashes.push(get_selector_from_name(&type_hash).map_err(|_| { - Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) + hashes.push(get_selector_from_name(&type_hash).map_err(|e| { + Error::InvalidMessageError(format!( + "Invalid type {} for selector: {}", + r#type, e + )) })?); for (field_name, value) in obj { @@ -368,24 +337,23 @@ impl PrimitiveType { "shortstring" => get_hex(string), "string" => { // split the string into short strings and encode - let byte_array = byte_array_from_string(string).map_err(|_| { - Error::InvalidMessageError("Invalid short string".to_string()) + let byte_array = ByteArray::from_string(string).map_err(|e| { + Error::InvalidMessageError(format!("Invalid string for bytearray: {}", e)) })?; - let mut hashes = vec![FieldElement::from(byte_array.0.len())]; + let mut hashes = vec![FieldElement::from(byte_array.data.len())]; - for hash in byte_array.0 { - hashes.push(hash); + for hash in byte_array.data { + hashes.push(hash.felt()); } - hashes.push(byte_array.1); - hashes.push(FieldElement::from(byte_array.2)); + hashes.push(byte_array.pending_word); + hashes.push(FieldElement::from(byte_array.pending_word_len)); Ok(poseidon_hash_many(hashes.as_slice())) } - "selector" => get_selector_from_name(string).map_err(|_| { - Error::InvalidMessageError(format!("Invalid type {} for selector", r#type)) - }), + "selector" => get_selector_from_name(string) + .map_err(|e| Error::InvalidMessageError(format!("Invalid selector: {}", e))), "felt" => get_hex(string), "ContractAddress" => get_hex(string), "ClassHash" => get_hex(string),