From dd2b41ff95f3527bd313d87bb1c11dbc2a5163f2 Mon Sep 17 00:00:00 2001 From: Bruno Windels Date: Thu, 27 Jan 2022 16:07:18 +0100 Subject: [PATCH] use backup flag in key backup rather than separate store --- src/matrix/DeviceMessageHandler.js | 3 +- src/matrix/Session.js | 9 +-- src/matrix/e2ee/megolm/decryption/RoomKey.ts | 12 +++- src/matrix/e2ee/megolm/keybackup/KeyBackup.ts | 60 +++++++------------ 4 files changed, 35 insertions(+), 49 deletions(-) diff --git a/src/matrix/DeviceMessageHandler.js b/src/matrix/DeviceMessageHandler.js index 0a606841..6ac5ac07 100644 --- a/src/matrix/DeviceMessageHandler.js +++ b/src/matrix/DeviceMessageHandler.js @@ -57,7 +57,8 @@ export class DeviceMessageHandler { async writeSync(prep, txn) { // write olm changes prep.olmDecryptChanges.write(txn); - await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn))); + const didWriteValues = await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn))); + return didWriteValues.some(didWrite => !!didWrite); } } diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 9b9d013e..dd7a0e1c 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -597,12 +597,7 @@ export class Session { } if (preparation) { - await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log)); - // this should come after the deviceMessageHandler, so the room keys are already written and their - // isBetter property has been checked - if (this._keyBackup) { - changes.shouldFlushKeyBackup = this._keyBackup.writeKeys(preparation.newRoomKeys, txn, log); - } + changes.hasNewRoomKeys = await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log)); } // store account data @@ -641,7 +636,7 @@ export class Session { } } // should flush and not already flushing - if (changes.shouldFlushKeyBackup && this._keyBackup && !this._keyBackupOperation.get()) { + if (changes.hasNewRoomKeys && this._keyBackup && !this._keyBackupOperation.get()) { log.wrapDetached("flush key backup", async log => { const operation = this._keyBackup.flush(log); this._keyBackupOperation.set(operation); diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index 299b0a81..6112ccef 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +import {BackupStatus} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {Transaction} from "../../../storage/idb/Transaction"; import type {DecryptionResult} from "../../DecryptionResult"; @@ -81,6 +82,7 @@ export abstract class IncomingRoomKey extends RoomKey { senderKey: this.senderKey, sessionId: this.sessionId, session: pickledSession, + backup: this.backupStatus, claimedKeys: {"ed25519": this.claimedEd25519Key}, }; txn.inboundGroupSessions.set(sessionEntry); @@ -125,6 +127,10 @@ export abstract class IncomingRoomKey extends RoomKey { } return this.isBetter!; } + + protected get backupStatus(): BackupStatus { + return BackupStatus.NotBackedUp; + } } class DeviceMessageRoomKey extends IncomingRoomKey { @@ -162,9 +168,13 @@ class BackupRoomKey extends IncomingRoomKey { loadInto(session) { session.import_session(this.serializationKey); } + + protected get backupStatus(): BackupStatus { + return BackupStatus.BackedUp; + } } -class StoredRoomKey extends RoomKey { +export class StoredRoomKey extends RoomKey { private storageEntry: InboundGroupSessionEntry; constructor(storageEntry: InboundGroupSessionEntry) { diff --git a/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts b/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts index 3792ea96..49771d15 100644 --- a/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts +++ b/src/matrix/e2ee/megolm/keybackup/KeyBackup.ts @@ -15,7 +15,7 @@ limitations under the License. */ import {StoreNames} from "../../../storage/common"; -import {keyFromStorage, keyFromBackup} from "../decryption/RoomKey"; +import {StoredRoomKey, keyFromBackup} from "../decryption/RoomKey"; import {MEGOLM_ALGORITHM} from "../../common"; import * as Curve25519 from "./Curve25519"; import {AbortableOperation} from "../../../../utils/AbortableOperation"; @@ -30,7 +30,6 @@ import type {Storage} from "../../../storage/idb/Storage"; import type {ILogItem} from "../../../../logging/types"; import type {Platform} from "../../../../platform/web/Platform"; import type {Transaction} from "../../../storage/idb/Transaction"; -import type {BackupEntry} from "../../../storage/idb/stores/SessionNeedingBackupStore"; import type * as OlmNamespace from "@matrix-org/olm"; type Olm = typeof OlmNamespace; @@ -57,17 +56,6 @@ export class KeyBackup { } } - writeKeys(roomKeys: IncomingRoomKey[], txn: Transaction): boolean { - let hasBetter = false; - for (const key of roomKeys) { - if (key.isBetter) { - txn.sessionsNeedingBackup.set(key.roomId, key.senderKey, key.sessionId); - hasBetter = true; - } - } - return hasBetter; - } - // TODO: protect against having multiple concurrent flushes flush(log: ILogItem): AbortableOperation, Progress> { return new AbortableOperation(async (setAbortable, setProgress) => { @@ -77,36 +65,30 @@ export class KeyBackup { const timeout = this.platform.clock.createTimeout(this.platform.random() * 10000); setAbortable(timeout); await timeout.elapsed(); - const txn = await this.storage.readTxn([ - StoreNames.sessionsNeedingBackup, - StoreNames.inboundGroupSessions, - ]); + const txn = await this.storage.readTxn([StoreNames.inboundGroupSessions]); setAbortable(txn); // fetch total again on each iteration as while we are flushing, sync might be adding keys - total = await txn.sessionsNeedingBackup.count(); + total = await txn.inboundGroupSessions.countNonBackedUpSessions(); setProgress(new Progress(total, amountFinished)); - const keysNeedingBackup = await txn.sessionsNeedingBackup.getFirstEntries(20); + const keysNeedingBackup = (await txn.inboundGroupSessions.getFirstNonBackedUpSessions(20)) + .map(entry => new StoredRoomKey(entry)); if (keysNeedingBackup.length === 0) { return true; } - const roomKeysOrNotFound = await Promise.all(keysNeedingBackup.map(k => keyFromStorage(k.roomId, k.senderKey, k.sessionId, txn))); - const roomKeys = roomKeysOrNotFound.filter(k => !!k) as RoomKey[]; - if (roomKeys.length) { - const payload = await this.encodeKeysForBackup(roomKeys); - const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload, {log}); - setAbortable(uploadRequest); - try { - await uploadRequest.response(); - } catch (err) { - if (err.name === "HomeServerError" && err.errcode === "M_WRONG_ROOM_KEYS_VERSION") { - log.set("wrong_version", true); - return false; - } else { - throw err; - } + const payload = await this.encodeKeysForBackup(keysNeedingBackup); + const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload, {log}); + setAbortable(uploadRequest); + try { + await uploadRequest.response(); + } catch (err) { + if (err.name === "HomeServerError" && err.errcode === "M_WRONG_ROOM_KEYS_VERSION") { + log.set("wrong_version", true); + return false; + } else { + throw err; } } - this.removeBackedUpKeys(keysNeedingBackup, setAbortable); + this.markKeysAsBackedUp(keysNeedingBackup, setAbortable); amountFinished += keysNeedingBackup.length; setProgress(new Progress(total, amountFinished)); } @@ -126,15 +108,13 @@ export class KeyBackup { return payload; } - private async removeBackedUpKeys(keysNeedingBackup: BackupEntry[], setAbortable: SetAbortableFn) { + private async markKeysAsBackedUp(roomKeys: RoomKey[], setAbortable: SetAbortableFn) { const txn = await this.storage.readWriteTxn([ - StoreNames.sessionsNeedingBackup, + StoreNames.inboundGroupSessions, ]); setAbortable(txn); try { - for (const key of keysNeedingBackup) { - txn.sessionsNeedingBackup.remove(key.roomId, key.senderKey, key.sessionId); - } + await Promise.all(roomKeys.map(key => txn.inboundGroupSessions.markAsBackedUp(key.roomId, key.senderKey, key.sessionId))); } catch (err) { txn.abort(); throw err;