Skip to content

Commit

Permalink
Message policy support support chain switching
Browse files Browse the repository at this point in the history
  • Loading branch information
broody committed Jan 9, 2025
1 parent 2f22063 commit b6accf7
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 140 deletions.
57 changes: 31 additions & 26 deletions examples/next/src/components/providers/StarknetProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,35 @@ export const ETH_CONTRACT_ADDRESS =
export const STRK_CONTRACT_ADDRESS =
"0x04718f5a0Fc34cC1AF16A1cdee98fFB20C31f5cD61D6Ab07201858f4287c938D";

const messageForChain = (chainId: constants.StarknetChainId) => {
return {
types: {
StarknetDomain: [
{ name: "name", type: "shortstring" },
{ name: "version", type: "shortstring" },
{ name: "chainId", type: "shortstring" },
{ name: "revision", type: "shortstring" },
],
Person: [
{ name: "name", type: "felt" },
{ name: "wallet", type: "felt" },
],
Mail: [
{ name: "from", type: "Person" },
{ name: "to", type: "Person" },
{ name: "contents", type: "felt" },
],
},
primaryType: "Mail",
domain: {
name: "StarkNet Mail",
version: "1",
revision: "1",
chainId: chainId,
},
};
};

const policies: SessionPolicies = {
contracts: {
[ETH_CONTRACT_ADDRESS]: {
Expand Down Expand Up @@ -51,32 +80,8 @@ const policies: SessionPolicies = {
},
},
messages: [
// {
// types: {
// StarknetDomain: [
// { name: "name", type: "shortstring" },
// { name: "version", type: "shortstring" },
// { name: "chainId", type: "shortstring" },
// { name: "revision", type: "shortstring" },
// ],
// Person: [
// { name: "name", type: "felt" },
// { name: "wallet", type: "felt" },
// ],
// Mail: [
// { name: "from", type: "Person" },
// { name: "to", type: "Person" },
// { name: "contents", type: "felt" },
// ],
// },
// primaryType: "Mail",
// domain: {
// name: "StarkNet Mail",
// version: "1",
// revision: "1",
// chainId: "SN_SEPOLIA",
// },
// },
messageForChain(constants.StarknetChainId.SN_MAIN),
messageForChain(constants.StarknetChainId.SN_SEPOLIA),
],
};

Expand Down
13 changes: 12 additions & 1 deletion packages/controller/src/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export default class ControllerProvider extends BaseProvider {
let chainId: ChainId | undefined;
const url = new URL(chain.rpcUrl);
const parts = url.pathname.split("/");

console.log(chain.rpcUrl, parts);
if (parts.includes("starknet")) {
if (parts.includes("mainnet")) {
chainId = constants.StarknetChainId.SN_MAIN;
Expand All @@ -60,6 +60,17 @@ export default class ControllerProvider extends BaseProvider {
chains.set(chainId, chain);
}

if (
options.policies?.messages?.length &&
options.policies.messages.length !== chains.size
) {
console.warn(
"Each message policy is associated with a specific chain. " +
"The number of message policies does not match the number of chains specified - " +
"session signing may not work on some chains.",
);
}

this.chains = chains;
this.selectedChain = options.defaultChainId;

Expand Down
57 changes: 22 additions & 35 deletions packages/keychain/src/components/connect/CreateSession.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { Container, Content, Footer } from "@/components/layout";
import { BigNumberish, shortString } from "starknet";
import { ControllerError } from "@/utils/connection";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useCallback, useMemo, useState } from "react";
import { useConnection } from "@/hooks/connection";
import { ControllerErrorAlert } from "@/components/ErrorAlert";
import { SessionConsent } from "@/components/connect";
import { Upgrade } from "./Upgrade";
import { ErrorCode } from "@cartridge/account-wasm";
import { TypedDataPolicy } from "@cartridge/presets";
import { ParsedSessionPolicies } from "@/hooks/session";
import { UnverifiedSessionSummary } from "@/components/session/UnverifiedSessionSummary";
import { VerifiedSessionSummary } from "@/components/session/VerifiedSessionSummary";
Expand All @@ -33,7 +31,6 @@ export function CreateSession({
}) {
const { controller, upgrade, chainId, theme, logout } = useConnection();
const [isConnecting, setIsConnecting] = useState(false);
const [isDisabled, setIsDisabled] = useState(false);
const [isConsent, setIsConsent] = useState(false);
const [duration, setDuration] = useState<bigint>(DEFAULT_SESSION_DURATION);
const expiresAt = useMemo(
Expand All @@ -43,32 +40,17 @@ export function CreateSession({
const [maxFee] = useState<BigNumberish>();
const [error, setError] = useState<ControllerError | Error>();

useEffect(() => {
if (!chainId) return;
const normalizedChainId = normalizeChainId(chainId);

const violatingPolicy = policies.messages?.find(
(policy: TypedDataPolicy) =>
"domain" in policy &&
(!policy.domain.chainId ||
normalizeChainId(policy.domain.chainId) !== normalizedChainId),
);

if (violatingPolicy) {
setError({
code: ErrorCode.PolicyChainIdMismatch,
message: `Policy for ${
(violatingPolicy as TypedDataPolicy).domain.name
}.${
(violatingPolicy as TypedDataPolicy).primaryType
} has mismatched chain ID.`,
});
setIsDisabled(true);
} else {
setError(undefined);
setIsDisabled(false);
}
}, [chainId, policies]);
const chainSpecificMessages = useMemo(() => {
if (!policies.messages || !chainId) return [];
return policies.messages.filter((message) => {
return (
!("domain" in message) ||
(message.domain.chainId &&
normalizeChainId(message.domain.chainId) ===
normalizeChainId(chainId))
);
});
}, [policies.messages, chainId]);

const onCreateSession = useCallback(async () => {
if (!controller || !policies) return;
Expand Down Expand Up @@ -139,9 +121,16 @@ export function CreateSession({
<Content gap={6}>
<SessionConsent isVerified={policies?.verified} />
{policies?.verified ? (
<VerifiedSessionSummary game={theme.name} policies={policies} />
<VerifiedSessionSummary
game={theme.name}
contracts={policies.contracts}
messages={chainSpecificMessages}
/>
) : (
<UnverifiedSessionSummary policies={policies} />
<UnverifiedSessionSummary
contracts={policies.contracts}
messages={chainSpecificMessages}
/>
)}
</Content>
<Footer>
Expand Down Expand Up @@ -199,9 +188,7 @@ export function CreateSession({
</Button>
<Button
className="flex-1"
disabled={
isDisabled || isConnecting || (!policies?.verified && !isConsent)
}
disabled={isConnecting || (!policies?.verified && !isConsent)}
isLoading={isConnecting}
onClick={onCreateSession}
>
Expand Down
96 changes: 50 additions & 46 deletions packages/keychain/src/components/session/AggregateCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,35 @@ import { useExplorer } from "@starknet-react/core";
import { constants } from "starknet";
import { Method } from "@cartridge/presets";
import { useChainId } from "@/hooks/connection";
import { ParsedSessionPolicies } from "@/hooks/session";
import { SessionContracts, SessionMessages } from "@/hooks/session";
import { Link } from "react-router-dom";
import { AccordionCard } from "./AccordionCard";
import { MessageContent } from "./MessageCard";

interface AggregateCardProps {
title: string;
icon: React.ReactNode;
policies: ParsedSessionPolicies;
contracts?: SessionContracts;
messages?: SessionMessages;
}

export function AggregateCard({ title, icon, policies }: AggregateCardProps) {
export function AggregateCard({
title,
icon,
contracts,
messages,
}: AggregateCardProps) {
const chainId = useChainId();
const explorer = useExplorer();

const totalMethods = Object.values(policies.contracts || {}).reduce(
const totalMethods = Object.values(contracts || {}).reduce(
(acc, contract) => {
return acc + (contract.methods?.length || 0);
},
0,
);

const totalMessages = policies.messages?.length ?? 0;
const totalMessages = messages?.length ?? 0;
const count = totalMethods + totalMessages;

return (
Expand All @@ -43,52 +49,50 @@ export function AggregateCard({ title, icon, policies }: AggregateCardProps) {
}
className="gap-2"
>
{Object.entries(policies.contracts || {}).map(
([address, { name, methods }]) => (
<div key={address} className="flex flex-col gap-2">
<div className="flex items-center justify-between bg-secondary text-xs">
<div className="py-2 font-bold">{name}</div>
<Link
to={
chainId === constants.StarknetChainId.SN_MAIN ||
chainId === constants.StarknetChainId.SN_SEPOLIA
? explorer.contract(address)
: `#` // TODO: Add explorer for worlds.dev
}
target="_blank"
className="text-muted-foreground hover:underline"
>
{formatAddress(address, { first: 5, last: 5 })}
</Link>
</div>
{Object.entries(contracts || {}).map(([address, { name, methods }]) => (
<div key={address} className="flex flex-col gap-2">
<div className="flex items-center justify-between bg-secondary text-xs">
<div className="py-2 font-bold">{name}</div>
<Link
to={
chainId === constants.StarknetChainId.SN_MAIN ||
chainId === constants.StarknetChainId.SN_SEPOLIA
? explorer.contract(address)
: `#` // TODO: Add explorer for worlds.dev
}
target="_blank"
className="text-muted-foreground hover:underline"
>
{formatAddress(address, { first: 5, last: 5 })}
</Link>
</div>

<div className="flex flex-col gap-px rounded overflow-auto border border-background">
{methods.map((method: Method) => (
<div
key={method.name}
className="flex flex-col p-3 gap-3 text-xs"
>
<div className="flex items-center justify-between">
<div className="font-bold text-accent-foreground">
{method.name}
</div>
<div className="text-muted-foreground">
{method.entrypoint}
</div>
<div className="flex flex-col gap-px rounded overflow-auto border border-background">
{methods.map((method: Method) => (
<div
key={method.name}
className="flex flex-col p-3 gap-3 text-xs"
>
<div className="flex items-center justify-between">
<div className="font-bold text-accent-foreground">
{method.name}
</div>
<div className="text-muted-foreground">
{method.entrypoint}
</div>
{method.description && (
<div className="text-muted-foreground">
{method.description}
</div>
)}
</div>
))}
</div>
{method.description && (
<div className="text-muted-foreground">
{method.description}
</div>
)}
</div>
))}
</div>
),
)}
</div>
))}

{policies.messages && <MessageContent messages={policies.messages} />}
{messages && <MessageContent messages={messages} />}
</AccordionCard>
);
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import { toArray } from "@cartridge/controller";
import { ParsedSessionPolicies } from "@/hooks/session";
import { SessionContracts, SessionMessages } from "@/hooks/session";

import { MessageCard } from "./MessageCard";
import { ContractCard } from "./ContractCard";

export function UnverifiedSessionSummary({
policies,
contracts,
messages,
}: {
policies: ParsedSessionPolicies;
contracts?: SessionContracts;
messages?: SessionMessages;
}) {
return (
<div className="flex flex-col gap-4">
{Object.entries(policies.contracts ?? {}).map(([address, contract]) => {
{Object.entries(contracts ?? {}).map(([address, contract]) => {
const methods = toArray(contract.methods);
const title = !contract.meta?.name ? "Contract" : contract.meta.name;
const icon = contract.meta?.icon;
Expand All @@ -28,8 +30,8 @@ export function UnverifiedSessionSummary({
);
})}

{policies.messages && policies.messages.length > 0 && (
<MessageCard messages={policies.messages} isExpanded />
{messages && messages.length > 0 && (
<MessageCard messages={messages} isExpanded />
)}
</div>
);
Expand Down
Loading

0 comments on commit b6accf7

Please sign in to comment.