Skip to content

Commit

Permalink
Support partial session approval (#1213)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarrencev authored Jan 6, 2025
1 parent 2fc8897 commit ca079a6
Show file tree
Hide file tree
Showing 17 changed files with 305 additions and 83 deletions.
24 changes: 21 additions & 3 deletions packages/account-wasm/src/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ impl CartridgeAccount {
.is_some())
}

#[wasm_bindgen(js_name = session)]
pub async fn session_metadata(
#[wasm_bindgen(js_name = getAuthorizedSessionMetadata)]
pub async fn authorized_session_metadata(
&self,
policies: Vec<Policy>,
public_key: Option<JsFelt>,
Expand All @@ -302,10 +302,28 @@ impl CartridgeAccount {
.controller
.lock()
.await
.session_metadata(&policies, public_key.map(|f| f.0))
.authorized_session_metadata(&policies, public_key.map(|f| f.0))
.map(|(_, metadata)| SessionMetadata::from(metadata)))
}

#[wasm_bindgen(js_name = isRequestedSession)]
pub async fn is_requested_session(
&self,
policies: Vec<Policy>,
public_key: Option<JsFelt>,
) -> std::result::Result<bool, JsControllerError> {
let policies = policies
.into_iter()
.map(TryFrom::try_from)
.collect::<std::result::Result<Vec<_>, _>>()?;

Ok(self
.controller
.lock()
.await
.is_requested_session(&policies, public_key.map(|f| f.0)))
}

#[wasm_bindgen(js_name = revokeSession)]
pub fn revoke_session(&self) -> Result<()> {
unimplemented!("Revoke Session not implemented");
Expand Down
25 changes: 19 additions & 6 deletions packages/account-wasm/src/types/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ use super::EncodingError;
pub struct CallPolicy {
pub target: String,
pub method: String,
#[tsify(optional)]
pub authorized: Option<bool>,
}

#[derive(Tsify, Serialize, Deserialize, Debug, Clone)]
#[tsify(into_wasm_abi, from_wasm_abi)]
pub struct TypedDataPolicy {
pub scope_hash: String,
#[tsify(optional)]
pub authorized: Option<bool>,
}

#[allow(non_snake_case)]
Expand All @@ -45,15 +49,22 @@ impl TryFrom<Policy> for SdkPolicy {

fn try_from(value: Policy) -> Result<Self, Self::Error> {
match value {
Policy::Call(CallPolicy { target, method }) => Ok(SdkPolicy::Call(SdkCallPolicy {
Policy::Call(CallPolicy {
target,
method,
authorized,
}) => Ok(SdkPolicy::Call(SdkCallPolicy {
contract_address: Felt::from_str(&target)?,
selector: get_selector_from_name(&method).unwrap(),
authorized,
})),
Policy::TypedData(TypedDataPolicy {
scope_hash,
authorized,
}) => Ok(SdkPolicy::TypedData(SdkTypedDataPolicy {
scope_hash: Felt::from_str(&scope_hash)?,
authorized,
})),
Policy::TypedData(TypedDataPolicy { scope_hash }) => {
Ok(SdkPolicy::TypedData(SdkTypedDataPolicy {
scope_hash: Felt::from_str(&scope_hash)?,
}))
}
}
}
}
Expand All @@ -64,9 +75,11 @@ impl From<SdkPolicy> for Policy {
SdkPolicy::Call(call_policy) => Policy::Call(CallPolicy {
target: call_policy.contract_address.to_string(),
method: call_policy.selector.to_string(),
authorized: call_policy.authorized,
}),
SdkPolicy::TypedData(typed_data_policy) => Policy::TypedData(TypedDataPolicy {
scope_hash: typed_data_policy.scope_hash.to_string(),
authorized: typed_data_policy.authorized,
}),
}
}
Expand Down
2 changes: 1 addition & 1 deletion packages/account-wasm/src/types/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl From<account_sdk::account::session::hash::Session> for Session {
fn from(value: account_sdk::account::session::hash::Session) -> Self {
Session {
policies: value
.policies
.proved_policies
.into_iter()
.map(|p| p.policy.into())
.collect::<Vec<_>>(),
Expand Down
47 changes: 31 additions & 16 deletions packages/account_sdk/src/account/session/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use super::policy::ProvedPolicy;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Session {
pub inner: crate::abigen::controller::Session,
pub policies: Vec<ProvedPolicy>,
pub requested_policies: Vec<Policy>,
pub proved_policies: Vec<ProvedPolicy>,
pub metadata: String,
}

Expand All @@ -37,23 +38,32 @@ impl Session {
session_signer: &Signer,
guardian_guid: Felt,
) -> Result<Self, SignError> {
if policies.is_empty() {
return Err(SignError::NoAllowedSessionMethods);
}
let metadata = json!({ "metadata": "metadata", "max_fee": 0 });
let hashes = policies

// Only include authorized policies in the merkle tree
let authorized_policies: Vec<_> = policies.iter().filter(|&p| p.is_authorized()).collect();

let hashes = authorized_policies
.iter()
.map(Policy::as_merkle_leaf)
.map(|&p| Policy::as_merkle_leaf(p))
.collect::<Vec<Felt>>();
let policies: Vec<_> = policies

let proved_policies: Vec<_> = authorized_policies
.clone()
.into_iter()
.enumerate()
.map(|(i, method)| ProvedPolicy {
policy: method,
.map(|(i, policy)| ProvedPolicy {
policy: policy.clone(),
proof: MerkleTree::compute_proof(hashes.clone(), i),
})
.collect();
let root = MerkleTree::compute_root(hashes[0], policies[0].proof.clone());

let root = if authorized_policies.is_empty() {
Felt::ZERO
} else {
MerkleTree::compute_root(hashes[0], proved_policies[0].proof.clone())
};

Ok(Self {
inner: crate::abigen::controller::Session {
expires_at,
Expand All @@ -62,7 +72,8 @@ impl Session {
guardian_key_guid: guardian_guid,
metadata_hash: Felt::ZERO,
},
policies,
requested_policies: policies,
proved_policies,
metadata: serde_json::to_string(&metadata).unwrap(),
})
}
Expand All @@ -77,16 +88,16 @@ impl Session {
}

pub fn single_proof(&self, policy: &Policy) -> Option<Vec<Felt>> {
self.policies
self.proved_policies
.iter()
.find(|ProvedPolicy { policy: method, .. }| method == policy)
.find(|ProvedPolicy { policy: p, .. }| p == policy)
.map(|ProvedPolicy { proof, .. }| proof.clone())
}

pub fn is_authorized(&self, policy: &Policy) -> bool {
self.policies
.iter()
.any(|proved_policy| proved_policy.policy == *policy)
self.proved_policies.iter().any(|proved_policy| {
proved_policy.policy == *policy && proved_policy.policy.is_authorized()
})
}

pub fn is_expired(&self) -> bool {
Expand All @@ -103,6 +114,10 @@ impl Session {

self.inner.session_key_guid == session_key_guid
}

pub fn is_requested(&self, policy: &Policy) -> bool {
self.requested_policies.iter().any(|p| p == policy)
}
}

impl StructHashRev1 for abigen::controller::Session {
Expand Down
45 changes: 40 additions & 5 deletions packages/account_sdk/src/account/session/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,56 @@ impl Policy {
Policy::Call(CallPolicy {
contract_address,
selector,
authorized: Some(true),
})
}

pub fn new_typed_data(scope_hash: Felt) -> Self {
Policy::TypedData(TypedDataPolicy { scope_hash })
Policy::TypedData(TypedDataPolicy {
scope_hash,
authorized: Some(true),
})
}

pub fn is_authorized(&self) -> bool {
match self {
Policy::Call(call) => call.authorized.unwrap_or(false),
Policy::TypedData(typed_data) => typed_data.authorized.unwrap_or(false),
}
}
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
impl PartialEq for CallPolicy {
fn eq(&self, other: &Self) -> bool {
self.contract_address == other.contract_address && self.selector == other.selector
}
}

impl PartialEq for TypedDataPolicy {
fn eq(&self, other: &Self) -> bool {
self.scope_hash == other.scope_hash
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CallPolicy {
pub contract_address: Felt,
pub selector: Felt,
pub authorized: Option<bool>,
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TypedDataPolicy {
pub scope_hash: Felt,
pub authorized: Option<bool>,
}

impl From<&Call> for Policy {
fn from(call: &Call) -> Self {
Policy::Call(CallPolicy {
contract_address: call.to,
selector: call.selector,
authorized: Some(true),
})
}
}
Expand All @@ -55,6 +82,7 @@ impl From<&TypedData> for Policy {
fn from(typed_data: &TypedData) -> Self {
Self::TypedData(TypedDataPolicy {
scope_hash: typed_data.scope_hash,
authorized: Some(true),
})
}
}
Expand Down Expand Up @@ -90,8 +118,15 @@ pub trait MerkleLeaf {
impl MerkleLeaf for Policy {
fn as_merkle_leaf(&self) -> Felt {
match self {
Policy::Call(call_policy) => call_policy.as_merkle_leaf(),
Policy::TypedData(typed_data_policy) => typed_data_policy.as_merkle_leaf(),
Policy::Call(call_policy) if call_policy.authorized.unwrap_or(false) => {
call_policy.as_merkle_leaf()
}
Policy::TypedData(typed_data_policy)
if typed_data_policy.authorized.unwrap_or(false) =>
{
typed_data_policy.as_merkle_leaf()
}
_ => Felt::ZERO,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions packages/account_sdk/src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl Controller {
match est {
Ok(mut fee_estimate) => {
if self
.session_metadata(&Policy::from_calls(&calls), None)
.authorized_session_metadata(&Policy::from_calls(&calls), None)
.map_or(true, |(_, metadata)| !metadata.is_registered)
{
fee_estimate.overall_fee += WEBAUTHN_GAS * fee_estimate.gas_price;
Expand Down Expand Up @@ -252,7 +252,7 @@ impl Controller {

// Update is_registered to true after successful execution with a session
if let Some((key, metadata)) =
self.session_metadata(&Policy::from_calls(&calls), None)
self.authorized_session_metadata(&Policy::from_calls(&calls), None)
{
if !metadata.is_registered {
let mut updated_metadata = metadata;
Expand Down
4 changes: 3 additions & 1 deletion packages/account_sdk/src/execute_from_outside.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ impl Controller {
.map_err(ControllerError::PaymasterError)?;

// Update is_registered to true after successful execution with a session
if let Some((key, metadata)) = self.session_metadata(&Policy::from_calls(&calls), None) {
if let Some((key, metadata)) =
self.authorized_session_metadata(&Policy::from_calls(&calls), None)
{
if !metadata.is_registered {
let mut updated_metadata = metadata;
updated_metadata.is_registered = true;
Expand Down
4 changes: 2 additions & 2 deletions packages/account_sdk/src/execute_from_outside_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async fn test_execute_from_outside_with_session() {

// Check that the session is not registered initially
let (_, initial_metadata) = controller
.session_metadata(&Policy::from_calls(&[]), None)
.authorized_session_metadata(&Policy::from_calls(&[]), None)
.expect("Failed to get session metadata");
assert!(
!initial_metadata.is_registered,
Expand Down Expand Up @@ -140,7 +140,7 @@ async fn test_execute_from_outside_with_session() {

// Check that the session is registered
let (_, metadata) = controller
.session_metadata(&Policy::from_calls(&[]), None)
.authorized_session_metadata(&Policy::from_calls(&[]), None)
.expect("Failed to get session metadata");
assert!(metadata.is_registered, "Session should be registered");
}
22 changes: 18 additions & 4 deletions packages/account_sdk/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,37 @@ impl Controller {
Ok(txn)
}

pub fn session_metadata(
pub fn authorized_session_metadata(
&self,
policies: &[Policy],
public_key: Option<Felt>,
) -> Option<(String, SessionMetadata)> {
let key: String = Selectors::session(&self.address, &self.app_id, &self.chain_id);
let key = self.session_key();
self.storage
.session(&key)
.ok()
.flatten()
.filter(|metadata| metadata.is_valid(policies, public_key))
.filter(|metadata| metadata.is_authorized(policies, public_key))
.map(|metadata| (key, metadata))
}

pub fn is_requested_session(&self, policies: &[Policy], public_key: Option<Felt>) -> bool {
let key = self.session_key();
self.storage
.session(&key)
.ok()
.flatten()
.filter(|metadata| metadata.is_requested(policies, public_key))
.is_some()
}

pub fn session_key(&self) -> String {
Selectors::session(&self.address, &self.app_id, &self.chain_id)
}

pub fn session_account(&self, policies: &[Policy]) -> Option<SessionAccount> {
// Check if there's a valid session stored
let (_, metadata) = self.session_metadata(policies, None)?;
let (_, metadata) = self.authorized_session_metadata(policies, None)?;
let credentials = metadata.credentials.as_ref()?;
let session_signer =
Signer::Starknet(SigningKey::from_secret_scalar(credentials.private_key));
Expand Down
Loading

0 comments on commit ca079a6

Please sign in to comment.