From d0384bf9c9476c2168586cf7dc48fe6adb965bcb Mon Sep 17 00:00:00 2001 From: Stanley Yuen <102275989+stanleyyconsensys@users.noreply.github.com> Date: Thu, 10 Oct 2024 18:55:11 +0800 Subject: [PATCH] refactor(SwitchNetwork): revamp RPC `starkNet_switchNetwork` (#368) * chore: revamp switch network * chore: remove legacy code for switch network * fix: util `getCurrentNetwork ` * chore: update network state mgr with default network config * chore: lint fix * chore: lint fix * chore: rebase * chore: update comment * chore: update comment * chore: use new error format * chore: rollback snapstate change --------- Co-authored-by: khanti42 --- packages/starknet-snap/src/index.ts | 7 +- packages/starknet-snap/src/rpcs/index.ts | 1 + .../src/rpcs/switch-network.test.ts | 187 ++++++++++++++++++ .../starknet-snap/src/rpcs/switch-network.ts | 126 ++++++++++++ packages/starknet-snap/src/switchNetwork.ts | 52 ----- packages/starknet-snap/src/utils/rpc.ts | 2 +- .../test/src/switchNetwork.test.ts | 109 ---------- 7 files changed, 320 insertions(+), 164 deletions(-) create mode 100644 packages/starknet-snap/src/rpcs/switch-network.test.ts create mode 100644 packages/starknet-snap/src/rpcs/switch-network.ts delete mode 100644 packages/starknet-snap/src/switchNetwork.ts delete mode 100644 packages/starknet-snap/test/src/switchNetwork.test.ts diff --git a/packages/starknet-snap/src/index.ts b/packages/starknet-snap/src/index.ts index 46859407..724b0f9c 100644 --- a/packages/starknet-snap/src/index.ts +++ b/packages/starknet-snap/src/index.ts @@ -34,6 +34,7 @@ import type { SignTransactionParams, SignDeclareTransactionParams, VerifySignatureParams, + SwitchNetworkParams, } from './rpcs'; import { displayPrivateKey, @@ -43,10 +44,10 @@ import { signTransaction, signDeclareTransaction, verifySignature, + switchNetwork, } from './rpcs'; import { sendTransaction } from './sendTransaction'; import { signDeployAccountTransaction } from './signDeployAccountTransaction'; -import { switchNetwork } from './switchNetwork'; import type { ApiParams, ApiParamsWithKeyDeriver, @@ -230,7 +231,9 @@ export const onRpcRequest: OnRpcRequestHandler = async ({ request }) => { return await addNetwork(apiParams); case 'starkNet_switchNetwork': - return await switchNetwork(apiParams); + return await switchNetwork.execute( + apiParams.requestParams as unknown as SwitchNetworkParams, + ); case 'starkNet_getCurrentNetwork': return await getCurrentNetwork(apiParams); diff --git a/packages/starknet-snap/src/rpcs/index.ts b/packages/starknet-snap/src/rpcs/index.ts index e3bc2fb0..13140791 100644 --- a/packages/starknet-snap/src/rpcs/index.ts +++ b/packages/starknet-snap/src/rpcs/index.ts @@ -5,3 +5,4 @@ export * from './signMessage'; export * from './signTransaction'; export * from './sign-declare-transaction'; export * from './verify-signature'; +export * from './switch-network'; diff --git a/packages/starknet-snap/src/rpcs/switch-network.test.ts b/packages/starknet-snap/src/rpcs/switch-network.test.ts new file mode 100644 index 00000000..5058a6c0 --- /dev/null +++ b/packages/starknet-snap/src/rpcs/switch-network.test.ts @@ -0,0 +1,187 @@ +import type { constants } from 'starknet'; + +import { Config } from '../config'; +import { NetworkStateManager } from '../state/network-state-manager'; +import type { Network } from '../types/snapState'; +import { + STARKNET_SEPOLIA_TESTNET_NETWORK, + STARKNET_MAINNET_NETWORK, +} from '../utils/constants'; +import { + InvalidNetworkError, + InvalidRequestParamsError, + UserRejectedOpError, +} from '../utils/exceptions'; +import { prepareConfirmDialog } from './__tests__/helper'; +import { switchNetwork } from './switch-network'; +import type { SwitchNetworkParams } from './switch-network'; + +jest.mock('../utils/logger'); + +describe('switchNetwork', () => { + const createRequestParam = ( + chainId: constants.StarknetChainId | string, + enableAuthorize?: boolean, + ): SwitchNetworkParams => { + const request: SwitchNetworkParams = { + chainId: chainId as constants.StarknetChainId, + }; + if (enableAuthorize) { + request.enableAuthorize = enableAuthorize; + } + return request; + }; + + const mockNetworkStateManager = ({ + network = STARKNET_SEPOLIA_TESTNET_NETWORK, + currentNetwork = STARKNET_MAINNET_NETWORK, + }: { + network?: Network | null; + currentNetwork?: Network; + }) => { + const txStateSpy = jest.spyOn( + NetworkStateManager.prototype, + 'withTransaction', + ); + const getNetworkSpy = jest.spyOn( + NetworkStateManager.prototype, + 'getNetwork', + ); + const setCurrentNetworkSpy = jest.spyOn( + NetworkStateManager.prototype, + 'setCurrentNetwork', + ); + const getCurrentNetworkSpy = jest.spyOn( + NetworkStateManager.prototype, + 'getCurrentNetwork', + ); + + getNetworkSpy.mockResolvedValue(network); + getCurrentNetworkSpy.mockResolvedValue(currentNetwork); + txStateSpy.mockImplementation(async (fn) => { + return await fn({ + accContracts: [], + erc20Tokens: [], + networks: Config.availableNetworks, + transactions: [], + }); + }); + + return { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy }; + }; + + it('switchs a network correctly', async () => { + const currentNetwork = STARKNET_MAINNET_NETWORK; + const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + const { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy } = + mockNetworkStateManager({ + currentNetwork, + network: requestNetwork, + }); + const request = createRequestParam(requestNetwork.chainId); + + const result = await switchNetwork.execute(request); + + expect(result).toBe(true); + expect(getCurrentNetworkSpy).toHaveBeenCalled(); + expect(getNetworkSpy).toHaveBeenCalledWith( + { chainId: requestNetwork.chainId }, + expect.anything(), + ); + expect(setCurrentNetworkSpy).toHaveBeenCalledWith(requestNetwork); + }); + + it('returns `true` if the request chainId is the same with current network', async () => { + const currentNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + const { getNetworkSpy, setCurrentNetworkSpy, getCurrentNetworkSpy } = + mockNetworkStateManager({ + currentNetwork, + network: requestNetwork, + }); + const request = createRequestParam(requestNetwork.chainId); + + const result = await switchNetwork.execute(request); + + expect(result).toBe(true); + expect(getCurrentNetworkSpy).toHaveBeenCalled(); + expect(getNetworkSpy).not.toHaveBeenCalled(); + expect(setCurrentNetworkSpy).not.toHaveBeenCalled(); + }); + + it('renders confirmation dialog', async () => { + const currentNetwork = STARKNET_MAINNET_NETWORK; + const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + mockNetworkStateManager({ + currentNetwork, + network: requestNetwork, + }); + const { confirmDialogSpy } = prepareConfirmDialog(); + const request = createRequestParam(requestNetwork.chainId, true); + + await switchNetwork.execute(request); + + expect(confirmDialogSpy).toHaveBeenCalledWith([ + { type: 'heading', value: 'Do you want to switch to this network?' }, + { + type: 'row', + label: 'Chain Name', + value: { + value: requestNetwork.name, + markdown: false, + type: 'text', + }, + }, + { + type: 'divider', + }, + { + type: 'row', + label: 'Chain ID', + value: { + value: requestNetwork.chainId, + markdown: false, + type: 'text', + }, + }, + ]); + }); + + it('throws `UserRejectedRequestError` if user denied the operation', async () => { + const currentNetwork = STARKNET_MAINNET_NETWORK; + const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + mockNetworkStateManager({ + currentNetwork, + network: requestNetwork, + }); + const { confirmDialogSpy } = prepareConfirmDialog(); + confirmDialogSpy.mockResolvedValue(false); + const request = createRequestParam(requestNetwork.chainId, true); + + await expect(switchNetwork.execute(request)).rejects.toThrow( + UserRejectedOpError, + ); + }); + + it('throws `Network not supported` error if the request network is not support', async () => { + const currentNetwork = STARKNET_MAINNET_NETWORK; + const requestNetwork = STARKNET_SEPOLIA_TESTNET_NETWORK; + // Mock the network state manager to return null network + // even if the request chain id is not block by the superstruct + mockNetworkStateManager({ + currentNetwork, + network: null, + }); + const request = createRequestParam(requestNetwork.chainId); + + await expect(switchNetwork.execute(request)).rejects.toThrow( + InvalidNetworkError, + ); + }); + + it('throws `InvalidRequestParamsError` when request parameter is not correct', async () => { + await expect( + switchNetwork.execute({} as unknown as SwitchNetworkParams), + ).rejects.toThrow(InvalidRequestParamsError); + }); +}); diff --git a/packages/starknet-snap/src/rpcs/switch-network.ts b/packages/starknet-snap/src/rpcs/switch-network.ts new file mode 100644 index 00000000..61bc0fbf --- /dev/null +++ b/packages/starknet-snap/src/rpcs/switch-network.ts @@ -0,0 +1,126 @@ +import type { Component } from '@metamask/snaps-sdk'; +import { divider, heading, row, text } from '@metamask/snaps-sdk'; +import type { Infer } from 'superstruct'; +import { assign, boolean } from 'superstruct'; + +import { NetworkStateManager } from '../state/network-state-manager'; +import { + confirmDialog, + AuthorizableStruct, + BaseRequestStruct, + RpcController, +} from '../utils'; +import { InvalidNetworkError, UserRejectedOpError } from '../utils/exceptions'; + +export const SwitchNetworkRequestStruct = assign( + AuthorizableStruct, + BaseRequestStruct, +); + +export const SwitchNetworkResponseStruct = boolean(); + +export type SwitchNetworkParams = Infer; + +export type SwitchNetworkResponse = Infer; + +/** + * The RPC handler to switch the network. + */ +export class SwitchNetworkRpc extends RpcController< + SwitchNetworkParams, + SwitchNetworkResponse +> { + protected requestStruct = SwitchNetworkRequestStruct; + + protected responseStruct = SwitchNetworkResponseStruct; + + /** + * Execute the switching network request handler. + * It switch to a supported network based on the chain id. + * It will show a confirmation dialog to the user before switching a network. + * + * @param params - The parameters of the request. + * @param [params.enableAuthorize] - Optional, a flag to enable or display the confirmation dialog to the user. + * @param params.chainId - The chain id of the network to switch. + * @returns the response of the switching a network in boolean. + * @throws {UserRejectedRequestError} If the user rejects the request. + * @throws {Error} If the network with the chain id is not supported. + */ + async execute(params: SwitchNetworkParams): Promise { + return super.execute(params); + } + + protected async handleRequest( + params: SwitchNetworkParams, + ): Promise { + const { enableAuthorize, chainId } = params; + const networkStateMgr = new NetworkStateManager(); + + // Using transactional state interaction to ensure that the state is updated atomically + // To avoid a use case while 2 requests are trying to update/read the state at the same time + return await networkStateMgr.withTransaction(async (state) => { + const currentNetwork = await networkStateMgr.getCurrentNetwork(state); + + // Return early if the current network is the same as the requested network + if (currentNetwork.chainId === chainId) { + return true; + } + + const network = await networkStateMgr.getNetwork( + { + chainId, + }, + state, + ); + + // if the network is not in the list of networks that we support, we throw an error + if (!network) { + throw new InvalidNetworkError() as unknown as Error; + } + + if ( + // Get Starknet expected show the confirm dialog, while the companion doesnt needed, + // therefore, `enableAuthorize` is to enable/disable the confirmation + enableAuthorize && + !(await this.getSwitchNetworkConsensus(network.name, network.chainId)) + ) { + throw new UserRejectedOpError() as unknown as Error; + } + + await networkStateMgr.setCurrentNetwork(network); + + return true; + }); + } + + protected async getSwitchNetworkConsensus( + networkName: string, + networkChainId: string, + ) { + const components: Component[] = []; + components.push(heading('Do you want to switch to this network?')); + components.push( + row( + 'Chain Name', + text({ + value: networkName, + markdown: false, + }), + ), + ); + components.push(divider()); + components.push( + row( + 'Chain ID', + text({ + value: networkChainId, + markdown: false, + }), + ), + ); + + return await confirmDialog(components); + } +} + +export const switchNetwork = new SwitchNetworkRpc(); diff --git a/packages/starknet-snap/src/switchNetwork.ts b/packages/starknet-snap/src/switchNetwork.ts deleted file mode 100644 index 258d51fa..00000000 --- a/packages/starknet-snap/src/switchNetwork.ts +++ /dev/null @@ -1,52 +0,0 @@ -import { panel, heading, DialogType } from '@metamask/snaps-sdk'; - -import type { ApiParams, SwitchNetworkRequestParams } from './types/snapApi'; -import { logger } from './utils/logger'; -import { toJson } from './utils/serializer'; -import { - getNetwork, - setCurrentNetwork, - getNetworkTxt, -} from './utils/snapUtils'; - -/** - * - * @param params - */ -export async function switchNetwork(params: ApiParams) { - try { - const { state, wallet, saveMutex, requestParams } = params; - const requestParamsObj = requestParams as SwitchNetworkRequestParams; - const network = getNetwork(state, requestParamsObj.chainId); - if (!network) { - throw new Error( - `The given chainId is invalid: ${requestParamsObj.chainId}`, - ); - } - const components = getNetworkTxt(network); - - if (requestParamsObj.enableAuthorize) { - const response = await wallet.request({ - method: 'snap_dialog', - params: { - type: DialogType.Confirmation, - content: panel([ - heading('Do you want to switch to this network?'), - ...components, - ]), - }, - }); - if (!response) { - return false; - } - } - - logger.log(`switchNetwork: network:\n${toJson(network, 2)}`); - await setCurrentNetwork(network, wallet, saveMutex, state); - - return true; - } catch (error) { - logger.error(`Problem found:`, error); - throw error; - } -} diff --git a/packages/starknet-snap/src/utils/rpc.ts b/packages/starknet-snap/src/utils/rpc.ts index 70571a92..4d88e347 100644 --- a/packages/starknet-snap/src/utils/rpc.ts +++ b/packages/starknet-snap/src/utils/rpc.ts @@ -86,7 +86,7 @@ export abstract class RpcController< } // TODO: the Type should be moved to a common place -export type AccountRpcParams = Json & { +export type AccountRpcParams = { chainId: string; address: string; }; diff --git a/packages/starknet-snap/test/src/switchNetwork.test.ts b/packages/starknet-snap/test/src/switchNetwork.test.ts deleted file mode 100644 index 2678911b..00000000 --- a/packages/starknet-snap/test/src/switchNetwork.test.ts +++ /dev/null @@ -1,109 +0,0 @@ -import chai, { expect } from 'chai'; -import sinon from 'sinon'; -import sinonChai from 'sinon-chai'; -import { WalletMock } from '../wallet.mock.test'; -import { SnapState } from '../../src/types/snapState'; -import * as snapUtils from '../../src/utils/snapUtils'; -import { - STARKNET_MAINNET_NETWORK, - STARKNET_SEPOLIA_TESTNET_NETWORK, -} from '../../src/utils/constants'; -import { Mutex } from 'async-mutex'; -import { SwitchNetworkRequestParams, ApiParams } from '../../src/types/snapApi'; -import { switchNetwork } from '../../src/switchNetwork'; - -chai.use(sinonChai); -const sandbox = sinon.createSandbox(); - -describe('Test function: switchNetwork', function () { - const walletStub = new WalletMock(); - const state: SnapState = { - accContracts: [], - erc20Tokens: [], - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], - transactions: [], - currentNetwork: STARKNET_SEPOLIA_TESTNET_NETWORK, - }; - const apiParams: ApiParams = { - state, - requestParams: {}, - wallet: walletStub, - saveMutex: new Mutex(), - }; - let stateStub: sinon.SinonStub; - let dialogStub: sinon.SinonStub; - beforeEach(function () { - stateStub = walletStub.rpcStubs.snap_manageState; - dialogStub = walletStub.rpcStubs.snap_dialog; - stateStub.resolves(state); - dialogStub.resolves(true); - }); - - afterEach(function () { - walletStub.reset(); - sandbox.restore(); - }); - - it('should switch the network correctly', async function () { - const requestObject: SwitchNetworkRequestParams = { - chainId: STARKNET_MAINNET_NETWORK.chainId, - enableAuthorize: true, - }; - apiParams.requestParams = requestObject; - const result = await switchNetwork(apiParams); - expect(result).to.be.eql(true); - expect(stateStub).to.be.calledOnce; - expect(dialogStub).to.be.calledOnce; - expect(state.currentNetwork).to.be.eql(STARKNET_MAINNET_NETWORK); - }); - - it('should skip authorize when enableAuthorize is false or omit', async function () { - const requestObject: SwitchNetworkRequestParams = { - chainId: STARKNET_MAINNET_NETWORK.chainId, - }; - apiParams.requestParams = requestObject; - const result = await switchNetwork(apiParams); - expect(result).to.be.eql(true); - expect(stateStub).to.be.calledOnce; - expect(dialogStub).to.be.callCount(0); - expect(state.currentNetwork).to.be.eql(STARKNET_MAINNET_NETWORK); - }); - - it('should throw an error if network not found', async function () { - const requestObject: SwitchNetworkRequestParams = { - chainId: '123', - enableAuthorize: true, - }; - apiParams.requestParams = requestObject; - let result; - try { - await switchNetwork(apiParams); - } catch (err) { - result = err; - } finally { - expect(result).to.be.an('Error'); - expect(stateStub).to.be.callCount(0); - expect(dialogStub).to.be.callCount(0); - expect(state.currentNetwork).to.be.eql(STARKNET_MAINNET_NETWORK); - } - }); - - it('should throw an error if setCurrentNetwork failed', async function () { - sandbox.stub(snapUtils, 'setCurrentNetwork').throws(new Error()); - const requestObject: SwitchNetworkRequestParams = { - chainId: STARKNET_SEPOLIA_TESTNET_NETWORK.chainId, - enableAuthorize: true, - }; - apiParams.requestParams = requestObject; - let result; - try { - await switchNetwork(apiParams); - } catch (err) { - result = err; - } finally { - expect(result).to.be.an('Error'); - expect(dialogStub).to.be.callCount(1); - expect(state.currentNetwork).to.be.eql(STARKNET_MAINNET_NETWORK); - } - }); -});