diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index d7f8644c8e0..405dfd96b88 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -340,6 +340,108 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe expect(afterCache.imported).toStrictEqual(1); }); + it("Should import full backup in chunks", async function () { + fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); + + aliceClient = await initTestClient(); + + const aliceCrypto = aliceClient.getCrypto()!; + + // just mock this call + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest.fn(); + + await aliceClient.startClient(); + + // tell Alice to trust the dummy device that signed the backup + await waitForDeviceList(); + await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + + const fullBackup: { + rooms: { + [roomId: string]: { + [sessionId: string]: any; + }; + }; + } = { rooms: {} }; + + // we need several rooms with several sessions to test chunking + const keysPerRoom = [45, 300, 345, 12, 130]; + const expectedTotal = keysPerRoom.reduce((a, b) => a + b, 0); + for (let i = 0; i < keysPerRoom.length; i++) { + const roomId = `!room${i}:example.com`; + fullBackup.rooms[roomId] = { sessions: {} }; + for (let j = 0; j < keysPerRoom[i]; j++) { + const sessionId = `session${j}`; + // put the same fake session data, not important for that tet + fullBackup.rooms[roomId].sessions[sessionId] = testData.CURVE25519_KEY_BACKUP_DATA; + } + } + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", fullBackup); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + // @ts-ignore spying internal method + const importSpy = jest.spyOn(aliceCrypto, "importBackedUpRoomKeys"); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.imported).toStrictEqual(expectedTotal); + // should be called 5 times: 200*4 plus one chunk with the remaining 32 + expect(importSpy).toHaveBeenCalledTimes(5); + + expect(progressCallback).toHaveBeenCalledWith({ + // unfortunately there is no proper enum for stages :/ + // react sdk expect some values though + stage: "fetch", + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 200, + stage: "load_keys", + failures: 0, + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 400, + stage: "load_keys", + failures: 0, + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 600, + stage: "load_keys", + failures: 0, + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 800, + stage: "load_keys", + failures: 0, + }); + + expect(progressCallback).toHaveBeenCalledWith({ + total: expectedTotal, + successes: 832, + stage: "load_keys", + failures: 0, + }); + }); + it("recover specific session from backup", async function () { fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); diff --git a/spec/unit/rust-crypto/rust-crypto.spec.ts b/spec/unit/rust-crypto/rust-crypto.spec.ts index c9146c24677..acf00dce321 100644 --- a/spec/unit/rust-crypto/rust-crypto.spec.ts +++ b/spec/unit/rust-crypto/rust-crypto.spec.ts @@ -165,7 +165,7 @@ describe("RustCrypto", () => { let importTotal = 0; const opt: ImportRoomKeysOpts = { progressCallback: (stage) => { - importTotal = stage.total; + importTotal = stage.total ?? 0; }, }; await rustCrypto.importRoomKeys(someRoomKeys, opt); diff --git a/src/client.ts b/src/client.ts index 27e70c43520..d727beec470 100644 --- a/src/client.ts +++ b/src/client.ts @@ -209,7 +209,7 @@ import { IgnoredInvites } from "./models/invites-ignorer"; import { UIARequest, UIAResponse } from "./@types/uia"; import { LocalNotificationSettings } from "./@types/local_notifications"; import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature"; -import { CryptoBackend } from "./common-crypto/CryptoBackend"; +import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api"; import { DeviceInfoMap } from "./crypto/DeviceList"; @@ -3898,7 +3898,7 @@ export class MatrixClient extends TypedEventEmitter { + decryptedKeyCount += chunk.length; + // we have a chunk of decrypted keys, import them + try { + await this.cryptoBackend!.importBackedUpRoomKeys(chunk, { + untrusted, + }); + totalImported += chunk.length; + if (progressCallback) { + progressCallback({ + total: totalKeyCount, + successes: decryptedKeyCount, + stage: "load_keys", + failures: 0, + }); + } + } catch (e) { + // We failed to import some keys, but we should still try to import the rest? + // Log the error and continue + logger.error("Error importing keys from backup", e); + } + }, + ); } else if ((res as IRoomKeysResponse).sessions) { + // For now we don't chunk for a single room backup, but we could in the future. + // Currently it is not used by the application. const sessions = (res as IRoomKeysResponse).sessions; totalKeyCount = Object.keys(sessions).length; - keys = await backupDecryptor.decryptSessions(sessions); + const keys = await backupDecryptor.decryptSessions(sessions); for (const k of keys) { k.room_id = targetRoomId!; } + await this.cryptoBackend.importBackedUpRoomKeys(keys, { + progressCallback, + untrusted, + }); + totalImported = keys.length; } else { totalKeyCount = 1; try { @@ -3961,7 +3999,12 @@ export class MatrixClient extends TypedEventEmitter Promise, + ): Promise { + const rooms = (res as IRoomsKeysResponse).rooms; + + let groupChunkCount = 0; + let chunkGroupByRoom: Map = new Map(); + + const handleChunkCallback = async (roomChunks: Map): Promise => { + const currentChunk: IMegolmSessionData[] = []; + for (const roomId of roomChunks.keys()) { + const decryptedSessions = await backupDecryptor.decryptSessions(roomChunks.get(roomId)!); + for (const sessionId in decryptedSessions) { + const k = decryptedSessions[sessionId]; + k.room_id = roomId; + currentChunk.push(k); + } + } + await block(currentChunk); + }; - /// in case entering the passphrase would add a new signature? - await this.cryptoBackend.checkKeyBackupAndEnable(); + for (const [roomId, roomData] of Object.entries(rooms)) { + if (!roomData.sessions) continue; + + chunkGroupByRoom.set(roomId, {}); + + for (const sessions of Object.entries(roomData.sessions)) { + const sessionsForRoom = chunkGroupByRoom.get(roomId)!; + sessionsForRoom[sessions[0]] = sessions[1]; + groupChunkCount += 1; + if (groupChunkCount >= chunkSize) { + // we have enough chunks to decrypt + await handleChunkCallback(chunkGroupByRoom); + chunkGroupByRoom = new Map(); + // there might be remaining keys for that room, so add back an entry for the current room. + chunkGroupByRoom.set(roomId, {}); + groupChunkCount = 0; + } + } + } - return { total: totalKeyCount, imported: keys.length }; + // handle remaining chunk if needed + if (groupChunkCount > 0) { + await handleChunkCallback(chunkGroupByRoom); + } } public deleteKeysFromBackup(roomId: undefined, sessionId: undefined, version?: string): Promise; diff --git a/src/crypto-api.ts b/src/crypto-api.ts index 163c491974d..744b20f2b7e 100644 --- a/src/crypto-api.ts +++ b/src/crypto-api.ts @@ -586,9 +586,9 @@ export class DeviceVerificationStatus { */ export interface ImportRoomKeyProgressData { stage: string; // TODO: Enum - successes: number; - failures: number; - total: number; + successes?: number; + failures?: number; + total?: number; } /** diff --git a/src/crypto/keybackup.ts b/src/crypto/keybackup.ts index 24c73b85cd5..8ef04176cb2 100644 --- a/src/crypto/keybackup.ts +++ b/src/crypto/keybackup.ts @@ -15,6 +15,8 @@ limitations under the License. */ // Export for backward compatibility +import { ImportRoomKeyProgressData } from "../crypto-api"; + export type { Curve25519AuthData as ICurve25519AuthData, Aes256AuthData as IAes256AuthData, @@ -41,5 +43,5 @@ export interface IKeyBackupRestoreResult { export interface IKeyBackupRestoreOpts { cacheCompleteCallback?: () => void; - progressCallback?: (progress: { stage: string }) => void; + progressCallback?: (progress: ImportRoomKeyProgressData) => void; } diff --git a/src/rust-crypto/backup.ts b/src/rust-crypto/backup.ts index b969013eb10..289b197b280 100644 --- a/src/rust-crypto/backup.ts +++ b/src/rust-crypto/backup.ts @@ -29,7 +29,7 @@ import { logger } from "../logger"; import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api"; import { CryptoEvent, IMegolmSessionData } from "../crypto"; import { TypedEventEmitter } from "../models/typed-event-emitter"; -import { encodeUri, immediate, logDuration } from "../utils"; +import { encodeUri, logDuration } from "../utils"; import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor"; import { sleep } from "../utils"; import { BackupDecryptor } from "../common-crypto/CryptoBackend"; @@ -534,9 +534,6 @@ export class RustBackupDecryptor implements BackupDecryptor { ); decrypted.session_id = sessionId; keys.push(decrypted); - - // there might be lots of sessions, so don't hog the event loop - await immediate(); } catch (e) { logger.log("Failed to decrypt megolm session from backup", e, sessionData); }