Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decrypt and Import full backups in chunk with progress #4005

Merged
merged 10 commits into from
Jan 19, 2024
102 changes: 102 additions & 0 deletions spec/integ/crypto/megolm-backup.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,108 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
expect(afterCache.imported).toStrictEqual(1);
});

it("Should import full backup in chunks", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);

aliceClient = await initTestClient();

const aliceCrypto = aliceClient.getCrypto()!;

// just mock this call
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest.fn();

await aliceClient.startClient();

// tell Alice to trust the dummy device that signed the backup
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);

const fullBackup: {
rooms: {
[roomId: string]: {
[sessionId: string]: any;
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
};
};
} = { rooms: {} };

// we need several rooms with several sessions to test chunking
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
const keysPerRoom = [45, 300, 345, 12, 130];
const expectedTotal = keysPerRoom.reduce((a, b) => a + b, 0);
for (let i = 0; i < keysPerRoom.length; i++) {
const roomId = `!room${i}:example.com`;
fullBackup.rooms[roomId] = { sessions: {} };
for (let j = 0; j < keysPerRoom[i]; j++) {
const sessionId = `session${j}`;
// put the same fake session data, not important for that tet
fullBackup.rooms[roomId].sessions[sessionId] = testData.CURVE25519_KEY_BACKUP_DATA;
}
}
andybalaam marked this conversation as resolved.
Show resolved Hide resolved

fetchMock.get("express:/_matrix/client/v3/room_keys/keys", fullBackup);

const check = await aliceCrypto.checkKeyBackupAndEnable();

// @ts-ignore spying internal method
const importSpy = jest.spyOn(aliceCrypto, "importBackedUpRoomKeys");

const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);

expect(result.imported).toStrictEqual(expectedTotal);
// should be called 5 times: 200*4 plus one chunk with the remaining 32
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
expect(importSpy).toHaveBeenCalledTimes(5);

expect(progressCallback).toHaveBeenCalledWith({
// unfortunately there is no proper enum for stages :/
// react sdk expect some values though
andybalaam marked this conversation as resolved.
Show resolved Hide resolved
stage: "fetch",
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 200,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 400,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 600,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 800,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 832,
stage: "load_keys",
failures: 0,
});
andybalaam marked this conversation as resolved.
Show resolved Hide resolved
});

it("recover specific session from backup", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);

Expand Down
2 changes: 1 addition & 1 deletion spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ describe("RustCrypto", () => {
let importTotal = 0;
const opt: ImportRoomKeysOpts = {
progressCallback: (stage) => {
importTotal = stage.total;
importTotal = stage.total ?? 0;
},
};
await rustCrypto.importRoomKeys(someRoomKeys, opt);
Expand Down
159 changes: 138 additions & 21 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ import { IgnoredInvites } from "./models/invites-ignorer";
import { UIARequest, UIAResponse } from "./@types/uia";
import { LocalNotificationSettings } from "./@types/local_notifications";
import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { CryptoBackend } from "./common-crypto/CryptoBackend";
import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
Expand Down Expand Up @@ -3898,7 +3898,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
}

let totalKeyCount = 0;
let keys: IMegolmSessionData[] = [];
let totalImported = 0;

const path = this.makeKeyBackupPath(targetRoomId, targetSessionId, backupInfo.version);

Expand Down Expand Up @@ -3934,25 +3934,63 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
{ prefix: ClientPrefix.V3 },
);

// We have finished fetching the backup, go to next step
if (progressCallback) {
progressCallback({
stage: "load_keys",
});
}

if ((res as IRoomsKeysResponse).rooms) {
const rooms = (res as IRoomsKeysResponse).rooms;
for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;

totalKeyCount += Object.keys(roomData.sessions).length;
const roomKeys = await backupDecryptor.decryptSessions(roomData.sessions);
for (const k of roomKeys) {
k.room_id = roomId;
keys.push(k);
}
}
// We have a full backup here, it can get quite big, so we need to decrypt and import it in chunks.

// Get the total count as a first pass
totalKeyCount = this.getTotalKeyCount(res as IRoomsKeysResponse);
let decryptedKeyCount = 0;

// Now decrypt and import the keys in chunks
// We need to adapt the `progressCallback` to give accurate progress.
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
await this.handleDecryptionOfAFullBackup(
res as IRoomsKeysResponse,
backupDecryptor,
200,
async (chunk) => {
decryptedKeyCount += chunk.length;
// we have a chunk of decrypted keys, import them
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
try {
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, {
untrusted,
});
totalImported += chunk.length;
if (progressCallback) {
progressCallback({
total: totalKeyCount,
successes: decryptedKeyCount,
stage: "load_keys",
failures: 0,
});
}
} catch (e) {
// We failed to import some keys, but we should still try to import the rest?
// Log the error and continue
andybalaam marked this conversation as resolved.
Show resolved Hide resolved
logger.error("Error importing keys from backup", e);
}
},
);
} else if ((res as IRoomKeysResponse).sessions) {
// For now we don't chunk for a single room backup, but we could in the future.
// Currently it is not used by the application.
const sessions = (res as IRoomKeysResponse).sessions;
totalKeyCount = Object.keys(sessions).length;
keys = await backupDecryptor.decryptSessions(sessions);
const keys = await backupDecryptor.decryptSessions(sessions);
for (const k of keys) {
k.room_id = targetRoomId!;
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
totalImported = keys.length;
} else {
totalKeyCount = 1;
try {
Expand All @@ -3961,7 +3999,12 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
});
key.room_id = targetRoomId!;
key.session_id = targetSessionId!;
keys.push(key);

await this.cryptoBackend.importBackedUpRoomKeys([key], {
progressCallback,
untrusted,
});
totalImported = 1;
} catch (e) {
this.logger.debug("Failed to decrypt megolm session from backup", e);
}
Expand All @@ -3970,15 +4013,89 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
backupDecryptor.free();
}

await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});

/// in case entering the passphrase would add a new signature?
await this.cryptoBackend.checkKeyBackupAndEnable();

return { total: totalKeyCount, imported: keys.length };
return { total: totalKeyCount, imported: totalImported };
}

/**
* This method calculates the total number of keys present in the response of a `/room_keys/keys` call.
*
* @param res - The response from the server containing the keys to be counted.
*
* @returns The total number of keys in the backup.
*/
private getTotalKeyCount(res: IRoomsKeysResponse): number {
const rooms = res.rooms;
let totalKeyCount = 0;
for (const entry of Object.entries(rooms)) {
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
const roomData = entry[1];
if (!roomData.sessions) continue;
totalKeyCount += Object.keys(roomData.sessions).length;
}
return totalKeyCount;
}

/**
* This method handles the decryption of a full backup, i.e a call to `/room_keys/keys`.
* It will decrypt the keys in chunks and call the `block` callback for each chunk.
*
* @param res - The response from the server containing the keys to be decrypted.
* @param backupDecryptor - An instance of the BackupDecryptor class used to decrypt the keys.
* @param chunkSize - The size of the chunks to be processed at a time.
* @param block - A callback function that is called for each chunk of keys.
*
* @returns A promise that resolves when the decryption is complete.
*/
private async handleDecryptionOfAFullBackup(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a nice use of a generator function, potentially? Not a blocker though.

res: IRoomsKeysResponse,
backupDecryptor: BackupDecryptor,
chunkSize: number,
block: (chunk: IMegolmSessionData[]) => Promise<void>,
): Promise<void> {
const rooms = (res as IRoomsKeysResponse).rooms;

let groupChunkCount = 0;
let chunkGroupByRoom: Map<string, IKeyBackupRoomSessions> = new Map();

const handleChunkCallback = async (roomChunks: Map<string, IKeyBackupRoomSessions>): Promise<void> => {
const currentChunk: IMegolmSessionData[] = [];
for (const roomId of roomChunks.keys()) {
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
const decryptedSessions = await backupDecryptor.decryptSessions(roomChunks.get(roomId)!);
for (const sessionId in decryptedSessions) {
const k = decryptedSessions[sessionId];
k.room_id = roomId;
currentChunk.push(k);
}
}
await block(currentChunk);
};

for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;

chunkGroupByRoom.set(roomId, {});

for (const sessions of Object.entries(roomData.sessions)) {
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
const sessionsForRoom = chunkGroupByRoom.get(roomId)!;
sessionsForRoom[sessions[0]] = sessions[1];
groupChunkCount += 1;
if (groupChunkCount >= chunkSize) {
// we have enough chunks to decrypt
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
await handleChunkCallback(chunkGroupByRoom);
chunkGroupByRoom = new Map();
// there might be remaining keys for that room, so add back an entry for the current room.
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
chunkGroupByRoom.set(roomId, {});
groupChunkCount = 0;
}
}
}

// handle remaining chunk if needed
BillCarsonFr marked this conversation as resolved.
Show resolved Hide resolved
if (groupChunkCount > 0) {
await handleChunkCallback(chunkGroupByRoom);
}
}

public deleteKeysFromBackup(roomId: undefined, sessionId: undefined, version?: string): Promise<void>;
Expand Down
6 changes: 3 additions & 3 deletions src/crypto-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ export class DeviceVerificationStatus {
*/
export interface ImportRoomKeyProgressData {
stage: string; // TODO: Enum
successes: number;
failures: number;
total: number;
successes?: number;
failures?: number;
total?: number;
}

/**
Expand Down
4 changes: 3 additions & 1 deletion src/crypto/keybackup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
*/

// Export for backward compatibility
import { ImportRoomKeyProgressData } from "../crypto-api";

export type {
Curve25519AuthData as ICurve25519AuthData,
Aes256AuthData as IAes256AuthData,
Expand All @@ -41,5 +43,5 @@ export interface IKeyBackupRestoreResult {

export interface IKeyBackupRestoreOpts {
cacheCompleteCallback?: () => void;
progressCallback?: (progress: { stage: string }) => void;
progressCallback?: (progress: ImportRoomKeyProgressData) => void;
}
5 changes: 1 addition & 4 deletions src/rust-crypto/backup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { logger } from "../logger";
import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api";
import { CryptoEvent, IMegolmSessionData } from "../crypto";
import { TypedEventEmitter } from "../models/typed-event-emitter";
import { encodeUri, immediate, logDuration } from "../utils";
import { encodeUri, logDuration } from "../utils";
import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor";
import { sleep } from "../utils";
import { BackupDecryptor } from "../common-crypto/CryptoBackend";
Expand Down Expand Up @@ -534,9 +534,6 @@ export class RustBackupDecryptor implements BackupDecryptor {
);
decrypted.session_id = sessionId;
keys.push(decrypted);

// there might be lots of sessions, so don't hog the event loop
await immediate();
} catch (e) {
logger.log("Failed to decrypt megolm session from backup", e, sessionData);
}
Expand Down
Loading