diff --git a/packages/starknet-snap/src/config.ts b/packages/starknet-snap/src/config.ts index 8fb6bc24..cec6227f 100644 --- a/packages/starknet-snap/src/config.ts +++ b/packages/starknet-snap/src/config.ts @@ -1,9 +1,16 @@ -import { SnapEnv } from './utils/constants'; +import type { Network } from './types/snapState'; +import { + SnapEnv, + STARKNET_MAINNET_NETWORK, + STARKNET_SEPOLIA_TESTNET_NETWORK, +} from './utils/constants'; import { LogLevel } from './utils/logger'; export type SnapConfig = { logLevel: string; snapEnv: SnapEnv; + defaultNetwork: Network; + availableNetworks: Network[]; }; export const Config: SnapConfig = { @@ -11,4 +18,11 @@ export const Config: SnapConfig = { logLevel: process.env.LOG_LEVEL ?? LogLevel.OFF.valueOf().toString(), // eslint-disable-next-line no-restricted-globals snapEnv: (process.env.SNAP_ENV ?? SnapEnv.Prod) as unknown as SnapEnv, + + defaultNetwork: STARKNET_MAINNET_NETWORK, + + availableNetworks: [ + STARKNET_MAINNET_NETWORK, + STARKNET_SEPOLIA_TESTNET_NETWORK, + ], }; diff --git a/packages/starknet-snap/src/state/network-state-manager.test.ts b/packages/starknet-snap/src/state/network-state-manager.test.ts index 6e199f83..3003d8ca 100644 --- a/packages/starknet-snap/src/state/network-state-manager.test.ts +++ b/packages/starknet-snap/src/state/network-state-manager.test.ts @@ -1,9 +1,11 @@ import { constants } from 'starknet'; +import { Config } from '../config'; import type { Network } from '../types/snapState'; import { STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK, + STARKNET_TESTNET_NETWORK, } from '../utils/constants'; import { mockState } from './__tests__/helper'; import { NetworkStateManager, ChainIdFilter } from './network-state-manager'; @@ -14,7 +16,7 @@ describe('NetworkStateManager', () => { it('returns the network', async () => { const chainId = constants.StarknetChainId.SN_SEPOLIA; await mockState({ - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], + networks: Config.availableNetworks, }); const stateManager = new NetworkStateManager(); @@ -25,15 +27,27 @@ describe('NetworkStateManager', () => { expect(result).toStrictEqual(STARKNET_SEPOLIA_TESTNET_NETWORK); }); - it('returns null if the network can not be found', async () => { - const chainId = constants.StarknetChainId.SN_SEPOLIA; + it('looks up the configuration if the network cant be found in state', async () => { await mockState({ networks: [STARKNET_MAINNET_NETWORK], }); const stateManager = new NetworkStateManager(); const result = await stateManager.getNetwork({ - chainId, + chainId: STARKNET_SEPOLIA_TESTNET_NETWORK.chainId, + }); + + expect(result).toStrictEqual(STARKNET_SEPOLIA_TESTNET_NETWORK); + }); + + it('returns null if the network can not be found', async () => { + await mockState({ + networks: Config.availableNetworks, + }); + + const stateManager = new NetworkStateManager(); + const result = await stateManager.getNetwork({ + chainId: '0x9999', }); expect(result).toBeNull(); @@ -103,7 +117,7 @@ describe('NetworkStateManager', () => { it('returns the list of network by chainId', async () => { const chainId = constants.StarknetChainId.SN_SEPOLIA; await mockState({ - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], + networks: Config.availableNetworks, }); const stateManager = new NetworkStateManager(); @@ -163,7 +177,7 @@ describe('NetworkStateManager', () => { describe('getCurrentNetwork', () => { it('get the current network', async () => { await mockState({ - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], + networks: Config.availableNetworks, currentNetwork: STARKNET_MAINNET_NETWORK, }); @@ -173,15 +187,27 @@ describe('NetworkStateManager', () => { expect(result).toStrictEqual(STARKNET_MAINNET_NETWORK); }); - it('returns null if the current network is null or undefined', async () => { + it(`returns default network if the current network is null or undefined`, async () => { await mockState({ - networks: [STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK], + networks: Config.availableNetworks, }); const stateManager = new NetworkStateManager(); const result = await stateManager.getCurrentNetwork(); - expect(result).toBeNull(); + expect(result).toStrictEqual(Config.defaultNetwork); + }); + + it(`returns default network if the current network is neither mainnet or sepolia testnet`, async () => { + await mockState({ + networks: Config.availableNetworks, + currentNetwork: STARKNET_TESTNET_NETWORK, + }); + + const stateManager = new NetworkStateManager(); + const result = await stateManager.getCurrentNetwork(); + + expect(result).toStrictEqual(Config.defaultNetwork); }); }); @@ -213,10 +239,7 @@ describe('NetworkStateManager', () => { updateTo: Network; }) => { const { state } = await mockState({ - networks: [ - STARKNET_MAINNET_NETWORK, - STARKNET_SEPOLIA_TESTNET_NETWORK, - ], + networks: Config.availableNetworks, currentNetwork, }); diff --git a/packages/starknet-snap/src/state/network-state-manager.ts b/packages/starknet-snap/src/state/network-state-manager.ts index a3588e2e..65a9312a 100644 --- a/packages/starknet-snap/src/state/network-state-manager.ts +++ b/packages/starknet-snap/src/state/network-state-manager.ts @@ -1,5 +1,6 @@ import { assert, string } from 'superstruct'; +import { Config } from '../config'; import type { Network, SnapState } from '../types/snapState'; import type { IFilter } from './filter'; import { ChainIdFilter as BaseChainIdFilter } from './filter'; @@ -59,6 +60,9 @@ export class NetworkStateManager extends StateManager { /** * Finds a network based on the given chainId. + * The query will first be looked up in the state. If the result is false, it will then fallback to the available Networks constants. + * + * (Note) Due to the returned network object may not exist in the state, it may failed to execute `updateNetwork` with the returned network object. * * @param param - The param object. * @param param.chainId - The chainId to search for. @@ -74,7 +78,12 @@ export class NetworkStateManager extends StateManager { state?: SnapState, ): Promise { const filters: INetworkFilter[] = [new ChainIdFilter([chainId])]; - return this.find(filters, state); + // in case the network not found from the state, try to get the network from the available Networks constants + return ( + (await this.find(filters, state)) ?? + Config.availableNetworks.find((network) => network.chainId === chainId) ?? + null + ); } /** @@ -88,10 +97,9 @@ export class NetworkStateManager extends StateManager { async updateNetwork(data: Network): Promise { try { await this.update(async (state: SnapState) => { - const dataInState = await this.getNetwork( - { - chainId: data.chainId, - }, + // Use underlying function `find` to avoid searching network from constants + const dataInState = await this.find( + [new ChainIdFilter([data.chainId])], state, ); @@ -111,8 +119,20 @@ export class NetworkStateManager extends StateManager { * @param [state] - The optional SnapState object. * @returns A Promise that resolves with the current Network object if found, or null if not found. */ - async getCurrentNetwork(state?: SnapState): Promise { - return (state ?? (await this.get())).currentNetwork ?? null; + async getCurrentNetwork(state?: SnapState): Promise { + const { currentNetwork } = state ?? (await this.get()); + + // Make sure the current network is either Sepolia testnet or Mainnet. By default it will be Mainnet. + if ( + !currentNetwork || + !Config.availableNetworks.find( + (network) => network.chainId === currentNetwork.chainId, + ) + ) { + return Config.defaultNetwork; + } + + return currentNetwork; } /** diff --git a/packages/starknet-snap/src/utils/snapUtils.ts b/packages/starknet-snap/src/utils/snapUtils.ts index ea6a78ab..92d23744 100644 --- a/packages/starknet-snap/src/utils/snapUtils.ts +++ b/packages/starknet-snap/src/utils/snapUtils.ts @@ -14,6 +14,7 @@ import type { UniversalDetails, } from 'starknet'; +import { Config } from '../config'; import { FeeToken, type AddErc20TokenRequestParams, @@ -34,7 +35,6 @@ import { MAXIMUM_TOKEN_SYMBOL_LENGTH, PRELOADED_NETWORKS, PRELOADED_TOKENS, - STARKNET_MAINNET_NETWORK, STARKNET_SEPOLIA_TESTNET_NETWORK, } from './constants'; import { DeployRequiredError, UpgradeRequiredError } from './exceptions'; @@ -855,7 +855,7 @@ export function getNetworkFromChainId( state: SnapState, targerChainId: string | undefined, ) { - const chainId = targerChainId ?? STARKNET_MAINNET_NETWORK.chainId; + const chainId = targerChainId ?? Config.defaultNetwork.chainId; const network = getNetwork(state, chainId); if (network === undefined) { throw new Error( @@ -1117,7 +1117,7 @@ export async function removeAcceptedTransaction( * @param state */ export function getCurrentNetwork(state: SnapState) { - return state.currentNetwork ?? STARKNET_MAINNET_NETWORK; + return state.currentNetwork ?? Config.defaultNetwork; } /**