From 2b832db3e45933321d59a07aa83d4d1524a1e467 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Fri, 29 Sep 2023 16:48:07 +0100 Subject: [PATCH] Make MSC3906 implementation compatible with Rust Crypto --- spec/unit/rendezvous/rendezvous.spec.ts | 202 +++++++++++++++--------- src/rendezvous/MSC3906Rendezvous.ts | 49 +++--- 2 files changed, 154 insertions(+), 97 deletions(-) diff --git a/spec/unit/rendezvous/rendezvous.spec.ts b/spec/unit/rendezvous/rendezvous.spec.ts index eb8a72c26ae..04a504b3968 100644 --- a/spec/unit/rendezvous/rendezvous.spec.ts +++ b/spec/unit/rendezvous/rendezvous.spec.ts @@ -23,7 +23,7 @@ import { MSC3903ECDHPayload, MSC3903ECDHv2RendezvousChannel as MSC3903ECDHRendezvousChannel, } from "../../../src/rendezvous/channels"; -import { MatrixClient } from "../../../src"; +import { Device, MatrixClient } from "../../../src"; import { MSC3886SimpleHttpRendezvousTransport, MSC3886SimpleHttpRendezvousTransportDetails, @@ -31,16 +31,59 @@ import { import { DummyTransport } from "./DummyTransport"; import { decodeBase64 } from "../../../src/base64"; import { logger } from "../../../src/logger"; -import { DeviceInfo } from "../../../src/crypto/deviceinfo"; +import { CrossSigningKey } from "../../../src/crypto-api"; + +type UserID = string; +type DeviceID = string; +type Fingerprint = string; +type PartialUserDevices = Map>; +type PartialDeviceMap = Map; +type SimpleDeviceMap = Record>; + +function mockDevice(userId: UserID, deviceId: DeviceID, fingerprint: Fingerprint): Partial { + return { + deviceId, + userId, + getFingerprint: () => fingerprint, + }; +} + +function mockDeviceMap( + userId: UserID, + deviceId: DeviceID, + deviceKey?: Fingerprint, + otherDevices: SimpleDeviceMap = {}, +): PartialDeviceMap { + const deviceMap: PartialDeviceMap = new Map(); + + const myDevices: PartialUserDevices = new Map(); + if (deviceKey) { + myDevices.set(deviceId, mockDevice(userId, deviceId, deviceKey)); + } + deviceMap.set(userId, myDevices); + + for (const u in otherDevices) { + let userDevices = deviceMap.get(u); + if (!userDevices) { + userDevices = new Map(); + deviceMap.set(u, userDevices); + } + for (const d in otherDevices[u]) { + userDevices.set(d, mockDevice(u, d, otherDevices[u][d])); + } + } + + return deviceMap; +} function makeMockClient(opts: { - userId: string; - deviceId: string; - deviceKey?: string; + userId: UserID; + deviceId: DeviceID; + deviceKey?: Fingerprint; getLoginTokenEnabled: boolean; msc3882r0Only: boolean; msc3886Enabled: boolean; - devices?: Record>; + devices?: SimpleDeviceMap; verificationFunction?: ( userId: string, deviceId: string, @@ -48,50 +91,56 @@ function makeMockClient(opts: { blocked: boolean, known: boolean, ) => void; - crossSigningIds?: Record; -}): MatrixClient { - return { - getVersions() { - return { - unstable_features: { - "org.matrix.msc3882": opts.getLoginTokenEnabled, - "org.matrix.msc3886": opts.msc3886Enabled, - }, - }; - }, - getCapabilities() { - return opts.msc3882r0Only - ? {} - : { - capabilities: { - "m.get_login_token": { - enabled: opts.getLoginTokenEnabled, + crossSigningIds?: Partial>; +}): [MatrixClient, PartialDeviceMap] { + const deviceMap = mockDeviceMap(opts.userId, opts.deviceId, opts.deviceKey, opts.devices); + return [ + { + getVersions() { + return { + unstable_features: { + "org.matrix.msc3882": opts.getLoginTokenEnabled, + "org.matrix.msc3886": opts.msc3886Enabled, + }, + }; + }, + getCapabilities() { + return opts.msc3882r0Only + ? {} + : { + capabilities: { + "m.get_login_token": { + enabled: opts.getLoginTokenEnabled, + }, }, - }, - }; - }, - getUserId() { - return opts.userId; - }, - getDeviceId() { - return opts.deviceId; - }, - getDeviceEd25519Key() { - return opts.deviceKey; - }, - baseUrl: "https://example.com", - crypto: { - getStoredDevice(userId: string, deviceId: string) { - return opts.devices?.[deviceId] ?? null; + }; }, - setDeviceVerification: opts.verificationFunction, - crossSigningInfo: { - getId(key: string) { - return opts.crossSigningIds?.[key]; - }, + getUserId() { + return opts.userId; + }, + getDeviceId() { + return opts.deviceId; + }, + getDeviceEd25519Key() { + return opts.deviceKey; }, - }, - } as unknown as MatrixClient; + baseUrl: "https://example.com", + getCrypto() { + return { + getUserDeviceInfo([userId]: string[], deviceId: string): Promise { + return Promise.resolve(deviceMap); + }, + getCrossSigningKeyId(key: CrossSigningKey): string | null { + return opts.crossSigningIds?.[key] ?? null; + }, + }; + }, + crypto: { + setDeviceVerification: opts.verificationFunction, + }, + } as unknown as MatrixClient, + deviceMap, + ]; } function makeTransport(name: string, uri = "https://test.rz/123456") { @@ -106,6 +155,7 @@ describe("Rendezvous", function () { let httpBackend: MockHttpBackend; let fetchFn: typeof global.fetch; let transports: DummyTransport[]; + const userId: UserID = "@user:example.com"; beforeEach(function () { httpBackend = new MockHttpBackend(); @@ -118,9 +168,9 @@ describe("Rendezvous", function () { }); it("generate and cancel", async function () { - const alice = makeMockClient({ - userId: "@alice:example.com", - deviceId: "DEVICEID", + const [alice] = makeMockClient({ + userId, + deviceId: "ALICE", msc3886Enabled: false, getLoginTokenEnabled: true, msc3882r0Only: true, @@ -194,8 +244,8 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); - const alice = makeMockClient({ - userId: "alice", + const [alice] = makeMockClient({ + userId, deviceId: "ALICE", msc3886Enabled: false, getLoginTokenEnabled, @@ -257,8 +307,8 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); - const alice = makeMockClient({ - userId: "alice", + const [alice] = makeMockClient({ + userId, deviceId: "ALICE", getLoginTokenEnabled: true, msc3882r0Only: false, @@ -316,8 +366,8 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); - const alice = makeMockClient({ - userId: "alice", + const [alice] = makeMockClient({ + userId, deviceId: "ALICE", getLoginTokenEnabled: true, msc3882r0Only: false, @@ -375,7 +425,7 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); - const alice = makeMockClient({ + const [alice] = makeMockClient({ userId: "alice", deviceId: "ALICE", getLoginTokenEnabled: true, @@ -436,7 +486,7 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); - const alice = makeMockClient({ + const [alice] = makeMockClient({ userId: "alice", deviceId: "ALICE", getLoginTokenEnabled: true, @@ -495,7 +545,7 @@ describe("Rendezvous", function () { await bobCompleteProm; }); - async function completeLogin(devices: Record>) { + async function completeLogin(devices: SimpleDeviceMap) { const aliceTransport = makeTransport("Alice", "https://test.rz/123456"); const bobTransport = makeTransport("Bob", "https://test.rz/999999"); transports.push(aliceTransport, bobTransport); @@ -505,8 +555,8 @@ describe("Rendezvous", function () { // alice is already signs in and generates a code const aliceOnFailure = jest.fn(); const aliceVerification = jest.fn(); - const alice = makeMockClient({ - userId: "alice", + const [alice, deviceMap] = makeMockClient({ + userId, deviceId: "ALICE", getLoginTokenEnabled: true, msc3882r0Only: false, @@ -575,13 +625,15 @@ describe("Rendezvous", function () { aliceRz, bobTransport, bobEcdh, + devices, + deviceMap, }; } it("approve on existing device + verification", async function () { const { bobEcdh, aliceRz } = await completeLogin({ - BOB: { - getFingerprint: () => "bbbb", + [userId]: { + BOB: "bbbb", }, }); const verifyProm = aliceRz.verifyNewDeviceOnExistingDevice(); @@ -607,33 +659,29 @@ describe("Rendezvous", function () { }); it("device appears online within timeout", async function () { - const devices: Record> = {}; - const { aliceRz } = await completeLogin(devices); - // device appears after 1 second + const devices: SimpleDeviceMap = {}; + const { aliceRz, deviceMap } = await completeLogin(devices); + // device appears before the timeout setTimeout(() => { - devices.BOB = { - getFingerprint: () => "bbbb", - }; + deviceMap.get(userId)?.set("BOB", mockDevice(userId, "BOB", "bbbb")); }, 1000); await aliceRz.verifyNewDeviceOnExistingDevice(2000); }); it("device appears online after timeout", async function () { - const devices: Record> = {}; - const { aliceRz } = await completeLogin(devices); - // device appears after 1 second + const devices: SimpleDeviceMap = {}; + const { aliceRz, deviceMap } = await completeLogin(devices); + // device appears after the timeout setTimeout(() => { - devices.BOB = { - getFingerprint: () => "bbbb", - }; + deviceMap.get(userId)?.set("BOB", mockDevice(userId, "BOB", "bbbb")); }, 1500); await expect(aliceRz.verifyNewDeviceOnExistingDevice(1000)).rejects.toThrow(); }); it("mismatched device key", async function () { const { aliceRz } = await completeLogin({ - BOB: { - getFingerprint: () => "XXXX", + [userId]: { + BOB: "XXXX", }, }); await expect(aliceRz.verifyNewDeviceOnExistingDevice(1000)).rejects.toThrow(/different key/); diff --git a/src/rendezvous/MSC3906Rendezvous.ts b/src/rendezvous/MSC3906Rendezvous.ts index 4a92bbd7797..21331272cff 100644 --- a/src/rendezvous/MSC3906Rendezvous.ts +++ b/src/rendezvous/MSC3906Rendezvous.ts @@ -17,12 +17,12 @@ limitations under the License. import { UnstableValue } from "matrix-events-sdk"; import { RendezvousChannel, RendezvousFailureListener, RendezvousFailureReason, RendezvousIntent } from "."; -import { ICrossSigningKey, IGetLoginTokenCapability, MatrixClient, GET_LOGIN_TOKEN_CAPABILITY } from "../client"; -import { CrossSigningInfo } from "../crypto/CrossSigning"; -import { DeviceInfo } from "../crypto/deviceinfo"; +import { IGetLoginTokenCapability, MatrixClient, GET_LOGIN_TOKEN_CAPABILITY } from "../client"; import { buildFeatureSupportMap, Feature, ServerSupport } from "../feature"; import { logger } from "../logger"; import { sleep } from "../utils"; +import { CrossSigningKey } from "../crypto-api"; +import { Device } from "../matrix"; enum PayloadType { Start = "m.login.start", @@ -178,10 +178,9 @@ export class MSC3906Rendezvous { return deviceId; } - private async verifyAndCrossSignDevice( - deviceInfo: DeviceInfo, - ): Promise { - if (!this.client.crypto) { + private async verifyAndCrossSignDevice(deviceInfo: Device): Promise { + const crypto = this.client.getCrypto(); + if (!crypto) { throw new Error("Crypto not available on client"); } @@ -203,19 +202,21 @@ export class MSC3906Rendezvous { } // mark the device as verified locally + cross sign logger.info(`Marking device ${this.newDeviceId} as verified`); - const info = await this.client.crypto.setDeviceVerification(userId, this.newDeviceId, true, false, true); + // TODO: this function isn't available with rust crypto backend + await this.client.crypto!.setDeviceVerification(userId, this.newDeviceId, true, false, true); - const masterPublicKey = this.client.crypto.crossSigningInfo.getId("master")!; + const masterPublicKey = (await crypto.getCrossSigningKeyId(CrossSigningKey.Master)) ?? undefined; + + const ourDeviceId = this.client.getDeviceId(); + const ourDevice = ourDeviceId ? await this.getOwnDevice(ourDeviceId) : undefined; await this.send({ type: PayloadType.Finish, outcome: Outcome.Verified, - verifying_device_id: this.client.getDeviceId()!, - verifying_device_key: this.client.getDeviceEd25519Key()!, + verifying_device_id: ourDevice?.deviceId, + verifying_device_key: ourDevice?.getFingerprint(), master_key: masterPublicKey, }); - - return info; } /** @@ -223,9 +224,7 @@ export class MSC3906Rendezvous { * @param timeout - time in milliseconds to wait for device to come online * @returns the new device info if the device was verified */ - public async verifyNewDeviceOnExistingDevice( - timeout = 10 * 1000, - ): Promise { + public async verifyNewDeviceOnExistingDevice(timeout = 10 * 1000): Promise { if (!this.newDeviceId) { throw new Error("No new device to sign"); } @@ -235,7 +234,8 @@ export class MSC3906Rendezvous { return undefined; } - if (!this.client.crypto) { + const crypto = this.client.getCrypto(); + if (!crypto) { throw new Error("Crypto not available on client"); } @@ -245,21 +245,30 @@ export class MSC3906Rendezvous { throw new Error("No user ID set"); } - let deviceInfo = this.client.crypto.getStoredDevice(userId, this.newDeviceId); + let deviceInfo = await this.getOwnDevice(this.newDeviceId); if (!deviceInfo) { logger.info("Going to wait for new device to be online"); await sleep(timeout); - deviceInfo = this.client.crypto.getStoredDevice(userId, this.newDeviceId); + deviceInfo = await this.getOwnDevice(this.newDeviceId); } if (deviceInfo) { - return await this.verifyAndCrossSignDevice(deviceInfo); + await this.verifyAndCrossSignDevice(deviceInfo); + return; } throw new Error("Device not online within timeout"); } + private async getOwnDevice(deviceId: string): Promise { + const userId = this.client.getUserId(); + if (!userId) { + return undefined; + } + return (await this.client.getCrypto()?.getUserDeviceInfo([userId], true))?.get(userId)?.get(deviceId); + } + public async cancel(reason: RendezvousFailureReason): Promise { this.onFailure?.(reason); await this.channel.cancel(reason);