From 4cddc7397d44ae5fa2c9d003e8a906fe86381b93 Mon Sep 17 00:00:00 2001 From: Valere Date: Fri, 19 Jan 2024 11:08:45 +0100 Subject: [PATCH] Decrypt and Import full backups in chunk with progress (#4005) * Decrypt and Import full backups in chunk with progress * backup chunk decryption jsdoc * Review: fix capitalization * review: better var name * review: fix better iterate on object * review: extract utility function * review: Improve test, ensure mock calls * review: Add more test for decryption or import failures * Review: fix typo Co-authored-by: Andy Balaam --------- Co-authored-by: Andy Balaam --- spec/integ/crypto/megolm-backup.spec.ts | 210 +++++++++++++++++++--- spec/unit/rust-crypto/rust-crypto.spec.ts | 2 +- src/client.ts | 157 +++++++++++++--- src/crypto-api.ts | 6 +- src/crypto/keybackup.ts | 4 +- src/rust-crypto/backup.ts | 5 +- 6 files changed, 332 insertions(+), 52 deletions(-) diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index d7f8644c8e0..f6bd53ef476 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -17,8 +17,18 @@ limitations under the License. import fetchMock from "fetch-mock-jest"; import "fake-indexeddb/auto"; import { IDBFactory } from "fake-indexeddb"; +import { Mocked } from "jest-mock"; -import { createClient, CryptoEvent, ICreateClientOpts, IEvent, MatrixClient, TypedEventEmitter } from "../../../src"; +import { + createClient, + CryptoApi, + CryptoEvent, + ICreateClientOpts, + IEvent, + IMegolmSessionData, + MatrixClient, + TypedEventEmitter, +} from "../../../src"; import { SyncResponder } from "../../test-utils/SyncResponder"; import { E2EKeyReceiver } from "../../test-utils/E2EKeyReceiver"; import { E2EKeyResponder } from "../../test-utils/E2EKeyResponder"; @@ -31,7 +41,7 @@ import { syncPromise, } from "../../test-utils/test-utils"; import * as testData from "../../test-utils/test-data"; -import { KeyBackupInfo } from "../../../src/crypto-api/keybackup"; +import { KeyBackupInfo, KeyBackupSession } from "../../../src/crypto-api/keybackup"; import { IKeyBackup } from "../../../src/crypto/backup"; import { flushPromises } from "../../test-utils/flushPromises"; import { defer, IDeferred } from "../../../src/utils"; @@ -286,17 +296,21 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe }); describe("recover from backup", () => { - it("can restore from backup (Curve25519 version)", async function () { + let aliceCrypto: CryptoApi; + + beforeEach(async () => { fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); aliceClient = await initTestClient(); - const aliceCrypto = aliceClient.getCrypto()!; + aliceCrypto = aliceClient.getCrypto()!; await aliceClient.startClient(); // tell Alice to trust the dummy device that signed the backup await waitForDeviceList(); await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + }); + it("can restore from backup (Curve25519 version)", async function () { const fullBackup = { rooms: { [ROOM_ID]: { @@ -340,17 +354,179 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe expect(afterCache.imported).toStrictEqual(1); }); - it("recover specific session from backup", async function () { - fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + /** + * Creates a mock backup response of a GET `room_keys/keys` with a given number of keys per room. + * @param keysPerRoom The number of keys per room + */ + function createBackupDownloadResponse(keysPerRoom: number[]) { + const response: { + rooms: { + [roomId: string]: { + sessions: { + [sessionId: string]: KeyBackupSession; + }; + }; + }; + } = { rooms: {} }; + + const expectedTotal = keysPerRoom.reduce((a, b) => a + b, 0); + for (let i = 0; i < keysPerRoom.length; i++) { + const roomId = `!room${i}:example.com`; + response.rooms[roomId] = { sessions: {} }; + for (let j = 0; j < keysPerRoom[i]; j++) { + const sessionId = `session${j}`; + // Put the same fake session data, not important for that test + response.rooms[roomId].sessions[sessionId] = testData.CURVE25519_KEY_BACKUP_DATA; + } + } + return { response, expectedTotal }; + } - aliceClient = await initTestClient(); - const aliceCrypto = aliceClient.getCrypto()!; - await aliceClient.startClient(); + it("Should import full backup in chunks", async function () { + const importMockImpl = jest.fn(); + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = importMockImpl; - // tell Alice to trust the dummy device that signed the backup - await waitForDeviceList(); - await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + // We need several rooms with several sessions to test chunking + const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]); + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.imported).toStrictEqual(expectedTotal); + // Should be called 5 times: 200*4 plus one chunk with the remaining 32 + expect(importMockImpl).toHaveBeenCalledTimes(5); + for (let i = 0; i < 4; i++) { + expect(importMockImpl.mock.calls[i][0].length).toEqual(200); + } + expect(importMockImpl.mock.calls[4][0].length).toEqual(32); + + expect(progressCallback).toHaveBeenCalledWith({ + stage: "fetch", + }); + + // Should be called 4 times and report 200/400/600/800 + for (let i = 0; i < 4; i++) { + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: (i + 1) * 200, + stage: "load_keys", + failures: 0, + }); + } + + // The last chunk + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 832, + stage: "load_keys", + failures: 0, + }); + }); + + it("Should continue to process backup if a chunk import fails and report failures", async function () { + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest + .fn() + .mockImplementationOnce(() => { + // Fail to import first chunk + throw new Error("test error"); + }) + // Ok for other chunks + .mockResolvedValue(undefined); + + const { response, expectedTotal } = createBackupDownloadResponse([100, 300]); + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.total).toStrictEqual(expectedTotal); + // A chunk failed to import + expect(result.imported).toStrictEqual(200); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 0, + stage: "load_keys", + failures: 200, + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 200, + stage: "load_keys", + failures: 200, + }); + }); + + it("Should continue if some keys fails to decrypt", async function () { + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest.fn(); + + const decryptionFailureCount = 2; + + const mockDecryptor = { + // DecryptSessions does not reject on decryption failure, but just skip the key + decryptSessions: jest.fn().mockImplementation((sessions) => { + // simulate fail to decrypt 2 keys out of all + const decrypted = []; + const keys = Object.keys(sessions); + for (let i = 0; i < keys.length - decryptionFailureCount; i++) { + decrypted.push({ + session_id: keys[i], + } as unknown as Mocked); + } + return decrypted; + }), + free: jest.fn(), + }; + + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.getBackupDecryptor = jest.fn().mockResolvedValue(mockDecryptor); + const { response, expectedTotal } = createBackupDownloadResponse([100]); + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + ); + + expect(result.total).toStrictEqual(expectedTotal); + // A chunk failed to import + expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount); + }); + + it("recover specific session from backup", async function () { fetchMock.get( "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", testData.CURVE25519_KEY_BACKUP_DATA, @@ -371,16 +547,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe }); it("Fails on bad recovery key", async function () { - fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); - - aliceClient = await initTestClient(); - const aliceCrypto = aliceClient.getCrypto()!; - await aliceClient.startClient(); - - // tell Alice to trust the dummy device that signed the backup - await waitForDeviceList(); - await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); - const fullBackup = { rooms: { [ROOM_ID]: { diff --git a/spec/unit/rust-crypto/rust-crypto.spec.ts b/spec/unit/rust-crypto/rust-crypto.spec.ts index 28f3c653e4f..6d3ecfd81c8 100644 --- a/spec/unit/rust-crypto/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto/rust-crypto.spec.ts @@ -341,7 +341,7 @@ describe("RustCrypto", () => { let importTotal = 0; const opt: ImportRoomKeysOpts = { progressCallback: (stage) => { - importTotal = stage.total; + importTotal = stage.total ?? 0; }, }; await rustCrypto.importRoomKeys(someRoomKeys, opt); diff --git a/src/client.ts b/src/client.ts index f0de776d0a6..3acc0135209 100644 --- a/src/client.ts +++ b/src/client.ts @@ -209,7 +209,7 @@ import { IgnoredInvites } from "./models/invites-ignorer"; import { UIARequest, UIAResponse } from "./@types/uia"; import { LocalNotificationSettings } from "./@types/local_notifications"; import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature"; -import { CryptoBackend } from "./common-crypto/CryptoBackend"; +import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api"; import { DeviceInfoMap } from "./crypto/DeviceList"; @@ -3905,7 +3905,8 @@ export class MatrixClient extends TypedEventEmitter { + // We have a chunk of decrypted keys: import them + try { + await this.cryptoBackend!.importBackedUpRoomKeys(chunk, { + untrusted, + }); + totalImported += chunk.length; + } catch (e) { + totalFailures += chunk.length; + // We failed to import some keys, but we should still try to import the rest? + // Log the error and continue + logger.error("Error importing keys from backup", e); + } + + if (progressCallback) { + progressCallback({ + total: totalKeyCount, + successes: totalImported, + stage: "load_keys", + failures: totalFailures, + }); + } + }, + ); } else if ((res as IRoomKeysResponse).sessions) { + // For now we don't chunk for a single room backup, but we could in the future. + // Currently it is not used by the application. const sessions = (res as IRoomKeysResponse).sessions; totalKeyCount = Object.keys(sessions).length; - keys = await backupDecryptor.decryptSessions(sessions); + const keys = await backupDecryptor.decryptSessions(sessions); for (const k of keys) { k.room_id = targetRoomId!; } + await this.cryptoBackend.importBackedUpRoomKeys(keys, { + progressCallback, + untrusted, + }); + totalImported = keys.length; } else { totalKeyCount = 1; try { @@ -3968,7 +4005,12 @@ export class MatrixClient extends TypedEventEmitter Promise, + ): Promise { + const rooms = (res as IRoomsKeysResponse).rooms; + + let groupChunkCount = 0; + let chunkGroupByRoom: Map = new Map(); + + const handleChunkCallback = async (roomChunks: Map): Promise => { + const currentChunk: IMegolmSessionData[] = []; + for (const roomId of roomChunks.keys()) { + const decryptedSessions = await backupDecryptor.decryptSessions(roomChunks.get(roomId)!); + for (const sessionId in decryptedSessions) { + const k = decryptedSessions[sessionId]; + k.room_id = roomId; + currentChunk.push(k); + } + } + await block(currentChunk); + }; + + for (const [roomId, roomData] of Object.entries(rooms)) { + if (!roomData.sessions) continue; + + chunkGroupByRoom.set(roomId, {}); + + for (const [sessionId, session] of Object.entries(roomData.sessions)) { + const sessionsForRoom = chunkGroupByRoom.get(roomId)!; + sessionsForRoom[sessionId] = session; + groupChunkCount += 1; + if (groupChunkCount >= chunkSize) { + // We have enough chunks to decrypt + await handleChunkCallback(chunkGroupByRoom); + chunkGroupByRoom = new Map(); + // There might be remaining keys for that room, so add back an entry for the current room. + chunkGroupByRoom.set(roomId, {}); + groupChunkCount = 0; + } + } + } + + // Handle remaining chunk if needed + if (groupChunkCount > 0) { + await handleChunkCallback(chunkGroupByRoom); + } } public deleteKeysFromBackup(roomId: undefined, sessionId: undefined, version?: string): Promise; diff --git a/src/crypto-api.ts b/src/crypto-api.ts index 163c491974d..744b20f2b7e 100644 --- a/src/crypto-api.ts +++ b/src/crypto-api.ts @@ -586,9 +586,9 @@ export class DeviceVerificationStatus { */ export interface ImportRoomKeyProgressData { stage: string; // TODO: Enum - successes: number; - failures: number; - total: number; + successes?: number; + failures?: number; + total?: number; } /** diff --git a/src/crypto/keybackup.ts b/src/crypto/keybackup.ts index 24c73b85cd5..8ef04176cb2 100644 --- a/src/crypto/keybackup.ts +++ b/src/crypto/keybackup.ts @@ -15,6 +15,8 @@ limitations under the License. */ // Export for backward compatibility +import { ImportRoomKeyProgressData } from "../crypto-api"; + export type { Curve25519AuthData as ICurve25519AuthData, Aes256AuthData as IAes256AuthData, @@ -41,5 +43,5 @@ export interface IKeyBackupRestoreResult { export interface IKeyBackupRestoreOpts { cacheCompleteCallback?: () => void; - progressCallback?: (progress: { stage: string }) => void; + progressCallback?: (progress: ImportRoomKeyProgressData) => void; } diff --git a/src/rust-crypto/backup.ts b/src/rust-crypto/backup.ts index b969013eb10..289b197b280 100644 --- a/src/rust-crypto/backup.ts +++ b/src/rust-crypto/backup.ts @@ -29,7 +29,7 @@ import { logger } from "../logger"; import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api"; import { CryptoEvent, IMegolmSessionData } from "../crypto"; import { TypedEventEmitter } from "../models/typed-event-emitter"; -import { encodeUri, immediate, logDuration } from "../utils"; +import { encodeUri, logDuration } from "../utils"; import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor"; import { sleep } from "../utils"; import { BackupDecryptor } from "../common-crypto/CryptoBackend"; @@ -534,9 +534,6 @@ export class RustBackupDecryptor implements BackupDecryptor { ); decrypted.session_id = sessionId; keys.push(decrypted); - - // there might be lots of sessions, so don't hog the event loop - await immediate(); } catch (e) { logger.log("Failed to decrypt megolm session from backup", e, sessionData); }