Skip to content

Commit

Permalink
Decrypt and Import full backups in chunk with progress
Browse files Browse the repository at this point in the history
  • Loading branch information
BillCarsonFr committed Jan 12, 2024
1 parent 2ef3ebb commit f76e465
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 31 deletions.
102 changes: 102 additions & 0 deletions spec/integ/crypto/megolm-backup.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,108 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe
expect(afterCache.imported).toStrictEqual(1);
});

it("Should import full backup in chunks", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);

aliceClient = await initTestClient();

const aliceCrypto = aliceClient.getCrypto()!;

// just mock this call
// @ts-ignore - mock a private method for testing purpose
aliceCrypto.importBackedUpRoomKeys = jest.fn();

await aliceClient.startClient();

// tell Alice to trust the dummy device that signed the backup
await waitForDeviceList();
await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID);

const fullBackup: {
rooms: {
[roomId: string]: {
[sessionId: string]: any;
};
};
} = { rooms: {} };

// we need several rooms with several sessions to test chunking
const keysPerRoom = [45, 300, 345, 12, 130];
const expectedTotal = keysPerRoom.reduce((a, b) => a + b, 0);
for (let i = 0; i < keysPerRoom.length; i++) {
const roomId = `!room${i}:example.com`;
fullBackup.rooms[roomId] = { sessions: {} };
for (let j = 0; j < keysPerRoom[i]; j++) {
const sessionId = `session${j}`;
// put the same fake session data, not important for that tet
fullBackup.rooms[roomId].sessions[sessionId] = testData.CURVE25519_KEY_BACKUP_DATA;
}
}

fetchMock.get("express:/_matrix/client/v3/room_keys/keys", fullBackup);

const check = await aliceCrypto.checkKeyBackupAndEnable();

// @ts-ignore spying internal method
const importSpy = jest.spyOn(aliceCrypto, "importBackedUpRoomKeys");

const progressCallback = jest.fn();
const result = await aliceClient.restoreKeyBackupWithRecoveryKey(
testData.BACKUP_DECRYPTION_KEY_BASE58,
undefined,
undefined,
check!.backupInfo!,
{
progressCallback,
},
);

expect(result.imported).toStrictEqual(expectedTotal);
// should be called 5 times: 200*4 plus one chunk with the remaining 32
expect(importSpy).toHaveBeenCalledTimes(5);

expect(progressCallback).toHaveBeenCalledWith({
// unfortunately there is no proper enum for stages :/
// react sdk expect some values though
stage: "fetch",
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 200,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 400,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 600,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 800,
stage: "load_keys",
failures: 0,
});

expect(progressCallback).toHaveBeenCalledWith({
total: expectedTotal,
successes: 832,
stage: "load_keys",
failures: 0,
});
});

it("recover specific session from backup", async function () {
fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA);

Expand Down
2 changes: 1 addition & 1 deletion spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ describe("RustCrypto", () => {
let importTotal = 0;
const opt: ImportRoomKeysOpts = {
progressCallback: (stage) => {
importTotal = stage.total;
importTotal = stage.total ?? 0;
},
};
await rustCrypto.importRoomKeys(someRoomKeys, opt);
Expand Down
140 changes: 118 additions & 22 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ import { IgnoredInvites } from "./models/invites-ignorer";
import { UIARequest, UIAResponse } from "./@types/uia";
import { LocalNotificationSettings } from "./@types/local_notifications";
import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature";
import { CryptoBackend } from "./common-crypto/CryptoBackend";
import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend";
import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants";
import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api";
import { DeviceInfoMap } from "./crypto/DeviceList";
Expand Down Expand Up @@ -3898,7 +3898,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
}

let totalKeyCount = 0;
let keys: IMegolmSessionData[] = [];
let totalImported = 0;

const path = this.makeKeyBackupPath(targetRoomId, targetSessionId, backupInfo.version);

Expand Down Expand Up @@ -3934,25 +3934,63 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
{ prefix: ClientPrefix.V3 },
);

// We have finished fetching the backup, go to next step
if (progressCallback) {
progressCallback({
stage: "load_keys",
});
}

if ((res as IRoomsKeysResponse).rooms) {
const rooms = (res as IRoomsKeysResponse).rooms;
for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;

totalKeyCount += Object.keys(roomData.sessions).length;
const roomKeys = await backupDecryptor.decryptSessions(roomData.sessions);
for (const k of roomKeys) {
k.room_id = roomId;
keys.push(k);
}
}
// We have a full backup here, it can get quite big, so we need to decrypt and import it in chunks.

// Get the total count as a first pass
totalKeyCount = this.getTotalKeyCount(res as IRoomsKeysResponse);
let decryptedKeyCount = 0;

// Now decrypt and import the keys in chunks
// We need to adapt the `progressCallback` to give accurate progress.
await this.handleDecryptionOfAFullBackup(
res as IRoomsKeysResponse,
backupDecryptor,
200,
async (chunk) => {
decryptedKeyCount += chunk.length;
// we have a chunk of decrypted keys, import them
try {
await this.cryptoBackend!.importBackedUpRoomKeys(chunk, {
untrusted,
});
totalImported += chunk.length;
if (progressCallback) {
progressCallback({
total: totalKeyCount,
successes: decryptedKeyCount,
stage: "load_keys",
failures: 0,
});
}
} catch (e) {
// We failed to import some keys, but we should still try to import the rest?
// Log the error and continue
logger.error("Error importing keys from backup", e);
}
},
);
} else if ((res as IRoomKeysResponse).sessions) {
// For now we don't chunk for a single room backup, but we could in the future.
// Currently it is not used by the application.
const sessions = (res as IRoomKeysResponse).sessions;
totalKeyCount = Object.keys(sessions).length;
keys = await backupDecryptor.decryptSessions(sessions);
const keys = await backupDecryptor.decryptSessions(sessions);
for (const k of keys) {
k.room_id = targetRoomId!;
}
await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
totalImported = keys.length;
} else {
totalKeyCount = 1;
try {
Expand All @@ -3961,7 +3999,12 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
});
key.room_id = targetRoomId!;
key.session_id = targetSessionId!;
keys.push(key);

await this.cryptoBackend.importBackedUpRoomKeys([key], {
progressCallback,
untrusted,
});
totalImported = 1;
} catch (e) {
this.logger.debug("Failed to decrypt megolm session from backup", e);
}
Expand All @@ -3970,15 +4013,68 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
backupDecryptor.free();
}

await this.cryptoBackend.importBackedUpRoomKeys(keys, {
progressCallback,
untrusted,
});
return { total: totalKeyCount, imported: totalImported };
}

private getTotalKeyCount(res: IRoomsKeysResponse): number {
const rooms = res.rooms;
let totalKeyCount = 0;
for (const entry of Object.entries(rooms)) {
const roomData = entry[1];
if (!roomData.sessions) continue;
totalKeyCount += Object.keys(roomData.sessions).length;
}
return totalKeyCount;
}

private async handleDecryptionOfAFullBackup(
res: IRoomsKeysResponse,
backupDecryptor: BackupDecryptor,
chunkSize: number,
block: (chunk: IMegolmSessionData[]) => Promise<void>,
): Promise<void> {
const rooms = (res as IRoomsKeysResponse).rooms;

let groupChunkCount = 0;
let chunkGroupByRoom: Map<string, IKeyBackupRoomSessions> = new Map();

const handleChunkCallback = async (roomChunks: Map<string, IKeyBackupRoomSessions>): Promise<void> => {
const currentChunk: IMegolmSessionData[] = [];
for (const roomId of roomChunks.keys()) {
const decryptedSessions = await backupDecryptor.decryptSessions(roomChunks.get(roomId)!);
for (const sessionId in decryptedSessions) {
const k = decryptedSessions[sessionId];
k.room_id = roomId;
currentChunk.push(k);
}
}
await block(currentChunk);
};

/// in case entering the passphrase would add a new signature?
await this.cryptoBackend.checkKeyBackupAndEnable();
for (const [roomId, roomData] of Object.entries(rooms)) {
if (!roomData.sessions) continue;

chunkGroupByRoom.set(roomId, {});

for (const sessions of Object.entries(roomData.sessions)) {
const sessionsForRoom = chunkGroupByRoom.get(roomId)!;
sessionsForRoom[sessions[0]] = sessions[1];
groupChunkCount += 1;
if (groupChunkCount >= chunkSize) {
// we have enough chunks to decrypt
await handleChunkCallback(chunkGroupByRoom);
chunkGroupByRoom = new Map();
// there might be remaining keys for that room, so add back an entry for the current room.
chunkGroupByRoom.set(roomId, {});
groupChunkCount = 0;
}
}
}

return { total: totalKeyCount, imported: keys.length };
// handle remaining chunk if needed
if (groupChunkCount > 0) {
await handleChunkCallback(chunkGroupByRoom);
}
}

public deleteKeysFromBackup(roomId: undefined, sessionId: undefined, version?: string): Promise<void>;
Expand Down
6 changes: 3 additions & 3 deletions src/crypto-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ export class DeviceVerificationStatus {
*/
export interface ImportRoomKeyProgressData {
stage: string; // TODO: Enum
successes: number;
failures: number;
total: number;
successes?: number;
failures?: number;
total?: number;
}

/**
Expand Down
4 changes: 3 additions & 1 deletion src/crypto/keybackup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
*/

// Export for backward compatibility
import { ImportRoomKeyProgressData } from "../crypto-api";

export type {
Curve25519AuthData as ICurve25519AuthData,
Aes256AuthData as IAes256AuthData,
Expand All @@ -41,5 +43,5 @@ export interface IKeyBackupRestoreResult {

export interface IKeyBackupRestoreOpts {
cacheCompleteCallback?: () => void;
progressCallback?: (progress: { stage: string }) => void;
progressCallback?: (progress: ImportRoomKeyProgressData) => void;
}
5 changes: 1 addition & 4 deletions src/rust-crypto/backup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { logger } from "../logger";
import { ClientPrefix, IHttpOpts, MatrixError, MatrixHttpApi, Method } from "../http-api";
import { CryptoEvent, IMegolmSessionData } from "../crypto";
import { TypedEventEmitter } from "../models/typed-event-emitter";
import { encodeUri, immediate, logDuration } from "../utils";
import { encodeUri, logDuration } from "../utils";
import { OutgoingRequestProcessor } from "./OutgoingRequestProcessor";
import { sleep } from "../utils";
import { BackupDecryptor } from "../common-crypto/CryptoBackend";
Expand Down Expand Up @@ -534,9 +534,6 @@ export class RustBackupDecryptor implements BackupDecryptor {
);
decrypted.session_id = sessionId;
keys.push(decrypted);

// there might be lots of sessions, so don't hog the event loop
await immediate();
} catch (e) {
logger.log("Failed to decrypt megolm session from backup", e, sessionData);
}
Expand Down

0 comments on commit f76e465

Please sign in to comment.