Skip to content

Commit

Permalink
solana: Multi Transceiver Support (#528)
Browse files Browse the repository at this point in the history
* Refactor `src/transceivers/*` and `src/messages.rs` from Manager into `ntt-transceiver` program
* Add zero-copy deserialization helpers for `ValidatedTransceiverMessage` in Manager
* Update transceiver's `release_outbound` ix to CPI into Manager's `mark_outbox_item_as_released` ix
* Update `Makefile` to remove generics (if exists) from all programs
* Update SDK interfaces to be more generic and add backwards-compatibility wrappers
* Bump IDL to version 3.0.0
  • Loading branch information
nvsriram authored Dec 3, 2024
1 parent 5e7ceae commit 738c67b
Show file tree
Hide file tree
Showing 57 changed files with 14,627 additions and 3,080 deletions.
11 changes: 5 additions & 6 deletions cli/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ yargs(hideBin(process.argv))
await signSendWait(ctx, tx, signer.signer)
}
for (const transceiver of missingConfig.transceiverPeers) {
const tx = ntt.setWormholeTransceiverPeer(transceiver, signer.address.address)
const tx = ntt.setTransceiverPeer(0, transceiver, signer.address.address)
await signSendWait(ctx, tx, signer.signer)
}
for (const evmChain of missingConfig.evmChains) {
Expand All @@ -696,10 +696,9 @@ yargs(hideBin(process.argv))
continue;
}
const solanaNtt = ntt as SolanaNtt<Network, SolanaChains>;
const tx = solanaNtt.registerTransceiver({
const tx = solanaNtt.registerWormholeTransceiver({
payer: signer.address.address as AccountAddress<SolanaChains>,
owner: signer.address.address as AccountAddress<SolanaChains>,
transceiver: solanaNtt.program.programId
})
try {
await signSendWait(ctx, tx, signer.signer)
Expand Down Expand Up @@ -1291,7 +1290,7 @@ async function deploySolana<N extends Network, C extends SolanaChains>(
// time by checking it here and failing early (not to mention better
// diagnostics).

const emitter = NTT.pdas(providedProgramId).emitterAccount().toBase58();
const emitter = NTT.transceiverPdas(providedProgramId).emitterAccount().toBase58();
const payerKeypair = Keypair.fromSecretKey(new Uint8Array(JSON.parse(fs.readFileSync(payer).toString())));

// this is not super pretty... I want to initialise the 'ntt' object, but
Expand Down Expand Up @@ -1787,7 +1786,7 @@ async function getPdas<N extends Network, C extends Chain>(chain: C, ntt: Ntt<N,
}
const solanaNtt = ntt as SolanaNtt<N, SolanaChains>;
const config = solanaNtt.pdas.configAccount();
const emitter = solanaNtt.pdas.emitterAccount();
const emitter = NTT.transceiverPdas(solanaNtt.program.programId).emitterAccount();
const outboxRateLimit = solanaNtt.pdas.outboxRateLimitAccount();
const tokenAuthority = solanaNtt.pdas.tokenAuthority();
const lutAccount = solanaNtt.pdas.lutAccount();
Expand Down Expand Up @@ -1826,7 +1825,7 @@ async function nttFromManager<N extends Network, C extends Chain>(
ntt: {
manager: nativeManagerAddress,
token: null,
transceiver: { wormhole: null },
transceiver: {},
}
});
const diff = await onlyManager.verifyAddresses();
Expand Down
145 changes: 104 additions & 41 deletions evm/ts/src/ntt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ import {
toUniversal,
universalAddress,
} from "@wormhole-foundation/sdk-definitions";
import type { AnyEvmAddress, EvmChains, EvmPlatformType } from "@wormhole-foundation/sdk-evm";
import type {
AnyEvmAddress,
EvmChains,
EvmPlatformType,
} from "@wormhole-foundation/sdk-evm";
import {
EvmAddress,
EvmPlatform,
Expand All @@ -26,6 +30,7 @@ import {
import "@wormhole-foundation/sdk-evm-core";

import {
EvmNttTransceiver,
Ntt,
NttTransceiver,
WormholeNttTransceiver,
Expand All @@ -39,7 +44,10 @@ import {
} from "./bindings.js";

export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
implements NttTransceiver<N, C, WormholeNttTransceiver.VAA> {
implements
WormholeNttTransceiver<N, C>,
EvmNttTransceiver<N, C, WormholeNttTransceiver.VAA>
{
transceiver: NttTransceiverBindings.NttTransceiver;
constructor(
readonly manager: EvmNtt<N, C>,
Expand All @@ -52,15 +60,26 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
);
}

async getTransceiverType(): Promise<string> {
// NOTE: We hardcode the type here as transceiver type is only available for versions >1.1.0
// For those versions, we can return `this.transceiver.getTransceiverType()` directly
return "wormhole";
}

getAddress(): ChainAddress<C> {
return { chain: this.manager.chain, address: toUniversal(this.manager.chain, this.address) };
return {
chain: this.manager.chain,
address: toUniversal(this.manager.chain, this.address),
};
}

encodeFlags(flags: { skipRelay: boolean }): Uint8Array {
return new Uint8Array([flags.skipRelay ? 1 : 0]);
}

async *setPeer<P extends Chain>(peer: ChainAddress<P>): AsyncGenerator<EvmUnsignedTransaction<N, C>> {
async *setPeer<P extends Chain>(
peer: ChainAddress<P>
): AsyncGenerator<EvmUnsignedTransaction<N, C>> {
const tx = await this.transceiver.setWormholePeer.populateTransaction(
toChainId(peer.chain),
universalAddress(peer)
Expand All @@ -74,8 +93,14 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
}

async *setPauser(pauser: AccountAddress<C>) {
const canonicalPauser = canonicalAddress({chain: this.manager.chain, address: pauser});
const tx = await this.transceiver.transferPauserCapability.populateTransaction(canonicalPauser);
const canonicalPauser = canonicalAddress({
chain: this.manager.chain,
address: pauser,
});
const tx =
await this.transceiver.transferPauserCapability.populateTransaction(
canonicalPauser
);
yield this.manager.createUnsignedTx(tx, "WormholeTransceiver.setPauser");
}

Expand All @@ -102,7 +127,10 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
toChainId(chain),
isEvm
);
yield this.manager.createUnsignedTx(tx, "WormholeTransceiver.setIsEvmChain");
yield this.manager.createUnsignedTx(
tx,
"WormholeTransceiver.setIsEvmChain"
);
}

async *receive(attestation: WormholeNttTransceiver.VAA) {
Expand All @@ -122,10 +150,11 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
}

async *setIsWormholeRelayingEnabled(destChain: Chain, enabled: boolean) {
const tx = await this.transceiver.setIsWormholeRelayingEnabled.populateTransaction(
toChainId(destChain),
enabled
);
const tx =
await this.transceiver.setIsWormholeRelayingEnabled.populateTransaction(
toChainId(destChain),
enabled
);
yield this.manager.createUnsignedTx(
tx,
"WormholeTransceiver.setWormholeRelayingEnabled"
Expand All @@ -139,10 +168,11 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
}

async *setIsSpecialRelayingEnabled(destChain: Chain, enabled: boolean) {
const tx = await this.transceiver.setIsSpecialRelayingEnabled.populateTransaction(
toChainId(destChain),
enabled
);
const tx =
await this.transceiver.setIsSpecialRelayingEnabled.populateTransaction(
toChainId(destChain),
enabled
);
yield this.manager.createUnsignedTx(
tx,
"WormholeTransceiver.setSpecialRelayingEnabled"
Expand All @@ -151,7 +181,8 @@ export class EvmNttWormholeTranceiver<N extends Network, C extends EvmChains>
}

export class EvmNtt<N extends Network, C extends EvmChains>
implements Ntt<N, C> {
implements Ntt<N, C>
{
tokenAddress: string;
readonly chainId: bigint;
manager: NttManagerBindings.NttManager;
Expand Down Expand Up @@ -182,17 +213,32 @@ export class EvmNtt<N extends Network, C extends EvmChains>
this.provider
);

if (contracts.ntt.transceiver.wormhole != null) {
this.xcvrs = [
// Enable more Transceivers here
new EvmNttWormholeTranceiver(
this,
contracts.ntt.transceiver.wormhole,
abiBindings!
),
this.xcvrs = [];
if (
"wormhole" in contracts.ntt.transceiver &&
contracts.ntt.transceiver["wormhole"]
) {
const transceiverTypes = [
"wormhole", // wormhole xcvr should be ix 0
...Object.keys(contracts.ntt.transceiver).filter((transceiverType) => {
transceiverType !== "wormhole";
}),
];
} else {
this.xcvrs = [];
transceiverTypes.map((transceiverType) => {
// we currently only support wormhole transceivers
if (transceiverType !== "wormhole") {
throw new Error(`Unsupported transceiver type: ${transceiverType}`);
}

// Enable more Transceivers here
this.xcvrs.push(
new EvmNttWormholeTranceiver(
this,
contracts.ntt!.transceiver[transceiverType]!,
abiBindings!
)
);
});
}
}

Expand All @@ -211,12 +257,12 @@ export class EvmNtt<N extends Network, C extends EvmChains>
}

async *pause() {
const tx = await this.manager.pause.populateTransaction()
const tx = await this.manager.pause.populateTransaction();
yield this.createUnsignedTx(tx, "Ntt.pause");
}

async *unpause() {
const tx = await this.manager.unpause.populateTransaction()
const tx = await this.manager.unpause.populateTransaction();
yield this.createUnsignedTx(tx, "Ntt.unpause");
}

Expand All @@ -230,13 +276,17 @@ export class EvmNtt<N extends Network, C extends EvmChains>

async *setOwner(owner: AnyEvmAddress) {
const canonicalOwner = new EvmAddress(owner).toString();
const tx = await this.manager.transferOwnership.populateTransaction(canonicalOwner);
const tx = await this.manager.transferOwnership.populateTransaction(
canonicalOwner
);
yield this.createUnsignedTx(tx, "Ntt.setOwner");
}

async *setPauser(pauser: AnyEvmAddress) {
const canonicalPauser = new EvmAddress(pauser).toString();
const tx = await this.manager.transferPauserCapability.populateTransaction(canonicalPauser);
const tx = await this.manager.transferPauserCapability.populateTransaction(
canonicalPauser
);
yield this.createUnsignedTx(tx, "Ntt.setPauser");
}

Expand Down Expand Up @@ -398,9 +448,14 @@ export class EvmNtt<N extends Network, C extends EvmChains>
}

async *setWormholeTransceiverPeer(peer: ChainAddress<C>) {
// TODO: we only have one right now, so just set the peer on that one
// in the future, these should(?) be keyed by attestation type
yield* this.xcvrs[0]!.setPeer(peer);
yield* this.setTransceiverPeer(0, peer);
}

async *setTransceiverPeer(ix: number, peer: ChainAddress<C>) {
if (ix >= this.xcvrs.length) {
throw new Error("Transceiver not found");
}
yield* this.xcvrs[ix]!.setPeer(peer);
}

async *transfer(
Expand Down Expand Up @@ -475,7 +530,9 @@ export class EvmNtt<N extends Network, C extends EvmChains>
}

async getOutboundLimit(): Promise<bigint> {
const encoded: EncodedTrimmedAmount = (await this.manager.getOutboundLimitParams()).limit;
const encoded: EncodedTrimmedAmount = (
await this.manager.getOutboundLimitParams()
).limit;
const trimmedAmount: TrimmedAmount = decodeTrimmedAmount(encoded);
const tokenDecimals = await this.getTokenDecimals();

Expand All @@ -492,7 +549,9 @@ export class EvmNtt<N extends Network, C extends EvmChains>
}

async getInboundLimit(fromChain: Chain): Promise<bigint> {
const encoded: EncodedTrimmedAmount = (await this.manager.getInboundLimitParams(toChainId(fromChain))).limit;
const encoded: EncodedTrimmedAmount = (
await this.manager.getInboundLimitParams(toChainId(fromChain))
).limit;
const trimmedAmount: TrimmedAmount = decodeTrimmedAmount(encoded);
const tokenDecimals = await this.getTokenDecimals();

Expand Down Expand Up @@ -547,7 +606,7 @@ export class EvmNtt<N extends Network, C extends EvmChains>
manager: this.managerAddress,
token: this.tokenAddress,
transceiver: {
wormhole: this.xcvrs[0]?.address,
...(this.xcvrs.length > 0 && { wormhole: this.xcvrs[0]!.address }),
},
// TODO: what about the quoter?
};
Expand All @@ -556,7 +615,7 @@ export class EvmNtt<N extends Network, C extends EvmChains>
manager: this.managerAddress,
token: await this.manager.token(),
transceiver: {
wormhole: (await this.manager.getTransceivers())[0]! // TODO: make this more generic
wormhole: (await this.manager.getTransceivers())[0]!, // TODO: make this more generic
},
};

Expand All @@ -569,7 +628,7 @@ export class EvmNtt<N extends Network, C extends EvmChains>
delete a[k];
}
}
}
};

deleteMatching(remote, local);

Expand Down Expand Up @@ -612,14 +671,18 @@ function untrim(trimmed: TrimmedAmount, toDecimals: number): bigint {
return scale(amount, fromDecimals, toDecimals);
}

function scale(amount: bigint, fromDecimals: number, toDecimals: number): bigint {
function scale(
amount: bigint,
fromDecimals: number,
toDecimals: number
): bigint {
if (fromDecimals == toDecimals) {
return amount;
}

if (fromDecimals > toDecimals) {
return amount / (10n ** BigInt(fromDecimals - toDecimals));
return amount / 10n ** BigInt(fromDecimals - toDecimals);
} else {
return amount * (10n ** BigInt(toDecimals - fromDecimals));
return amount * 10n ** BigInt(toDecimals - fromDecimals);
}
}
Loading

0 comments on commit 738c67b

Please sign in to comment.