diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 9510cbfd..d8595d12 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -76,6 +76,7 @@ export class Session { this._getSyncToken = () => this.syncToken; this._olmWorker = olmWorker; this._keyBackup = null; + this._keyBackupOperation = new ObservableValue(null); this._hasSecretStorageKey = new ObservableValue(null); this._observedRoomStatus = new Map(); @@ -270,6 +271,10 @@ export class Session { return this._keyBackup; } + get keyBackupOperation() { + return this._keyBackupOperation; + } + get hasIdentity() { return !!this._e2eeAccount; } @@ -559,7 +564,7 @@ export class Session { async writeSync(syncResponse, syncFilterId, preparation, txn, log) { const changes = { syncInfo: null, - e2eeAccountChanges: null, + e2eeAccountChanges: null }; const syncToken = syncResponse.next_batch; if (syncToken !== this.syncToken) { @@ -584,7 +589,7 @@ export class Session { // this should come after the deviceMessageHandler, so the room keys are already written and their // isBetter property has been checked if (this._keyBackup) { - this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log); + changes.shouldFlushKeyBackup = this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log); } } @@ -623,8 +628,18 @@ export class Session { await log.wrap("uploadKeys", log => this._e2eeAccount.uploadKeys(this._storage, false, log)); } } - if (this._keyBackup) { - this._keyBackup.flush(); + // should flush and not already flushing + if (changes.shouldFlushKeyBackup && this._keyBackup && !this._keyBackupOperation.get()) { + log.wrapDetached("flush key backup", async log => { + const operation = this._keyBackup.flush(log); + this._keyBackupOperation.set(operation); + try { + await operation.result; + } catch (err) { + log.catch(err); + } + this._keyBackupOperation.set(null); + }); } } diff --git a/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts b/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts index 38faa5e2..cd09423c 100644 --- a/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts +++ b/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts @@ -18,7 +18,10 @@ import {StoreNames} from "../../../storage/common"; import {keyFromStorage, keyFromBackup} from "../decryption/RoomKey"; import {MEGOLM_ALGORITHM} from "../../common"; import * as Curve25519 from "./Curve25519"; +import {AbortableOperation} from "../../../../utils/AbortableOperation"; +import {SetAbortableFn} from "../../../../utils/AbortableOperation"; +import type {BackupInfo, SessionData, SessionKeyInfo, SessionInfo, KeyBackupPayload} from "./types"; import type {HomeServerApi} from "../../../net/HomeServerApi"; import type {IncomingRoomKey, RoomKey} from "../decryption/RoomKey"; import type {KeyLoader} from "../decryption/KeyLoader"; @@ -27,6 +30,7 @@ import type {Storage} from "../../../storage/idb/Storage"; import type {ILogItem} from "../../../../logging/types"; import type {Platform} from "../../../../platform/web/Platform"; import type {Transaction} from "../../../storage/idb/Transaction"; +import type {BackupEntry} from "../../../storage/idb/stores/SessionNeedingBackupStore"; import type * as OlmNamespace from "@matrix-org/olm"; type Olm = typeof OlmNamespace; @@ -45,7 +49,7 @@ export class KeyBackup { if (!sessionResponse.session_data) { return; } - const sessionKeyInfo = this.crypto.decryptRoomKey(sessionResponse.session_data as Curve25519.SessionData); + const sessionKeyInfo = this.crypto.decryptRoomKey(sessionResponse.session_data as SessionData); if (sessionKeyInfo?.algorithm === MEGOLM_ALGORITHM) { return keyFromBackup(roomId, sessionId, sessionKeyInfo); } else if (sessionKeyInfo?.algorithm) { @@ -64,45 +68,69 @@ export class KeyBackup { return hasBetter; } - async flush() { - while (true) { - await this.platform.clock.createTimeout(this.platform.random() * 10000).elapsed(); - const txn = await this.storage.readTxn([ - StoreNames.sessionsNeedingBackup, - StoreNames.inboundGroupSessions, - ]); - const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20); - if (keysNeedingBackup.length === 0) { - return; - } - const roomKeys = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn))); - const payload: KeyBackupPayload = { rooms: {} }; - const payloadRooms = payload.rooms; - for (const key of roomKeys) { - if (key) { - let roomPayload = payloadRooms[key.roomId]; - if (!roomPayload) { - roomPayload = payloadRooms[key.roomId] = { sessions: {} }; - } - roomPayload.sessions[key.sessionId] = await this.encodeRoomKey(key); - } - } - await this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload).response(); - { - const txn = await this.storage.readWriteTxn([ + // TODO: protect against having multiple concurrent flushes + flush(log: ILogItem): AbortableOperation, Progress> { + return new AbortableOperation(async (setAbortable, setProgress) => { + let total = 0; + let amountFinished = 0; + while (true) { + const timeout = this.platform.clock.createTimeout(this.platform.random() * 10000); + setAbortable(timeout); + await timeout.elapsed(); + const txn = await this.storage.readTxn([ StoreNames.sessionsNeedingBackup, + StoreNames.inboundGroupSessions, ]); - try { - for (const key of keysNeedingBackup) { - txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId); - } - } catch (err) { - txn.abort(); - throw err; + setAbortable(txn); + // fetch total again on each iteration as while we are flushing, sync might be adding keys + total = await txn.sessionsNeedingBackup.count(); + setProgress(new Progress(total, amountFinished)); + const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20); + if (keysNeedingBackup.length === 0) { + return; } - await txn.complete(); + const roomKeysOrNotFound = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn))); + const roomKeys = roomKeysOrNotFound.filter(k => !!k) as RoomKey[]; + if (roomKeys.length) { + const payload = await this.encodeKeysForBackup(roomKeys); + const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload); + setAbortable(uploadRequest); + await uploadRequest.response(); + } + this.removeBackedUpKeys(keysNeedingBackup, setAbortable); + amountFinished += keysNeedingBackup.length; + setProgress(new Progress(total, amountFinished)); } + }); + } + + private async encodeKeysForBackup(roomKeys: RoomKey[]): Promise { + const payload: KeyBackupPayload = { rooms: {} }; + const payloadRooms = payload.rooms; + for (const key of roomKeys) { + let roomPayload = payloadRooms[key.roomId]; + if (!roomPayload) { + roomPayload = payloadRooms[key.roomId] = { sessions: {} }; + } + roomPayload.sessions[key.sessionId] = await this.encodeRoomKey(key); } + return payload; + } + + private async removeBackedUpKeys(keysNeedingBackup: BackupEntry[], setAbortable: SetAbortableFn) { + const txn = await this.storage.readWriteTxn([ + StoreNames.sessionsNeedingBackup, + ]); + setAbortable(txn); + try { + for (const key of keysNeedingBackup) { + txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId); + } + } catch (err) { + txn.abort(); + throw err; + } + await txn.complete(); } private async encodeRoomKey(roomKey: RoomKey): Promise { @@ -140,3 +168,10 @@ export class KeyBackup { } } } + +export class Progress { + constructor( + public readonly total: number, + public readonly finished: number + ) {} +} diff --git a/src/matrix/storage/idb/stores/SessionNeedingBackupStore.ts b/src/matrix/storage/idb/stores/SessionNeedingBackupStore.ts index 27104191..71a383bb 100644 --- a/src/matrix/storage/idb/stores/SessionNeedingBackupStore.ts +++ b/src/matrix/storage/idb/stores/SessionNeedingBackupStore.ts @@ -53,4 +53,8 @@ export class SessionNeedingBackupStore { remove(roomId: string, senderKey: string, sessionId: string): void { this.store.delete(encodeKey(roomId, senderKey, sessionId)); } + + count(): Promise { + return this.store.count(); + } }