diff --git a/src/assets/TokenListController.test.ts b/src/assets/TokenListController.test.ts index 500d83926e..82991887bc 100644 --- a/src/assets/TokenListController.test.ts +++ b/src/assets/TokenListController.test.ts @@ -631,6 +631,42 @@ describe('TokenListController', () => { ); }); + it('should update tokenList state when network updates are passed via onNetworkStateChange callback', async () => { + nock(TOKEN_END_POINT_API) + .get(`/tokens/${NetworksChainId.mainnet}`) + .reply(200, sampleMainnetTokenList) + .persist(); + + const controllerMessenger = getControllerMessenger(); + const { network } = setupNetworkController(controllerMessenger); + const messenger = getRestrictedMessenger(controllerMessenger); + + const controller = new TokenListController({ + chainId: NetworksChainId.mainnet, + onNetworkStateChange: (callback) => + controllerMessenger.subscribe( + 'NetworkController:providerChange', + callback, + ), + preventPollingOnNetworkRestart: false, + interval: 100, + messenger, + }); + controller.start(); + await new Promise((resolve) => setTimeout(() => resolve(), 150)); + expect(controller.state.tokenList).toStrictEqual( + sampleSingleChainState.tokenList, + ); + network.setProviderType('goerli'); + await new Promise((resolve) => setTimeout(() => resolve(), 500)); + + expect(controller.state.tokenList).toStrictEqual({}); + controller.destroy(); + controllerMessenger.clearEventSubscriptions( + 'NetworkController:providerChange', + ); + }); + it('should poll and update rate in the right interval', async () => { const tokenListMock = sinon.stub( TokenListController.prototype, diff --git a/src/assets/TokenListController.ts b/src/assets/TokenListController.ts index 9df6e93703..252ae063b0 100644 --- a/src/assets/TokenListController.ts +++ b/src/assets/TokenListController.ts @@ -5,7 +5,11 @@ import { BaseController } from '../BaseControllerV2'; import type { RestrictedControllerMessenger } from '../ControllerMessenger'; import { safelyExecute, isTokenListSupportedForNetwork } from '../util'; import { fetchTokenList } from '../apis/token-service'; -import { NetworkControllerProviderChangeEvent } from '../network/NetworkController'; +import { + NetworkControllerProviderChangeEvent, + NetworkState, + ProviderConfig, +} from '../network/NetworkController'; import { formatAggregatorNames, formatIconUrlWithProxy } from './assetsUtil'; const DEFAULT_INTERVAL = 24 * 60 * 60 * 1000; @@ -94,6 +98,7 @@ export class TokenListController extends BaseController< * * @param options - The controller options. * @param options.chainId - The chain ID of the current network. + * @param options.onNetworkStateChange - A function for registering an event handler for network state changes. * @param options.interval - The polling interval, in milliseconds. * @param options.cacheRefreshThreshold - The token cache expiry time, in milliseconds. * @param options.messenger - A restricted controller messenger. @@ -103,6 +108,7 @@ export class TokenListController extends BaseController< constructor({ chainId, preventPollingOnNetworkRestart = false, + onNetworkStateChange, interval = DEFAULT_INTERVAL, cacheRefreshThreshold = DEFAULT_THRESHOLD, messenger, @@ -110,6 +116,9 @@ export class TokenListController extends BaseController< }: { chainId: string; preventPollingOnNetworkRestart?: boolean; + onNetworkStateChange?: ( + listener: (networkState: NetworkState | ProviderConfig) => void, + ) => void; interval?: number; cacheRefreshThreshold?: number; messenger: TokenListMessenger; @@ -126,29 +135,54 @@ export class TokenListController extends BaseController< this.chainId = chainId; this.updatePreventPollingOnNetworkRestart(preventPollingOnNetworkRestart); this.abortController = new AbortController(); - this.messagingSystem.subscribe( - 'NetworkController:providerChange', - async (providerConfig) => { - if (this.chainId !== providerConfig.chainId) { - this.abortController.abort(); - this.abortController = new AbortController(); - this.chainId = providerConfig.chainId; - if (this.state.preventPollingOnNetworkRestart) { - this.clearingTokenListData(); - } else { - // Ensure tokenList is referencing data from correct network - this.update(() => { - return { - ...this.state, - tokenList: - this.state.tokensChainsCache[this.chainId]?.data || {}, - }; - }); - await this.restart(); - } + if (onNetworkStateChange) { + onNetworkStateChange(async (networkStateOrProviderConfig) => { + // this check for "provider" is for testing purposes, since in the extension this callback will receive + // an object typed as NetworkState but within repo we can only simulate as if the callback receives an + // object typed as ProviderConfig + if ('provider' in networkStateOrProviderConfig) { + await this.#onNetworkStateChangeCallback( + networkStateOrProviderConfig.provider, + ); + } else { + await this.#onNetworkStateChangeCallback( + networkStateOrProviderConfig, + ); } - }, - ); + }); + } else { + this.messagingSystem.subscribe( + 'NetworkController:providerChange', + async (providerConfig) => { + await this.#onNetworkStateChangeCallback(providerConfig); + }, + ); + } + } + + /** + * Updates state and restart polling when updates are received through NetworkController subscription. + * + * @param providerConfig - the configuration for a provider containing critical network info. + */ + async #onNetworkStateChangeCallback(providerConfig: ProviderConfig) { + if (this.chainId !== providerConfig.chainId) { + this.abortController.abort(); + this.abortController = new AbortController(); + this.chainId = providerConfig.chainId; + if (this.state.preventPollingOnNetworkRestart) { + this.clearingTokenListData(); + } else { + // Ensure tokenList is referencing data from correct network + this.update(() => { + return { + ...this.state, + tokenList: this.state.tokensChainsCache[this.chainId]?.data || {}, + }; + }); + await this.restart(); + } + } } /** diff --git a/src/gas/GasFeeController.ts b/src/gas/GasFeeController.ts index 61050664dc..200c6073b2 100644 --- a/src/gas/GasFeeController.ts +++ b/src/gas/GasFeeController.ts @@ -1,5 +1,6 @@ import type { Patch } from 'immer'; +import EthQuery from 'eth-query'; import { v1 as random } from 'uuid'; import { isHexString } from 'ethereumjs-util'; import { BaseController } from '../BaseControllerV2'; @@ -9,6 +10,8 @@ import type { NetworkControllerGetEthQueryAction, NetworkControllerGetProviderConfigAction, NetworkControllerProviderChangeEvent, + NetworkController, + NetworkState, } from '../network/NetworkController'; import { fetchGasEstimates, @@ -269,6 +272,10 @@ export class GasFeeController extends BaseController< * current network is compatible with the legacy gas price API. * @param options.getCurrentAccountEIP1559Compatibility - Determines whether or not the current * account is EIP-1559 compatible. + * @param options.getChainId - Returns the current chain ID. + * @param options.getProvider - Returns a network provider for the current network. + * @param options.onNetworkStateChange - A function for registering an event handler for the + * network state change event. * @param options.legacyAPIEndpoint - The legacy gas price API URL. This option is primarily for * testing purposes. * @param options.EIP1559APIEndpoint - The EIP-1559 gas price API URL. This option is primarily @@ -282,7 +289,10 @@ export class GasFeeController extends BaseController< state, getCurrentNetworkEIP1559Compatibility, getCurrentAccountEIP1559Compatibility, + getChainId, getCurrentNetworkLegacyGasAPICompatibility, + getProvider, + onNetworkStateChange, legacyAPIEndpoint = LEGACY_GAS_PRICES_API_URL, EIP1559APIEndpoint = GAS_FEE_API, clientId, @@ -293,6 +303,9 @@ export class GasFeeController extends BaseController< getCurrentNetworkEIP1559Compatibility: () => Promise; getCurrentNetworkLegacyGasAPICompatibility: () => boolean; getCurrentAccountEIP1559Compatibility?: () => boolean; + getChainId?: () => `0x${string}` | `${number}` | number; + getProvider?: () => NetworkController['provider']; + onNetworkStateChange?: (listener: (state: NetworkState) => void) => void; legacyAPIEndpoint?: string; EIP1559APIEndpoint?: string; clientId?: string; @@ -315,25 +328,41 @@ export class GasFeeController extends BaseController< getCurrentAccountEIP1559Compatibility; this.EIP1559APIEndpoint = EIP1559APIEndpoint; this.legacyAPIEndpoint = legacyAPIEndpoint; - const providerConfig = this.messagingSystem.call( - 'NetworkController:getProviderConfig', - ); - this.currentChainId = providerConfig.chainId; - this.ethQuery = this.messagingSystem.call('NetworkController:getEthQuery'); this.clientId = clientId; - this.messagingSystem.subscribe( - 'NetworkController:providerChange', - async (provider) => { - this.ethQuery = this.messagingSystem.call( - 'NetworkController:getEthQuery', - ); - - if (this.currentChainId !== provider.chainId) { - this.currentChainId = provider.chainId; + if (onNetworkStateChange && getChainId && getProvider) { + this.currentChainId = getChainId(); + onNetworkStateChange(async () => { + const newProvider = getProvider(); + const newChainId = getChainId(); + this.ethQuery = new EthQuery(newProvider); + if (this.currentChainId !== newChainId) { + this.currentChainId = newChainId; await this.resetPolling(); } - }, - ); + }); + } else { + const providerConfig = this.messagingSystem.call( + 'NetworkController:getProviderConfig', + ); + this.currentChainId = providerConfig.chainId; + this.ethQuery = this.messagingSystem.call( + 'NetworkController:getEthQuery', + ); + + this.messagingSystem.subscribe( + 'NetworkController:providerChange', + async (provider) => { + this.ethQuery = this.messagingSystem.call( + 'NetworkController:getEthQuery', + ); + + if (this.currentChainId !== provider.chainId) { + this.currentChainId = provider.chainId; + await this.resetPolling(); + } + }, + ); + } } async resetPolling() {