add progress notification and cancellation to key backup flush

This commit is contained in:
Bruno Windels 2022-01-26 15:19:31 +01:00
parent 554aa45d48
commit 60ed276b8a
3 changed files with 93 additions and 39 deletions

View file

@ -76,6 +76,7 @@ export class Session {
this._getSyncToken = () => this.syncToken; this._getSyncToken = () => this.syncToken;
this._olmWorker = olmWorker; this._olmWorker = olmWorker;
this._keyBackup = null; this._keyBackup = null;
this._keyBackupOperation = new ObservableValue(null);
this._hasSecretStorageKey = new ObservableValue(null); this._hasSecretStorageKey = new ObservableValue(null);
this._observedRoomStatus = new Map(); this._observedRoomStatus = new Map();
@ -270,6 +271,10 @@ export class Session {
return this._keyBackup; return this._keyBackup;
} }
get keyBackupOperation() {
return this._keyBackupOperation;
}
get hasIdentity() { get hasIdentity() {
return !!this._e2eeAccount; return !!this._e2eeAccount;
} }
@ -559,7 +564,7 @@ export class Session {
async writeSync(syncResponse, syncFilterId, preparation, txn, log) { async writeSync(syncResponse, syncFilterId, preparation, txn, log) {
const changes = { const changes = {
syncInfo: null, syncInfo: null,
e2eeAccountChanges: null, e2eeAccountChanges: null
}; };
const syncToken = syncResponse.next_batch; const syncToken = syncResponse.next_batch;
if (syncToken !== this.syncToken) { if (syncToken !== this.syncToken) {
@ -584,7 +589,7 @@ export class Session {
// this should come after the deviceMessageHandler, so the room keys are already written and their // this should come after the deviceMessageHandler, so the room keys are already written and their
// isBetter property has been checked // isBetter property has been checked
if (this._keyBackup) { if (this._keyBackup) {
this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log); changes.shouldFlushKeyBackup = this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log);
} }
} }
@ -623,8 +628,18 @@ export class Session {
await log.wrap("uploadKeys", log => this._e2eeAccount.uploadKeys(this._storage, false, log)); await log.wrap("uploadKeys", log => this._e2eeAccount.uploadKeys(this._storage, false, log));
} }
} }
if (this._keyBackup) { // should flush and not already flushing
this._keyBackup.flush(); if (changes.shouldFlushKeyBackup && this._keyBackup && !this._keyBackupOperation.get()) {
log.wrapDetached("flush key backup", async log => {
const operation = this._keyBackup.flush(log);
this._keyBackupOperation.set(operation);
try {
await operation.result;
} catch (err) {
log.catch(err);
}
this._keyBackupOperation.set(null);
});
} }
} }

View file

@ -18,7 +18,10 @@ import {StoreNames} from "../../../storage/common";
import {keyFromStorage, keyFromBackup} from "../decryption/RoomKey"; import {keyFromStorage, keyFromBackup} from "../decryption/RoomKey";
import {MEGOLM_ALGORITHM} from "../../common"; import {MEGOLM_ALGORITHM} from "../../common";
import * as Curve25519 from "./Curve25519"; import * as Curve25519 from "./Curve25519";
import {AbortableOperation} from "../../../../utils/AbortableOperation";
import {SetAbortableFn} from "../../../../utils/AbortableOperation";
import type {BackupInfo, SessionData, SessionKeyInfo, SessionInfo, KeyBackupPayload} from "./types";
import type {HomeServerApi} from "../../../net/HomeServerApi"; import type {HomeServerApi} from "../../../net/HomeServerApi";
import type {IncomingRoomKey, RoomKey} from "../decryption/RoomKey"; import type {IncomingRoomKey, RoomKey} from "../decryption/RoomKey";
import type {KeyLoader} from "../decryption/KeyLoader"; import type {KeyLoader} from "../decryption/KeyLoader";
@ -27,6 +30,7 @@ import type {Storage} from "../../../storage/idb/Storage";
import type {ILogItem} from "../../../../logging/types"; import type {ILogItem} from "../../../../logging/types";
import type {Platform} from "../../../../platform/web/Platform"; import type {Platform} from "../../../../platform/web/Platform";
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {BackupEntry} from "../../../storage/idb/stores/SessionNeedingBackupStore";
import type * as OlmNamespace from "@matrix-org/olm"; import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
@ -45,7 +49,7 @@ export class KeyBackup {
if (!sessionResponse.session_data) { if (!sessionResponse.session_data) {
return; return;
} }
const sessionKeyInfo = this.crypto.decryptRoomKey(sessionResponse.session_data as Curve25519.SessionData); const sessionKeyInfo = this.crypto.decryptRoomKey(sessionResponse.session_data as SessionData);
if (sessionKeyInfo?.algorithm === MEGOLM_ALGORITHM) { if (sessionKeyInfo?.algorithm === MEGOLM_ALGORITHM) {
return keyFromBackup(roomId, sessionId, sessionKeyInfo); return keyFromBackup(roomId, sessionId, sessionKeyInfo);
} else if (sessionKeyInfo?.algorithm) { } else if (sessionKeyInfo?.algorithm) {
@ -64,45 +68,69 @@ export class KeyBackup {
return hasBetter; return hasBetter;
} }
async flush() { // TODO: protect against having multiple concurrent flushes
while (true) { flush(log: ILogItem): AbortableOperation<Promise<void>, Progress> {
await this.platform.clock.createTimeout(this.platform.random() * 10000).elapsed(); return new AbortableOperation(async (setAbortable, setProgress) => {
const txn = await this.storage.readTxn([ let total = 0;
StoreNames.sessionsNeedingBackup, let amountFinished = 0;
StoreNames.inboundGroupSessions, while (true) {
]); const timeout = this.platform.clock.createTimeout(this.platform.random() * 10000);
const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20); setAbortable(timeout);
if (keysNeedingBackup.length === 0) { await timeout.elapsed();
return; const txn = await this.storage.readTxn([
}
const roomKeys = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn)));
const payload: KeyBackupPayload = { rooms: {} };
const payloadRooms = payload.rooms;
for (const key of roomKeys) {
if (key) {
let roomPayload = payloadRooms[key.roomId];
if (!roomPayload) {
roomPayload = payloadRooms[key.roomId] = { sessions: {} };
}
roomPayload.sessions[key.sessionId] = await this.encodeRoomKey(key);
}
}
await this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload).response();
{
const txn = await this.storage.readWriteTxn([
StoreNames.sessionsNeedingBackup, StoreNames.sessionsNeedingBackup,
StoreNames.inboundGroupSessions,
]); ]);
try { setAbortable(txn);
for (const key of keysNeedingBackup) { // fetch total again on each iteration as while we are flushing, sync might be adding keys
txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId); total = await txn.sessionsNeedingBackup.count();
} setProgress(new Progress(total, amountFinished));
} catch (err) { const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20);
txn.abort(); if (keysNeedingBackup.length === 0) {
throw err; return;
} }
await txn.complete(); const roomKeysOrNotFound = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn)));
const roomKeys = roomKeysOrNotFound.filter(k => !!k) as RoomKey[];
if (roomKeys.length) {
const payload = await this.encodeKeysForBackup(roomKeys);
const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload);
setAbortable(uploadRequest);
await uploadRequest.response();
}
this.removeBackedUpKeys(keysNeedingBackup, setAbortable);
amountFinished += keysNeedingBackup.length;
setProgress(new Progress(total, amountFinished));
} }
});
}
private async encodeKeysForBackup(roomKeys: RoomKey[]): Promise<KeyBackupPayload> {
const payload: KeyBackupPayload = { rooms: {} };
const payloadRooms = payload.rooms;
for (const key of roomKeys) {
let roomPayload = payloadRooms[key.roomId];
if (!roomPayload) {
roomPayload = payloadRooms[key.roomId] = { sessions: {} };
}
roomPayload.sessions[key.sessionId] = await this.encodeRoomKey(key);
} }
return payload;
}
private async removeBackedUpKeys(keysNeedingBackup: BackupEntry[], setAbortable: SetAbortableFn) {
const txn = await this.storage.readWriteTxn([
StoreNames.sessionsNeedingBackup,
]);
setAbortable(txn);
try {
for (const key of keysNeedingBackup) {
txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId);
}
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
} }
private async encodeRoomKey(roomKey: RoomKey): Promise<SessionInfo> { private async encodeRoomKey(roomKey: RoomKey): Promise<SessionInfo> {
@ -140,3 +168,10 @@ export class KeyBackup {
} }
} }
} }
export class Progress {
constructor(
public readonly total: number,
public readonly finished: number
) {}
}

View file

@ -53,4 +53,8 @@ export class SessionNeedingBackupStore {
remove(roomId: string, senderKey: string, sessionId: string): void { remove(roomId: string, senderKey: string, sessionId: string): void {
this.store.delete(encodeKey(roomId, senderKey, sessionId)); this.store.delete(encodeKey(roomId, senderKey, sessionId));
} }
count(): Promise<number> {
return this.store.count();
}
} }