diff --git a/Cargo.lock b/Cargo.lock index dd3aa6e355..ab71803a73 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10987,6 +10987,7 @@ dependencies = [ "dojo-test-utils", "dojo-types", "dojo-world", + "futures", "katana-runner", "notify", "notify-debouncer-mini", diff --git a/bin/sozo/Cargo.toml b/bin/sozo/Cargo.toml index 752f36e41f..0bfb305818 100644 --- a/bin/sozo/Cargo.toml +++ b/bin/sozo/Cargo.toml @@ -28,6 +28,7 @@ dojo-bindgen.workspace = true dojo-lang.workspace = true dojo-types.workspace = true dojo-world = { workspace = true, features = [ "contracts", "metadata", "migration" ] } +futures.workspace = true notify = "6.0.1" notify-debouncer-mini = "0.3.0" scarb-ui.workspace = true diff --git a/bin/sozo/src/commands/auth.rs b/bin/sozo/src/commands/auth.rs index 98b9587401..a815999667 100644 --- a/bin/sozo/src/commands/auth.rs +++ b/bin/sozo/src/commands/auth.rs @@ -1,7 +1,11 @@ +use std::str::FromStr; + use anyhow::Result; use clap::{Args, Subcommand}; +use dojo_world::contracts::cairo_utils; use dojo_world::metadata::dojo_metadata_from_workspace; use scarb::core::Config; +use starknet_crypto::FieldElement; use super::options::account::AccountOptions; use super::options::starknet::StarknetOptions; @@ -15,17 +19,128 @@ pub struct AuthArgs { pub command: AuthCommand, } +#[derive(Debug, Clone, PartialEq)] +pub struct ModelContract { + pub model: FieldElement, + pub contract: String, +} + +impl FromStr for ModelContract { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split(',').collect(); + + let (model, contract) = match parts.as_slice() { + [model, contract] => (model, contract), + _ => anyhow::bail!( + "Model and contract address are expected to be comma separated: `sozo auth writer \ + model_name,0x1234`" + ), + }; + + let model = cairo_utils::str_to_felt(model) + .map_err(|_| anyhow::anyhow!("Invalid model name: {}", model))?; + + Ok(ModelContract { model, contract: contract.to_string() }) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum ResourceType { + Contract(String), + Model(FieldElement), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct OwnerResource { + pub resource: ResourceType, + pub owner: FieldElement, +} + +impl FromStr for OwnerResource { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.split(',').collect(); + + let (resource_part, owner_part) = match parts.as_slice() { + [resource, owner] => (*resource, *owner), + _ => anyhow::bail!( + "Owner and resource are expected to be comma separated: `sozo auth owner \ + resource_type:resource_name,0x1234`" + ), + }; + + let owner = FieldElement::from_hex_be(owner_part) + .map_err(|_| anyhow::anyhow!("Invalid owner address: {}", owner_part))?; + + let resource_parts = resource_part.split_once(':'); + let resource = match resource_parts { + Some(("contract", name)) => ResourceType::Contract(name.to_string()), + Some(("model", name)) => { + let model = cairo_utils::str_to_felt(name) + .map_err(|_| anyhow::anyhow!("Invalid model name: {}", name))?; + ResourceType::Model(model) + } + _ => anyhow::bail!( + "Resource is expected to be in the format `resource_type:resource_name`: `sozo \ + auth owner 0x1234,resource_type:resource_name`" + ), + }; + + Ok(OwnerResource { owner, resource }) + } +} + #[derive(Debug, Subcommand)] -pub enum AuthCommand { - #[command(about = "Auth a system with the given calldata.")] +pub enum AuthKind { + #[command(about = "Grant a contract permission to write to a model.")] Writer { #[arg(num_args = 1..)] #[arg(required = true)] #[arg(value_name = "model,contract_address")] #[arg(help = "A list of models and contract address to grant write access to. Comma \ separated values to indicate model name and contract address e.g. \ - model_name,0x1234 model_name,0x1111 ")] - models_contracts: Vec, + model_name,path::to::contract model_name,contract_address ")] + models_contracts: Vec, + }, + #[command(about = "Grant ownership of a resource.")] + Owner { + #[arg(num_args = 1..)] + #[arg(required = true)] + #[arg(value_name = "resource,owner_address")] + #[arg(help = "A list of owners and resources to grant ownership to. Comma separated \ + values to indicate owner address and resouce e.g. \ + contract:path::to::contract,0x1234 contract:contract_address,0x1111, \ + model:model_name,0xbeef")] + owners_resources: Vec, + }, +} + +#[derive(Debug, Subcommand)] +pub enum AuthCommand { + #[command(about = "Grant an auth role.")] + Grant { + #[command(subcommand)] + kind: AuthKind, + + #[command(flatten)] + world: WorldOptions, + + #[command(flatten)] + starknet: StarknetOptions, + + #[command(flatten)] + account: AccountOptions, + + #[command(flatten)] + transaction: TransactionOptions, + }, + #[command(about = "Revoke an auth role.")] + Revoke { + #[command(subcommand)] + kind: AuthKind, #[command(flatten)] world: WorldOptions, @@ -54,3 +169,54 @@ impl AuthArgs { config.tokio_handle().block_on(auth::execute(self.command, env_metadata)) } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use starknet_crypto::FieldElement; + + use super::*; + + #[test] + fn test_owner_resource_from_str() { + // Test valid input + let input = "contract:path::to::contract,0x1234"; + let expected_owner = FieldElement::from_hex_be("0x1234").unwrap(); + let expected_resource = ResourceType::Contract("path::to::contract".to_string()); + let expected = OwnerResource { owner: expected_owner, resource: expected_resource }; + let result = OwnerResource::from_str(input).unwrap(); + assert_eq!(result, expected); + + // Test valid input with model + let input = "model:model_name,0x1234"; + let expected_owner = FieldElement::from_hex_be("0x1234").unwrap(); + let expected_model = cairo_utils::str_to_felt("model_name").unwrap(); + let expected_resource = ResourceType::Model(expected_model); + let expected = OwnerResource { owner: expected_owner, resource: expected_resource }; + let result = OwnerResource::from_str(input).unwrap(); + assert_eq!(result, expected); + + // Test invalid input + let input = "invalid_input"; + let result = OwnerResource::from_str(input); + assert!(result.is_err()); + } + + #[test] + fn test_model_contract_from_str() { + // Test valid input + let input = "model_name,0x1234"; + let expected_model = cairo_utils::str_to_felt("model_name").unwrap(); + let expected_contract = "0x1234"; + let expected = + ModelContract { model: expected_model, contract: expected_contract.to_string() }; + let result = ModelContract::from_str(input).unwrap(); + assert_eq!(result, expected); + + // Test invalid input + let input = "invalid_input"; + let result = ModelContract::from_str(input); + assert!(result.is_err()); + } +} diff --git a/bin/sozo/src/commands/execute.rs b/bin/sozo/src/commands/execute.rs index 0a494b1333..22ccf054de 100644 --- a/bin/sozo/src/commands/execute.rs +++ b/bin/sozo/src/commands/execute.rs @@ -7,6 +7,7 @@ use starknet::core::types::FieldElement; use super::options::account::AccountOptions; use super::options::starknet::StarknetOptions; use super::options::transaction::TransactionOptions; +use super::options::world::WorldOptions; use crate::ops::execute; #[derive(Debug, Args)] @@ -31,6 +32,9 @@ pub struct ExecuteArgs { #[command(flatten)] pub account: AccountOptions, + #[command(flatten)] + pub world: WorldOptions, + #[command(flatten)] pub transaction: TransactionOptions, } diff --git a/bin/sozo/src/ops/auth.rs b/bin/sozo/src/ops/auth.rs index 222f96a7ce..03903d4df2 100644 --- a/bin/sozo/src/ops/auth.rs +++ b/bin/sozo/src/ops/auth.rs @@ -1,57 +1,77 @@ use anyhow::{Context, Result}; -use dojo_world::contracts::cairo_utils; use dojo_world::contracts::world::WorldContract; use dojo_world::metadata::Environment; use dojo_world::utils::TransactionWaiter; use starknet::accounts::Account; -use starknet::core::types::FieldElement; -use crate::commands::auth::AuthCommand; +use super::get_contract_address; +use crate::commands::auth::{AuthCommand, AuthKind, ResourceType}; pub async fn execute(command: AuthCommand, env_metadata: Option) -> Result<()> { match command { - AuthCommand::Writer { models_contracts, world, starknet, account, transaction } => { - let world_address = world.address(env_metadata.as_ref())?; - let provider = starknet.provider(env_metadata.as_ref())?; + AuthCommand::Grant { kind, world, starknet, account, transaction } => match kind { + AuthKind::Writer { models_contracts } => { + let world_address = world.address(env_metadata.as_ref())?; + let provider = starknet.provider(env_metadata.as_ref())?; - let account = account.account(&provider, env_metadata.as_ref()).await?; - let world = WorldContract::new(world_address, &account); + let account = account.account(&provider, env_metadata.as_ref()).await?; + let world = WorldContract::new(world_address, &account); - let mut calls = vec![]; + let mut calls = Vec::new(); - for mc in models_contracts { - let parts: Vec<&str> = mc.split(',').collect(); + for mc in models_contracts { + let contract = get_contract_address(&world, mc.contract).await?; + calls.push(world.grant_writer_getcall(&mc.model, &contract.into())); + } - let (model, contract_part) = match parts.as_slice() { - [model, contract] => (model.to_string(), *contract), - _ => anyhow::bail!( - "Model and contract address are expected to be comma separated: `sozo \ - auth writer model_name,0x1234`" - ), - }; + let res = account + .execute(calls) + .send() + .await + .with_context(|| "Failed to send transaction")?; - let contract = FieldElement::from_hex_be(contract_part) - .map_err(|_| anyhow::anyhow!("Invalid contract address: {}", contract_part))?; - - calls.push( - world - .grant_writer_getcall(&cairo_utils::str_to_felt(&model)?, &contract.into()), - ); + if transaction.wait { + let receipt = TransactionWaiter::new(res.transaction_hash, &provider).await?; + println!("{}", serde_json::to_string_pretty(&receipt)?); + } else { + println!("Transaction hash: {:#x}", res.transaction_hash); + } } + AuthKind::Owner { owners_resources } => { + let world_address = world.address(env_metadata.as_ref())?; + let provider = starknet.provider(env_metadata.as_ref())?; + + let account = account.account(&provider, env_metadata.as_ref()).await?; + let world = WorldContract::new(world_address, &account); + + let mut calls = Vec::new(); + + for or in owners_resources { + let resource = match &or.resource { + ResourceType::Model(name) => *name, + ResourceType::Contract(name_or_address) => { + get_contract_address(&world, name_or_address.clone()).await? + } + }; + + calls.push(world.grant_owner_getcall(&or.owner.into(), &resource)); + } + + let res = account + .execute(calls) + .send() + .await + .with_context(|| "Failed to send transaction")?; - let res = account - .execute(calls) - .send() - .await - .with_context(|| "Failed to send transaction")?; - - if transaction.wait { - let receipt = TransactionWaiter::new(res.transaction_hash, &provider).await?; - println!("{}", serde_json::to_string_pretty(&receipt)?); - } else { - println!("Transaction hash: {:#x}", res.transaction_hash); + if transaction.wait { + let receipt = TransactionWaiter::new(res.transaction_hash, &provider).await?; + println!("{}", serde_json::to_string_pretty(&receipt)?); + } else { + println!("Transaction hash: {:#x}", res.transaction_hash); + } } - } + }, + _ => todo!(), } Ok(()) diff --git a/bin/sozo/src/ops/execute.rs b/bin/sozo/src/ops/execute.rs index 868f69993c..fd0f8d1373 100644 --- a/bin/sozo/src/ops/execute.rs +++ b/bin/sozo/src/ops/execute.rs @@ -1,50 +1,24 @@ use anyhow::{Context, Result}; +use dojo_world::contracts::world::WorldContract; use dojo_world::metadata::Environment; -use dojo_world::migration::strategy::generate_salt; use dojo_world::utils::TransactionWaiter; use starknet::accounts::{Account, Call}; -use starknet::core::types::{BlockId, BlockTag, FieldElement, FunctionCall}; -use starknet::core::utils::{get_contract_address, get_selector_from_name}; -use starknet::macros::selector; -use starknet::providers::Provider; +use starknet::core::utils::get_selector_from_name; +use super::get_contract_address; use crate::commands::execute::ExecuteArgs; pub async fn execute(args: ExecuteArgs, env_metadata: Option) -> Result<()> { - let ExecuteArgs { contract, entrypoint, calldata, starknet, account, transaction } = args; + let ExecuteArgs { contract, entrypoint, calldata, starknet, world, account, transaction } = + args; let provider = starknet.provider(env_metadata.as_ref())?; - let contract_address = if contract.starts_with("0x") { - FieldElement::from_hex_be(&contract)? - } else { - let world_address = env_metadata - .as_ref() - .and_then(|env| env.world_address.as_ref()) - .cloned() - .ok_or_else(|| anyhow::anyhow!("No World Address found"))?; - - let contract_class_hash = provider - .call( - FunctionCall { - contract_address: FieldElement::from_hex_be(&world_address).unwrap(), - entry_point_selector: selector!("base"), - calldata: [].to_vec(), - }, - BlockId::Tag(BlockTag::Latest), - ) - .await?; - - get_contract_address( - generate_salt(&contract), - contract_class_hash[0], - &[], - FieldElement::from_hex_be(&world_address).unwrap(), - ) - }; - let account = account.account(&provider, env_metadata.as_ref()).await?; + let world_address = world.address(env_metadata.as_ref())?; + let world = WorldContract::new(world_address, &account); + let contract_address = get_contract_address(&world, contract).await?; let res = account .execute(vec![Call { calldata, diff --git a/bin/sozo/src/ops/mod.rs b/bin/sozo/src/ops/mod.rs index abf4eb9c5c..509e266d27 100644 --- a/bin/sozo/src/ops/mod.rs +++ b/bin/sozo/src/ops/mod.rs @@ -1,6 +1,29 @@ +use anyhow::Result; +use dojo_world::contracts::world::WorldContract; +use dojo_world::migration::strategy::generate_salt; +use starknet::accounts::ConnectedAccount; +use starknet::core::types::FieldElement; + pub mod auth; pub mod events; pub mod execute; pub mod migration; pub mod model; pub mod register; + +pub async fn get_contract_address( + world: &WorldContract, + name_or_address: String, +) -> Result { + if name_or_address.starts_with("0x") { + FieldElement::from_hex_be(&name_or_address).map_err(anyhow::Error::from) + } else { + let contract_class_hash = world.base().call().await?; + Ok(starknet::core::utils::get_contract_address( + generate_salt(&name_or_address), + contract_class_hash.into(), + &[], + world.address, + )) + } +}