diff --git a/packages/account-wasm/src/account.rs b/packages/account-wasm/src/account.rs index 84be4e9c4..d4c45cdaf 100644 --- a/packages/account-wasm/src/account.rs +++ b/packages/account-wasm/src/account.rs @@ -1,3 +1,5 @@ +use std::borrow::BorrowMut; + use account_sdk::account::session::policy::Policy as SdkPolicy; use account_sdk::controller::Controller; use account_sdk::errors::ControllerError; @@ -13,6 +15,7 @@ use url::Url; use wasm_bindgen::prelude::*; use crate::errors::JsControllerError; +use crate::sync::WasmMutex; use crate::types::call::JsCall; use crate::types::invocation::JsInvocationsDetails; use crate::types::policy::Policy; @@ -25,7 +28,7 @@ type Result = std::result::Result; #[wasm_bindgen] pub struct CartridgeAccount { - controller: Controller, + controller: WasmMutex, } #[wasm_bindgen] @@ -83,15 +86,17 @@ impl CartridgeAccount { } #[wasm_bindgen(js_name = disconnect)] - pub fn disconnect(&mut self) -> std::result::Result<(), JsControllerError> { + pub async fn disconnect(&self) -> std::result::Result<(), JsControllerError> { self.controller + .lock() + .await .disconnect() .map_err(JsControllerError::from) } #[wasm_bindgen(js_name = registerSession)] pub async fn register_session( - &mut self, + &self, policies: Vec, expires_at: u64, public_key: JsFelt, @@ -104,6 +109,8 @@ impl CartridgeAccount { let res = self .controller + .lock() + .await .register_session(methods, expires_at, public_key.0, Felt::ZERO, max_fee.0) .await .map_err(JsControllerError::from)?; @@ -112,8 +119,8 @@ impl CartridgeAccount { } #[wasm_bindgen(js_name = registerSessionCalldata)] - pub fn register_session_calldata( - &mut self, + pub async fn register_session_calldata( + &self, policies: Vec, expires_at: u64, public_key: JsFelt, @@ -122,19 +129,22 @@ impl CartridgeAccount { .into_iter() .map(TryFrom::try_from) .collect::, _>>()?; - let call = - self.controller - .register_session_call(methods, expires_at, public_key.0, Felt::ZERO)?; + let call = self.controller.lock().await.register_session_call( + methods, + expires_at, + public_key.0, + Felt::ZERO, + )?; Ok(to_value(&call.calldata)?) } #[wasm_bindgen(js_name = upgrade)] - pub fn upgrade( + pub async fn upgrade( &self, new_class_hash: JsFelt, ) -> std::result::Result { - let call = self.controller.upgrade(new_class_hash.0); + let call = self.controller.lock().await.upgrade(new_class_hash.0); Ok(JsCall { contract_address: call.to, entrypoint: "upgrade".to_string(), @@ -143,7 +153,7 @@ impl CartridgeAccount { } #[wasm_bindgen(js_name = createSession)] - pub async fn create_session(&mut self, policies: Vec, expires_at: u64) -> Result<()> { + pub async fn create_session(&self, policies: Vec, expires_at: u64) -> Result<()> { set_panic_hook(); let methods = policies @@ -151,7 +161,11 @@ impl CartridgeAccount { .map(TryFrom::try_from) .collect::, _>>()?; - self.controller.create_session(methods, expires_at).await?; + self.controller + .lock() + .await + .create_session(methods, expires_at) + .await?; Ok(()) } @@ -168,13 +182,18 @@ impl CartridgeAccount { .map(TryFrom::try_from) .collect::, _>>()?; - let fee_estimate = self.controller.estimate_invoke_fee(calls).await?; + let fee_estimate = self + .controller + .lock() + .await + .estimate_invoke_fee(calls) + .await?; Ok(to_value(&fee_estimate)?) } #[wasm_bindgen(js_name = execute)] pub async fn execute( - &mut self, + &self, calls: Vec, details: JsInvocationsDetails, ) -> std::result::Result { @@ -185,14 +204,19 @@ impl CartridgeAccount { .map(TryFrom::try_from) .collect::, _>>()?; - let result = Controller::execute(&mut self.controller, calls, details.max_fee).await?; + let result = Controller::execute( + self.controller.lock().await.borrow_mut(), + calls, + details.max_fee, + ) + .await?; Ok(to_value(&result)?) } #[wasm_bindgen(js_name = executeFromOutsideV2)] pub async fn execute_from_outside_v2( - &mut self, + &self, calls: Vec, ) -> std::result::Result { set_panic_hook(); @@ -202,13 +226,18 @@ impl CartridgeAccount { .map(TryInto::try_into) .collect::>()?; - let response = self.controller.execute_from_outside_v2(calls).await?; + let response = self + .controller + .lock() + .await + .execute_from_outside_v2(calls) + .await?; Ok(to_value(&response)?) } #[wasm_bindgen(js_name = executeFromOutsideV3)] pub async fn execute_from_outside_v3( - &mut self, + &self, calls: Vec, ) -> std::result::Result { set_panic_hook(); @@ -218,12 +247,17 @@ impl CartridgeAccount { .map(TryInto::try_into) .collect::>()?; - let response = self.controller.execute_from_outside_v3(calls).await?; + let response = self + .controller + .lock() + .await + .execute_from_outside_v3(calls) + .await?; Ok(to_value(&response)?) } #[wasm_bindgen(js_name = hasSession)] - pub fn has_session(&self, calls: Vec) -> Result { + pub async fn has_session(&self, calls: Vec) -> Result { let calls: Vec = calls .into_iter() .map(TryFrom::try_from) @@ -231,12 +265,14 @@ impl CartridgeAccount { Ok(self .controller + .lock() + .await .session_account(&SdkPolicy::from_calls(&calls)) .is_some()) } #[wasm_bindgen(js_name = hasSessionForMessage)] - pub fn has_session_for_message(&self, typed_data: String) -> Result { + pub async fn has_session_for_message(&self, typed_data: String) -> Result { let typed_data: TypedData = serde_json::from_str(&typed_data)?; let domain_hash = typed_data.domain.encode(&typed_data.types)?; let type_hash = @@ -245,12 +281,14 @@ impl CartridgeAccount { Ok(self .controller + .lock() + .await .session_account(&[SdkPolicy::new_typed_data(scope_hash)]) .is_some()) } #[wasm_bindgen(js_name = session)] - pub fn session_metadata( + pub async fn session_metadata( &self, policies: Vec, public_key: Option, @@ -262,6 +300,8 @@ impl CartridgeAccount { Ok(self .controller + .lock() + .await .session_metadata(&policies, public_key.map(|f| f.0)) .map(|(_, metadata)| SessionMetadata::from(metadata))) } @@ -277,6 +317,8 @@ impl CartridgeAccount { let signature = self .controller + .lock() + .await .sign_message(serde_json::from_str(&typed_data)?) .await .map_err(|e| JsControllerError::from(ControllerError::SignError(e)))?; @@ -288,6 +330,8 @@ impl CartridgeAccount { pub async fn get_nonce(&self) -> std::result::Result { let nonce = self .controller + .lock() + .await .get_nonce() .await .map_err(|e| JsControllerError::from(ControllerError::ProviderError(e)))?; @@ -301,6 +345,8 @@ impl CartridgeAccount { let res = self .controller + .lock() + .await .deploy() .max_fee(max_fee.0) .send() @@ -316,6 +362,8 @@ impl CartridgeAccount { let res = self .controller + .lock() + .await .delegate_account() .await .map_err(JsControllerError::from)?; @@ -407,7 +455,9 @@ impl CartridgeAccountWithMeta { fn new(controller: Controller) -> Self { let meta = CartridgeAccountMeta::new(&controller); Self { - account: CartridgeAccount { controller }, + account: CartridgeAccount { + controller: WasmMutex::new(controller), + }, meta, } } diff --git a/packages/keychain/src/components/ConfirmTransaction.tsx b/packages/keychain/src/components/ConfirmTransaction.tsx index b60f52d5f..9141fdec9 100644 --- a/packages/keychain/src/components/ConfirmTransaction.tsx +++ b/packages/keychain/src/components/ConfirmTransaction.tsx @@ -1,4 +1,4 @@ -import { useMemo, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import { ResponseCodes, SessionPolicies, toArray } from "@cartridge/controller"; import { Content, FOOTER_MIN_HEIGHT } from "components/layout"; import { TransactionDuoIcon } from "@cartridge/ui"; @@ -12,6 +12,7 @@ import { CreateSession } from "./connect"; export function ConfirmTransaction() { const { controller, context, origin, policies, setContext } = useConnection(); const [policiesUpdated, setIsPoliciesUpdated] = useState(false); + const [updateSession, setUpdateSession] = useState(false); const ctx = context as ExecuteCtx; const account = controller; @@ -46,8 +47,11 @@ export function ConfirmTransaction() { [ctx.transactions], ); - const updateSession = useMemo(() => { - if (policiesUpdated) return false; + useEffect(() => { + if (policiesUpdated) { + setUpdateSession(false); + return; + } const entries = Object.entries(callPolicies.contracts || {}); const txnsApproved = entries.every(([target, policy]) => { @@ -60,7 +64,13 @@ export function ConfirmTransaction() { // If calls are approved by dapp specified policies but not stored session // then prompt user to update session. This also accounts for expired sessions. - return txnsApproved && !account?.session(callPolicies); + if (txnsApproved && account) { + account.session(callPolicies).then((hasSession) => { + setUpdateSession(!hasSession); + }); + } else { + setUpdateSession(false); + } }, [callPolicies, policiesUpdated, policies, account]); if (updateSession && policies) { diff --git a/packages/keychain/src/components/connect/RegisterSession.tsx b/packages/keychain/src/components/connect/RegisterSession.tsx index f5db1af5b..dac9c7180 100644 --- a/packages/keychain/src/components/connect/RegisterSession.tsx +++ b/packages/keychain/src/components/connect/RegisterSession.tsx @@ -1,5 +1,5 @@ import { Content } from "components/layout"; -import { useCallback, useMemo, useState } from "react"; +import { useCallback, useEffect, useState } from "react"; import { useConnection } from "hooks/connection"; import { SessionConsent } from "components/connect"; import { ExecutionContainer } from "components/ExecutionContainer"; @@ -23,23 +23,31 @@ export function RegisterSession({ }) { const { controller, theme } = useConnection(); const [expiresAt] = useState(SESSION_EXPIRATION); + const [transactions, setTransactions] = useState< + | { + contractAddress: string; + entrypoint: string; + calldata: string[]; + }[] + | undefined + >(undefined); - const transactions = useMemo(() => { - if (!publicKey || !controller) return; - - const calldata = controller.registerSessionCalldata( - expiresAt, - policies, - publicKey, - ); - - return [ - { - contractAddress: controller.address, - entrypoint: "register_session", - calldata, - }, - ]; + useEffect(() => { + if (!publicKey || !controller) { + setTransactions(undefined); + } else { + controller + .registerSessionCalldata(expiresAt, policies, publicKey) + .then((calldata) => { + setTransactions([ + { + contractAddress: controller.address, + entrypoint: "register_session", + calldata, + }, + ]); + }); + } }, [controller, expiresAt, policies, publicKey]); const onRegisterSession = useCallback( diff --git a/packages/keychain/src/hooks/connection.ts b/packages/keychain/src/hooks/connection.ts index 57e9c73a6..e9e043cf7 100644 --- a/packages/keychain/src/hooks/connection.ts +++ b/packages/keychain/src/hooks/connection.ts @@ -224,12 +224,13 @@ export function useConnectionValue() { }, [rpcUrl, controller]); const logout = useCallback(() => { - window.controller?.disconnect(); - setController(undefined); + window.controller?.disconnect().then(() => { + setController(undefined); - context?.resolve?.({ - code: ResponseCodes.NOT_CONNECTED, - message: "User logged out", + context?.resolve?.({ + code: ResponseCodes.NOT_CONNECTED, + message: "User logged out", + }); }); }, [context, setController]); diff --git a/packages/keychain/src/hooks/upgrade.ts b/packages/keychain/src/hooks/upgrade.ts index 0ad094071..5efa4ce72 100644 --- a/packages/keychain/src/hooks/upgrade.ts +++ b/packages/keychain/src/hooks/upgrade.ts @@ -1,4 +1,5 @@ -import { useCallback, useEffect, useMemo, useState } from "react"; +import { JsCall } from "@cartridge/account-wasm"; +import { useCallback, useEffect, useState } from "react"; import { addAddressPadding, Call } from "starknet"; import { ControllerError } from "utils/connection"; import Controller from "utils/controller"; @@ -71,6 +72,7 @@ export const useUpgrade = (controller?: Controller): UpgradeInterface => { const [isSynced, setIsSynced] = useState(false); const [isUpgrading, setIsUpgrading] = useState(false); const [current, setCurrent] = useState(); + const [calls, setCalls] = useState([]); useEffect(() => { if (!controller) { @@ -105,12 +107,14 @@ export const useUpgrade = (controller?: Controller): UpgradeInterface => { .finally(() => setIsSynced(true)); }, [controller]); - const calls = useMemo(() => { + useEffect(() => { if (!controller || !LATEST_CONTROLLER) { - return []; + setCalls([]); + } else { + controller.upgrade(LATEST_CONTROLLER.hash).then((call) => { + setCalls([call]); + }); } - - return [controller.upgrade(LATEST_CONTROLLER.hash)]; }, [controller]); const onUpgrade = useCallback(async () => { diff --git a/packages/keychain/src/pages/index.tsx b/packages/keychain/src/pages/index.tsx index c2106386d..56893430d 100644 --- a/packages/keychain/src/pages/index.tsx +++ b/packages/keychain/src/pages/index.tsx @@ -10,11 +10,13 @@ import { ErrorPage } from "components/ErrorBoundary"; import { Settings } from "components/Settings"; import { Upgrade } from "components/connect/Upgrade"; import { PurchaseCredits } from "components/Funding/PurchaseCredits"; -import { useEffect } from "react"; +import { useEffect, useState } from "react"; import { usePostHog } from "posthog-js/react"; function Home() { const { context, controller, error, policies, upgrade } = useConnection(); + const [hasSessionForPolicies, setHasSessionForPolicies] = + useState(false); const posthog = usePostHog(); useEffect(() => { @@ -33,6 +35,16 @@ function Home() { } }, [context?.origin, posthog]); + useEffect(() => { + if (controller && policies) { + controller.session(policies).then((session) => { + setHasSessionForPolicies(!!session); + }); + } else { + setHasSessionForPolicies(false); + } + }, [controller, policies]); + if (window.self === window.top || !context?.origin) { return <>; } @@ -73,7 +85,7 @@ function Home() { return <>; } - if (controller.session(policies)) { + if (hasSessionForPolicies) { context.resolve({ code: ResponseCodes.SUCCESS, address: controller.address, diff --git a/packages/keychain/src/pages/session.tsx b/packages/keychain/src/pages/session.tsx index 630a6e194..fb0d94d97 100644 --- a/packages/keychain/src/pages/session.tsx +++ b/packages/keychain/src/pages/session.tsx @@ -125,19 +125,21 @@ export default function Session() { // If the requested policies has no mismatch with existing policies and public key already // registered then return the exising session - if (controller.session(policies, queries.public_key)) { - onCallback({ - username: controller.username(), - address: controller.address, - ownerGuid: controller.ownerGuid(), - alreadyRegistered: true, - expiresAt: String(SESSION_EXPIRATION), - }); + controller.session(policies, queries.public_key).then((session) => { + if (session) { + onCallback({ + username: controller.username(), + address: controller.address, + ownerGuid: controller.ownerGuid(), + alreadyRegistered: true, + expiresAt: String(SESSION_EXPIRATION), + }); - return; - } + return; + } - setIsLoading(false); + setIsLoading(false); + }); }, [controller, origin, policies, queries.public_key, onCallback]); if (!controller) { diff --git a/packages/keychain/src/utils/connection/execute.ts b/packages/keychain/src/utils/connection/execute.ts index f3c96648f..5609d0c46 100644 --- a/packages/keychain/src/utils/connection/execute.ts +++ b/packages/keychain/src/utils/connection/execute.ts @@ -76,7 +76,7 @@ export function execute({ // If a session call and there is no session available // fallback to manual apporval flow - if (!account.hasSession(calls)) { + if (!(await account.hasSession(calls))) { setContext({ type: "execute", origin, diff --git a/packages/keychain/src/utils/connection/index.ts b/packages/keychain/src/utils/connection/index.ts index 0427b12f5..6b7b99046 100644 --- a/packages/keychain/src/utils/connection/index.ts +++ b/packages/keychain/src/utils/connection/index.ts @@ -47,12 +47,14 @@ export function connectToController({ reset: () => () => setContext(undefined), fetchControllers: fetchControllers, disconnect: () => () => { - window.controller?.disconnect(); - setController(undefined); + window.controller?.disconnect().then(() => { + setController(undefined); + }); }, logout: () => () => { - window.controller?.disconnect(); - setController(undefined); + window.controller?.disconnect().then(() => { + setController(undefined); + }); }, username: () => () => window.controller?.username(), delegateAccount: () => () => window.controller?.delegateAccount(), diff --git a/packages/keychain/src/utils/connection/probe.ts b/packages/keychain/src/utils/connection/probe.ts index a792ff94a..812497dc9 100644 --- a/packages/keychain/src/utils/connection/probe.ts +++ b/packages/keychain/src/utils/connection/probe.ts @@ -18,8 +18,9 @@ export function probeFactory({ } if (rpcUrl !== controller.rpcUrl()) { - controller.disconnect(); - setController(undefined); + controller.disconnect().then(() => { + setController(undefined); + }); return Promise.reject({ code: ResponseCodes.NOT_CONNECTED, }); diff --git a/packages/keychain/src/utils/connection/sign.ts b/packages/keychain/src/utils/connection/sign.ts index cc25df3d7..2faa3f49c 100644 --- a/packages/keychain/src/utils/connection/sign.ts +++ b/packages/keychain/src/utils/connection/sign.ts @@ -35,7 +35,7 @@ export function signMessageFactory(setContext: (ctx: ConnectionCtx) => void) { async (resolve, reject) => { // If a session call and there is no session available // fallback to manual apporval flow - if (!controller.hasSessionForMessage(typedData)) { + if (!(await controller.hasSessionForMessage(typedData))) { setContext({ type: "sign-message", origin, diff --git a/packages/keychain/src/utils/controller.ts b/packages/keychain/src/utils/controller.ts index 681b258eb..f643abe97 100644 --- a/packages/keychain/src/utils/controller.ts +++ b/packages/keychain/src/utils/controller.ts @@ -92,8 +92,8 @@ export default class Controller extends Account { return this.cartridgeMeta.chainId(); } - disconnect() { - this.cartridge.disconnect(); + async disconnect() { + await this.cartridge.disconnect(); delete window.controller; } @@ -109,12 +109,12 @@ export default class Controller extends Account { await this.cartridge.createSession(toWasmPolicies(policies), expiresAt); } - registerSessionCalldata( + async registerSessionCalldata( expiresAt: bigint, policies: SessionPolicies, publicKey: string, - ): Array { - return this.cartridge.registerSessionCalldata( + ): Promise> { + return await this.cartridge.registerSessionCalldata( toWasmPolicies(policies), expiresAt, publicKey, @@ -139,8 +139,8 @@ export default class Controller extends Account { ); } - upgrade(new_class_hash: JsFelt): JsCall { - return this.cartridge.upgrade(new_class_hash); + async upgrade(new_class_hash: JsFelt): Promise { + return await this.cartridge.upgrade(new_class_hash); } async executeFromOutsideV2(calls: Call[]): Promise { @@ -169,19 +169,19 @@ export default class Controller extends Account { ); } - hasSession(calls: Call[]): boolean { - return this.cartridge.hasSession(toJsCalls(calls)); + async hasSession(calls: Call[]): Promise { + return await this.cartridge.hasSession(toJsCalls(calls)); } - hasSessionForMessage(typedData: TypedData): boolean { - return this.cartridge.hasSessionForMessage(JSON.stringify(typedData)); + async hasSessionForMessage(typedData: TypedData): Promise { + return await this.cartridge.hasSessionForMessage(JSON.stringify(typedData)); } - session( + async session( policies: SessionPolicies, public_key?: string, - ): SessionMetadata | undefined { - return this.cartridge.session(toWasmPolicies(policies), public_key); + ): Promise { + return await this.cartridge.session(toWasmPolicies(policies), public_key); } async estimateInvokeFee(