From 9cb69cbbf28cdd5692ab380ba247ffece78b9fae Mon Sep 17 00:00:00 2001 From: Robert Chu Date: Wed, 21 Aug 2024 18:49:09 -0700 Subject: [PATCH] Adds waffle-compatibility layer. --- hardhat.config.ts | 2 +- src/compat/waffle.ts | 292 +++++++++++++++++++++++++++++++++++++ test/waffle-compat.test.ts | 122 ++++++++++++++++ 3 files changed, 415 insertions(+), 1 deletion(-) create mode 100644 src/compat/waffle.ts create mode 100644 test/waffle-compat.test.ts diff --git a/hardhat.config.ts b/hardhat.config.ts index d7ffa55..ee26ece 100644 --- a/hardhat.config.ts +++ b/hardhat.config.ts @@ -1,4 +1,4 @@ -import type { HardhatUserConfig } from "hardhat/config"; +import { type HardhatUserConfig } from "hardhat/config"; import "@nomicfoundation/hardhat-toolbox-viem"; const config: HardhatUserConfig = { diff --git a/src/compat/waffle.ts b/src/compat/waffle.ts new file mode 100644 index 0000000..89b2447 --- /dev/null +++ b/src/compat/waffle.ts @@ -0,0 +1,292 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + Abi, + AbiFunction, + AbiParametersToPrimitiveTypes, + ExtractAbiFunctionNames, + formatAbiItem, +} from "abitype"; +import { + deployMock, + MockCallExpectation, + MockContractController, +} from "../mock-contract"; +import { PublicClient, WalletClient } from "viem"; + +export const doppelgangerAbi = [ + { + "stateMutability": "payable", + "type": "fallback" + }, + { + "inputs": [ + { + "internalType": "bytes", + "name": "data", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "value", + "type": "bytes" + } + ], + "name": "__doppelganger__mockReturns", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes", + "name": "data", + "type": "bytes" + }, + { + "internalType": "string", + "name": "reason", + "type": "string" + } + ], + "name": "__doppelganger__mockReverts", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes", + "name": "data", + "type": "bytes" + }, + { + "internalType": "bytes", + "name": "value", + "type": "bytes" + } + ], + "name": "__doppelganger__queueReturn", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "bytes", + "name": "data", + "type": "bytes" + }, + { + "internalType": "string", + "name": "reason", + "type": "string" + } + ], + "name": "__doppelganger__queueRevert", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "inputs": [ + { + "internalType": "string", + "name": "reason", + "type": "string" + } + ], + "name": "__doppelganger__receiveReverts", + "outputs": [], + "stateMutability": "nonpayable", + "type": "function" + }, + { + "stateMutability": "payable", + "type": "receive" + } +] as const; + +interface StubInterface extends Pick, "then"> { + returns(...args: any): StubInterface; + reverts(): StubInterface; + revertsWithReason(reason: string): StubInterface; + withArgs(...args: any[]): StubInterface; +} + +export interface MockContract { + mock: { + [key in ExtractAbiFunctionNames | "receive"]: StubInterface; + }; + address: `0x${string}`; +} + +class Stub implements StubInterface { + calls: MockCallExpectation[] = []; + inputs: AbiParametersToPrimitiveTypes | undefined = undefined; + + revertSet = false; + argsSet = false; + + constructor( + private mockContract: MockContractController, + private func: T, + ) {} + + private err(reason: string): never { + this.revertSet = false; + this.argsSet = false; + throw new Error(reason); + } + + returns(...args: AbiParametersToPrimitiveTypes) { + if (this.revertSet) this.err("Revert must be the last call"); + if (!this.func.outputs) + this.err("Cannot mock return values from a void function"); + + this.calls.push({ + kind: "read", + abi: this.func, + inputs: this.inputs, + outputs: args, + }); + + return this; + } + + reverts() { + if (this.revertSet) this.err("Revert must be the last call"); + + this.calls.push({ + kind: "revert", + abi: this.func, + inputs: this.inputs, + reason: "Mock revert", + }); + + this.revertSet = true; + return this; + } + + revertsWithReason(reason: string) { + if (this.revertSet) this.err("Revert must be the last call"); + + this.calls.push({ + kind: "revert", + abi: this.func, + inputs: this.inputs, + reason, + }); + + this.revertSet = true; + return this; + } + + withArgs(...params: AbiParametersToPrimitiveTypes) { + if (this.argsSet) this.err("withArgs can be called only once"); + this.inputs = params; + this.argsSet = true; + return this; + } + + async then( + resolve?: + | ((value: void) => TResult1 | PromiseLike) + | null + | undefined, + reject?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined, + ): Promise { + if (this.argsSet) { + this.calls.push({ + kind: "write", + abi: this.func, + inputs: this.inputs, + }); + } + + try { + await this.mockContract.setup(...this.calls); + } catch (e) { + this.argsSet = false; + this.revertSet = false; + reject?.(e); + return undefined as never; + } + this.argsSet = false; + this.revertSet = false; + resolve?.(); + return undefined as never; + } +} + +function createMock( + abi: T, + mockContractInstance: MockContractController, + // wallet: WalletClient, +): MockContract["mock"] { + const functions = abi.filter( + (f) => f.type === "function", + ) as AbiFunction[]; + const mockedAbi = Object.values(functions).reduce( + (acc, func) => { + const stubbed = new Stub(mockContractInstance, func); + return { + ...acc, + [func.name]: stubbed, + [formatAbiItem(func)]: stubbed, + }; + }, + {} as MockContract["mock"], + ); + + // (mockedAbi as any).receive = { + // returns: () => { + // throw new Error("Receive function return is not implemented."); + // }, + // withArgs: () => { + // throw new Error("Receive function return is not implemented."); + // }, + // reverts: () => wallet.writeContract({ + // address: mockContractInstance.address, + // abi: doppelgangerAbi, + // functionName: "__doppelganger__receiveReverts", + // account: wallet.account!, + // chain: wallet.chain, + // args: ["Mock Revert"], + // }), + // revertsWithReason: (reason: string) => wallet.writeContract({ + // address: mockContractInstance.address, + // abi: doppelgangerAbi, + // functionName: "__doppelganger__receiveReverts", + // account: wallet.account!, + // chain: wallet.chain, + // args: [reason], + // }), + // }; + + return mockedAbi; +} + +export async function deployMockContract( + wallet: WalletClient, + reader: PublicClient, + abi: T, +): Promise> { + const mockContractInstance = await deployMock(wallet, reader); + + const mock = createMock( + abi, + mockContractInstance as unknown as MockContractController, + // wallet, + ); + + return { + mock, + address: mockContractInstance.address, + }; +} diff --git a/test/waffle-compat.test.ts b/test/waffle-compat.test.ts new file mode 100644 index 0000000..f3fc58a --- /dev/null +++ b/test/waffle-compat.test.ts @@ -0,0 +1,122 @@ +import { expect } from "chai"; +import { deployMockContract } from "../src/compat/waffle"; +import hre from "hardhat"; +import { zeroAddress } from "viem"; + +const erc20ABI = [ + { + type: "function", + name: "balanceOf", + stateMutability: "view", + inputs: [{ type: "address" }], + outputs: [{ type: "uint256" }], + }, + { + type: "function", + name: "decimals", + stateMutability: "view", + inputs: [], + outputs: [{ type: "uint8" }], + }, + { + type: "function", + name: "transfer", + stateMutability: "nonpayable", + inputs: [{ type: "address" }, { type: "uint256" }], + outputs: [], + }, + { + type: "event", + name: "Transfer", + inputs: [{ type: "address" }, { type: "address" }, { type: "uint256" }], + anonymous: false, + }, +] as const; + +describe("waffle", function () { + describe("compat", function () { + it("Should allow for the mocking of read calls", async function () { + const reader = await hre.viem.getPublicClient(); + const [signer] = await hre.viem.getWalletClients(); + const mock = await deployMockContract(signer, reader, erc20ABI); + console.log(`Deployed mock at ${mock.address}`); + + await mock.mock.balanceOf.withArgs(zeroAddress).returns(100n); + + expect(await reader.readContract({ + address: mock.address, + abi: erc20ABI, + functionName: "balanceOf", + args: [zeroAddress], + })).to.equal(100n); + }); + + it("Should allow for the mocking of write calls", async function () { + const reader = await hre.viem.getPublicClient(); + const [signer] = await hre.viem.getWalletClients(); + const mock = await deployMockContract(signer, reader, erc20ABI); + + await mock.mock.transfer.withArgs(zeroAddress, 100n); + + await signer.writeContract({ + address: mock.address, + abi: erc20ABI, + functionName: "transfer", + args: [zeroAddress, 100n], + }); + }); + + it("Should allow for the mocking of reverts on read calls", async function () { + const reader = await hre.viem.getPublicClient(); + const [signer] = await hre.viem.getWalletClients(); + const mock = await deployMockContract(signer, reader, erc20ABI); + + await mock.mock.balanceOf.withArgs(zeroAddress).revertsWithReason("Custom reason"); + + try { + await reader.readContract({ + address: mock.address, + abi: erc20ABI, + functionName: "balanceOf", + args: [zeroAddress], + }); + } catch (error) { + expect((error as Error).message).to.contain("Custom reason"); + } + }); + + it("Should fail if the mock is not set up", async function () { + const reader = await hre.viem.getPublicClient(); + const [signer] = await hre.viem.getWalletClients(); + const mock = await deployMockContract(signer, reader, erc20ABI); + + try { + await reader.readContract({ + address: mock.address, + abi: erc20ABI, + functionName: "balanceOf", + args: [zeroAddress], + }); + } catch (error) { + expect((error as Error).message).to.contain( + "Mock on the method is not initialized", + ); + } + }); + + it("Should allow undefined call.inputs for read calls", async function () { + const reader = await hre.viem.getPublicClient(); + const [signer] = await hre.viem.getWalletClients(); + const mock = await deployMockContract(signer, reader, erc20ABI); + + await mock.mock.balanceOf.returns(20998n); + + expect(await reader.readContract({ + address: mock.address, + abi: erc20ABI, + functionName: "balanceOf", + args: [zeroAddress], + })).to.equal(20998n); + }); + }); +});