diff --git a/packages/starknet-snap/src/switchNetwork.ts b/packages/starknet-snap/src/switchNetwork.ts index 1159710f..b3cd3a22 100644 --- a/packages/starknet-snap/src/switchNetwork.ts +++ b/packages/starknet-snap/src/switchNetwork.ts @@ -15,14 +15,16 @@ export async function switchNetwork(params: ApiParams) { } const components = getNetworkTxt(network); - const response = await wallet.request({ - method: 'snap_dialog', - params: { - type: DialogType.Confirmation, - content: panel([heading('Do you want to switch to this network?'), ...components]), - }, - }); - if (!response) return false; + if (requestParamsObj.enableAutherize === true) { + const response = await wallet.request({ + method: 'snap_dialog', + params: { + type: DialogType.Confirmation, + content: panel([heading('Do you want to switch to this network?'), ...components]), + }, + }); + if (!response) return false; + } logger.log(`switchNetwork: network:\n${toJson(network, 2)}`); await setCurrentNetwork(network, wallet, saveMutex, state); diff --git a/packages/starknet-snap/src/types/snapApi.ts b/packages/starknet-snap/src/types/snapApi.ts index 37f80539..d5ac25a1 100644 --- a/packages/starknet-snap/src/types/snapApi.ts +++ b/packages/starknet-snap/src/types/snapApi.ts @@ -75,7 +75,7 @@ export interface ExtractPublicKeyRequestParams extends BaseRequestParams { userAddress: string; } -export interface SignMessageRequestParams extends SignRequestParams, BaseRequestParams { +export interface SignMessageRequestParams extends Autherizeable, SignRequestParams, BaseRequestParams { typedDataMessage: typedData.TypedData; } @@ -181,25 +181,28 @@ export interface RpcV4GetTransactionReceiptResponse { finality_status?: string; } +export interface Autherizeable { + enableAutherize?: boolean; +} + export interface SignRequestParams { signerAddress: string; - enableAutherize?: boolean; } -export interface SignTransactionRequestParams extends SignRequestParams, BaseRequestParams { +export interface SignTransactionRequestParams extends Autherizeable, SignRequestParams, BaseRequestParams { transactions: Call[]; transactionsDetail: InvocationsSignerDetails; abis?: Abi[]; } -export interface SignDeployAccountTransactionRequestParams extends SignRequestParams, BaseRequestParams { +export interface SignDeployAccountTransactionRequestParams extends Autherizeable, SignRequestParams, BaseRequestParams { transaction: DeployAccountSignerDetails; } -export interface SignDeclareTransactionRequestParams extends SignRequestParams, BaseRequestParams { +export interface SignDeclareTransactionRequestParams extends Autherizeable, SignRequestParams, BaseRequestParams { transaction: DeclareSignerDetails; } -export interface SwitchNetworkRequestParams extends BaseRequestParams { +export interface SwitchNetworkRequestParams extends Autherizeable, BaseRequestParams { chainId: string; } diff --git a/packages/starknet-snap/test/src/switchNetwork.test.ts b/packages/starknet-snap/test/src/switchNetwork.test.ts index e8cc4573..f09b093e 100644 --- a/packages/starknet-snap/test/src/switchNetwork.test.ts +++ b/packages/starknet-snap/test/src/switchNetwork.test.ts @@ -44,6 +44,7 @@ describe('Test function: switchNetwork', function () { it('should switch the network correctly', async function () { const requestObject: SwitchNetworkRequestParams = { chainId: STARKNET_MAINNET_NETWORK.chainId, + enableAutherize: true, }; apiParams.requestParams = requestObject; const result = await switchNetwork(apiParams); @@ -54,9 +55,23 @@ describe('Test function: switchNetwork', function () { expect(state.networks.length).to.be.eql(2); }); + it('should skip autherize when enableAutherize is false or omit', async function () { + const requestObject: SwitchNetworkRequestParams = { + chainId: STARKNET_MAINNET_NETWORK.chainId, + }; + apiParams.requestParams = requestObject; + const result = await switchNetwork(apiParams); + expect(result).to.be.eql(true); + expect(stateStub).to.be.calledOnce; + expect(dialogStub).to.be.callCount(0); + expect(state.currentNetwork).to.be.eql(STARKNET_MAINNET_NETWORK); + expect(state.networks.length).to.be.eql(2); + }); + it('should throw an error if network not found', async function () { const requestObject: SwitchNetworkRequestParams = { chainId: '123', + enableAutherize: true, }; apiParams.requestParams = requestObject; let result; @@ -76,6 +91,7 @@ describe('Test function: switchNetwork', function () { sandbox.stub(snapUtils, 'setCurrentNetwork').throws(new Error()); const requestObject: SwitchNetworkRequestParams = { chainId: STARKNET_TESTNET_NETWORK.chainId, + enableAutherize: true, }; apiParams.requestParams = requestObject; let result; diff --git a/packages/wallet-ui/src/components/ui/organism/Menu/Menu.view.tsx b/packages/wallet-ui/src/components/ui/organism/Menu/Menu.view.tsx index 3be83421..cc0ee069 100644 --- a/packages/wallet-ui/src/components/ui/organism/Menu/Menu.view.tsx +++ b/packages/wallet-ui/src/components/ui/organism/Menu/Menu.view.tsx @@ -19,6 +19,7 @@ import { Menu } from '@headlessui/react'; import { theme } from 'theme/default'; import { Radio, Skeleton } from '@mui/material'; import { useAppDispatch, useAppSelector } from 'hooks/redux'; +import { useStarkNetSnap } from 'services'; import { setWalletConnection, setForceReconnect, resetWallet, clearAccounts } from 'slices/walletSlice'; import { resetNetwork, setActiveNetwork } from 'slices/networkSlice'; @@ -27,12 +28,16 @@ interface IProps extends HTMLAttributes { } export const MenuView = ({ connected, ...otherProps }: IProps) => { + const { switchNetwork } = useStarkNetSnap(); const networks = useAppSelector((state) => state.networks); const dispatch = useAppDispatch(); - const changeNetwork = (network: number) => { - dispatch(clearAccounts()); - dispatch(setActiveNetwork(network)); + const changeNetwork = async (network: number, chainId: string) => { + const result = await switchNetwork(chainId); + if (result) { + dispatch(clearAccounts()); + dispatch(setActiveNetwork(network)); + } }; /* There is no way to disconnect the snap from a dapp it must be done from MetaMask. @@ -80,7 +85,7 @@ export const MenuView = ({ connected, ...otherProps }: IProps) => { {networks.items.map((network, index) => ( - changeNetwork(index)}> + changeNetwork(index, network.chainId)}> { const dispatch = useAppDispatch(); const { loader } = useAppSelector((state) => state.UI); const { transactions, erc20TokenBalances, provider } = useAppSelector((state) => state.wallet); - const { activeNetwork } = useAppSelector((state) => state.networks); const snapId = process.env.REACT_APP_SNAP_ID ? process.env.REACT_APP_SNAP_ID : 'local:http://localhost:8081/'; const snapVersion = process.env.REACT_APP_SNAP_VERSION ? process.env.REACT_APP_SNAP_VERSION : '*'; const minSnapVersion = process.env.REACT_APP_MIN_SNAP_VERSION ? process.env.REACT_APP_MIN_SNAP_VERSION : '2.0.1'; @@ -199,7 +199,10 @@ export const useStarkNetSnap = () => { if (nets.length === 0) { return; } - const chainId = nets[activeNetwork].chainId; + const net = await getCurrentNetwork(); + const idx = nets.findIndex((e) => e.chainId === net.chainId); + dispatch(setActiveNetwork(idx)); + const chainId = net.chainId; await getWalletData(chainId, nets); } catch (err: any) { if (err.code && err.code === 4100) { @@ -529,6 +532,49 @@ export const useStarkNetSnap = () => { } }; + const switchNetwork = async (chainId: string) => { + dispatch(enableLoadingWithMessage('Switching Network...')); + try { + const result = await provider.request({ + method: 'wallet_invokeSnap', + params: { + snapId, + request: { + method: 'starkNet_switchNetwork', + params: { + ...defaultParam, + chainId, + }, + }, + }, + }); + dispatch(disableLoading()); + return result; + } catch (err) { + dispatch(disableLoading()); + return false; + } + }; + + const getCurrentNetwork = async () => { + try { + return await provider.request({ + method: 'wallet_invokeSnap', + params: { + snapId, + request: { + method: 'starkNet_getCurrentNetwork', + params: { + ...defaultParam, + }, + }, + }, + }); + } catch (err) { + throw err; + } + }; + return { connectToSnap, getNetworks, @@ -547,6 +593,8 @@ export const useStarkNetSnap = () => { initSnap, getWalletData, refreshTokensUSDPrice, + switchNetwork, + getCurrentNetwork, satisfiesVersion: oldVersionDetected, }; };