Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unsynchronized concurrent wasm calls #1155

Merged
merged 6 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 179 additions & 56 deletions packages/account-wasm/src/account.rs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions packages/account-wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ pub mod account;
pub mod session;

mod errors;
mod sync;
mod types;
mod utils;
70 changes: 70 additions & 0 deletions packages/account-wasm/src/sync.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::{
ops::{Deref, DerefMut},
sync::{Mutex as StdMutex, MutexGuard as StdMutexGuard},
};

use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::js_sys::Function;

/// A mutex implementation backed by JavaScript `Promise`s.
///
/// This type wraps a simple JavaScript `Mutex` implementation but exposes an idiomatic Rust API.
pub struct WasmMutex<T> {
js_lock: Mutex,
rs_lock: StdMutex<T>,
}

impl<T> WasmMutex<T> {
pub fn new(value: T) -> Self {
Self {
js_lock: Mutex::new(),
rs_lock: std::sync::Mutex::new(value),
}
}

pub async fn lock(&self) -> WasmMutexGuard<T> {
WasmMutexGuard {
js_release: self.js_lock.obtain().await,
// This never actually blocks as it's guarded by the JS lock. This field exists only to
// provide internal mutability for the underlying value.
rs_guard: self.rs_lock.lock().unwrap(),
}
}
}

/// A handle to the underlying guarded value. The lock is released when the instance is dropped.
pub struct WasmMutexGuard<'a, T> {
js_release: Function,
rs_guard: StdMutexGuard<'a, T>,
}

impl<T> Deref for WasmMutexGuard<'_, T> {
type Target = T;

fn deref(&self) -> &T {
std::sync::MutexGuard::deref(&self.rs_guard)
}
}

impl<T> DerefMut for WasmMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
std::sync::MutexGuard::deref_mut(&mut self.rs_guard)
}
}

impl<T> Drop for WasmMutexGuard<'_, T> {
fn drop(&mut self) {
self.js_release.call0(&JsValue::null()).unwrap();
}
}

#[wasm_bindgen(module = "/src/wasm-mutex.js")]
extern "C" {
type Mutex;

#[wasm_bindgen(constructor)]
fn new() -> Mutex;

#[wasm_bindgen(method)]
async fn obtain(this: &Mutex) -> Function;
}
13 changes: 13 additions & 0 deletions packages/account-wasm/src/wasm-mutex.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function releaseStub() {}

export class Mutex {
lastPromise = Promise.resolve();

async obtain() {
let release = releaseStub;
const lastPromise = this.lastPromise;
this.lastPromise = new Promise((resolve) => (release = resolve));
await lastPromise;
return release;
}
}
18 changes: 14 additions & 4 deletions packages/keychain/src/components/ConfirmTransaction.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useMemo, useState } from "react";
import { useEffect, useMemo, useState } from "react";
import { ResponseCodes, SessionPolicies, toArray } from "@cartridge/controller";
import { Content, FOOTER_MIN_HEIGHT } from "components/layout";
import { TransactionDuoIcon } from "@cartridge/ui";
Expand All @@ -12,6 +12,7 @@ import { CreateSession } from "./connect";
export function ConfirmTransaction() {
const { controller, context, origin, policies, setContext } = useConnection();
const [policiesUpdated, setIsPoliciesUpdated] = useState<boolean>(false);
const [updateSession, setUpdateSession] = useState<boolean>(false);
const ctx = context as ExecuteCtx;
const account = controller;

Expand Down Expand Up @@ -46,8 +47,11 @@ export function ConfirmTransaction() {
[ctx.transactions],
);

const updateSession = useMemo(() => {
if (policiesUpdated) return false;
useEffect(() => {
if (policiesUpdated) {
setUpdateSession(false);
return;
}

const entries = Object.entries(callPolicies.contracts || {});
const txnsApproved = entries.every(([target, policy]) => {
Expand All @@ -60,7 +64,13 @@ export function ConfirmTransaction() {

// If calls are approved by dapp specified policies but not stored session
// then prompt user to update session. This also accounts for expired sessions.
return txnsApproved && !account?.session(callPolicies);
if (txnsApproved && account) {
account.session(callPolicies).then((hasSession) => {
setUpdateSession(!hasSession);
});
} else {
setUpdateSession(false);
}
}, [callPolicies, policiesUpdated, policies, account]);

if (updateSession && policies) {
Expand Down
42 changes: 25 additions & 17 deletions packages/keychain/src/components/connect/RegisterSession.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Content } from "components/layout";
import { useCallback, useMemo, useState } from "react";
import { useCallback, useEffect, useState } from "react";
import { useConnection } from "hooks/connection";
import { SessionConsent } from "components/connect";
import { ExecutionContainer } from "components/ExecutionContainer";
Expand All @@ -23,23 +23,31 @@ export function RegisterSession({
}) {
const { controller, theme } = useConnection();
const [expiresAt] = useState<bigint>(SESSION_EXPIRATION);
const [transactions, setTransactions] = useState<
| {
contractAddress: string;
entrypoint: string;
calldata: string[];
}[]
| undefined
>(undefined);

const transactions = useMemo(() => {
if (!publicKey || !controller) return;

const calldata = controller.registerSessionCalldata(
expiresAt,
policies,
publicKey,
);

return [
{
contractAddress: controller.address,
entrypoint: "register_session",
calldata,
},
];
useEffect(() => {
if (!publicKey || !controller) {
setTransactions(undefined);
} else {
controller
.registerSessionCalldata(expiresAt, policies, publicKey)
.then((calldata) => {
setTransactions([
{
contractAddress: controller.address,
entrypoint: "register_session",
calldata,
},
]);
});
}
}, [controller, expiresAt, policies, publicKey]);

const onRegisterSession = useCallback(
Expand Down
13 changes: 7 additions & 6 deletions packages/keychain/src/hooks/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export function useConnectionValue() {
if (controller) {
posthog.identify(controller.username(), {
address: controller.address,
class: controller.cartridge.classHash,
class: controller.classHash(),
chainId: controller.chainId,
});
} else {
Expand Down Expand Up @@ -224,12 +224,13 @@ export function useConnectionValue() {
}, [rpcUrl, controller]);

const logout = useCallback(() => {
window.controller?.disconnect();
setController(undefined);
window.controller?.disconnect().then(() => {
setController(undefined);

context?.resolve?.({
code: ResponseCodes.NOT_CONNECTED,
message: "User logged out",
context?.resolve?.({
code: ResponseCodes.NOT_CONNECTED,
message: "User logged out",
});
});
}, [context, setController]);

Expand Down
7 changes: 2 additions & 5 deletions packages/keychain/src/hooks/deploy.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { useCallback, useState } from "react";
import { num } from "starknet";
import { useConnection } from "./connection";

type TransactionHash = string;

interface DeployInterface {
deploySelf: (maxFee: string) => Promise<TransactionHash>;
deploySelf: (maxFee: string) => Promise<TransactionHash | undefined>;
isDeploying: boolean;
}

Expand All @@ -18,9 +17,7 @@ export const useDeploy = (): DeployInterface => {
if (!controller) return;
try {
setIsDeploying(true);
const { transaction_hash } = await controller.cartridge.deploySelf(
num.toHex(maxFee),
);
const { transaction_hash } = await controller.selfDeploy(maxFee);

return transaction_hash;
} catch (e) {
Expand Down
16 changes: 10 additions & 6 deletions packages/keychain/src/hooks/upgrade.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { useCallback, useEffect, useMemo, useState } from "react";
import { JsCall } from "@cartridge/account-wasm";
import { useCallback, useEffect, useState } from "react";
import { addAddressPadding, Call } from "starknet";
import { ControllerError } from "utils/connection";
import Controller from "utils/controller";
Expand Down Expand Up @@ -71,6 +72,7 @@ export const useUpgrade = (controller?: Controller): UpgradeInterface => {
const [isSynced, setIsSynced] = useState<boolean>(false);
const [isUpgrading, setIsUpgrading] = useState<boolean>(false);
const [current, setCurrent] = useState<ControllerVersionInfo>();
const [calls, setCalls] = useState<JsCall[]>([]);

useEffect(() => {
if (!controller) {
Expand All @@ -93,7 +95,7 @@ export const useUpgrade = (controller?: Controller): UpgradeInterface => {
const current = CONTROLLER_VERSIONS.find(
(v) =>
addAddressPadding(v.hash) ===
addAddressPadding(controller.cartridge.classHash()),
addAddressPadding(controller.classHash()),
);
setCurrent(current);
setAvailable(current?.version !== LATEST_CONTROLLER.version);
Expand All @@ -105,12 +107,14 @@ export const useUpgrade = (controller?: Controller): UpgradeInterface => {
.finally(() => setIsSynced(true));
}, [controller]);

const calls = useMemo(() => {
useEffect(() => {
if (!controller || !LATEST_CONTROLLER) {
return [];
setCalls([]);
} else {
controller.upgrade(LATEST_CONTROLLER.hash).then((call) => {
setCalls([call]);
});
}

return [controller.cartridge.upgrade(LATEST_CONTROLLER.hash)];
}, [controller]);

const onUpgrade = useCallback(async () => {
Expand Down
21 changes: 19 additions & 2 deletions packages/keychain/src/pages/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import { ErrorPage } from "components/ErrorBoundary";
import { Settings } from "components/Settings";
import { Upgrade } from "components/connect/Upgrade";
import { PurchaseCredits } from "components/Funding/PurchaseCredits";
import { useEffect } from "react";
import { useEffect, useState } from "react";
import { usePostHog } from "posthog-js/react";
import { PageLoading } from "components/Loading";

function Home() {
const { context, controller, error, policies, upgrade } = useConnection();
const [hasSessionForPolicies, setHasSessionForPolicies] = useState<
boolean | undefined
>(undefined);
const posthog = usePostHog();

useEffect(() => {
Expand All @@ -33,6 +37,16 @@ function Home() {
}
}, [context?.origin, posthog]);

useEffect(() => {
if (controller && policies) {
controller.session(policies).then((session) => {
setHasSessionForPolicies(!!session);
});
} else {
setHasSessionForPolicies(undefined);
}
}, [controller, policies]);

if (window.self === window.top || !context?.origin) {
return <></>;
}
Expand Down Expand Up @@ -73,7 +87,10 @@ function Home() {
return <></>;
}

if (controller.session(policies)) {
if (hasSessionForPolicies === undefined) {
// This is likely never observable in a real application but just in case.
return <PageLoading />;
} else if (hasSessionForPolicies) {
context.resolve({
code: ResponseCodes.SUCCESS,
address: controller.address,
Expand Down
26 changes: 14 additions & 12 deletions packages/keychain/src/pages/session.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export default function Session() {
onCallback({
username: controller.username(),
address: controller.address,
ownerGuid: controller.cartridge.ownerGuid(),
ownerGuid: controller.ownerGuid(),
transactionHash: transaction_hash,
expiresAt: String(SESSION_EXPIRATION),
});
Expand All @@ -125,19 +125,21 @@ export default function Session() {

// If the requested policies has no mismatch with existing policies and public key already
// registered then return the exising session
if (controller.session(policies, queries.public_key)) {
onCallback({
username: controller.username(),
address: controller.address,
ownerGuid: controller.cartridge.ownerGuid(),
alreadyRegistered: true,
expiresAt: String(SESSION_EXPIRATION),
});
controller.session(policies, queries.public_key).then((session) => {
if (session) {
onCallback({
username: controller.username(),
address: controller.address,
ownerGuid: controller.ownerGuid(),
alreadyRegistered: true,
expiresAt: String(SESSION_EXPIRATION),
});

return;
}
return;
}

setIsLoading(false);
setIsLoading(false);
tarrencev marked this conversation as resolved.
Show resolved Hide resolved
});
}, [controller, origin, policies, queries.public_key, onCallback]);

if (!controller) {
Expand Down
Loading
Loading