diff --git a/src/matrix/DeviceMessageHandler.js b/src/matrix/DeviceMessageHandler.js index 27f1d386..0a606841 100644 --- a/src/matrix/DeviceMessageHandler.js +++ b/src/matrix/DeviceMessageHandler.js @@ -15,7 +15,7 @@ limitations under the License. */ import {OLM_ALGORITHM} from "./e2ee/common.js"; -import {countBy, groupBy} from "../utils/groupBy.js"; +import {countBy, groupBy} from "../utils/groupBy"; export class DeviceMessageHandler { constructor({storage}) { @@ -67,12 +67,4 @@ class SyncPreparation { this.newRoomKeys = newRoomKeys; this.newKeysByRoom = groupBy(newRoomKeys, r => r.roomId); } - - dispose() { - if (this.newRoomKeys) { - for (const k of this.newRoomKeys) { - k.dispose(); - } - } - } } diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 36c8e084..7047171e 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -26,14 +26,15 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js"; import {Account as E2EEAccount} from "./e2ee/Account.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js"; -import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js"; +import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption"; +import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader"; import {SessionBackup} from "./e2ee/megolm/SessionBackup.js"; import {Encryption as MegOlmEncryption} from "./e2ee/megolm/Encryption.js"; import {MEGOLM_ALGORITHM} from "./e2ee/common.js"; import {RoomEncryption} from "./e2ee/RoomEncryption.js"; import {DeviceTracker} from "./e2ee/DeviceTracker.js"; import {LockMap} from "../utils/LockMap.js"; -import {groupBy} from "../utils/groupBy.js"; +import {groupBy} from "../utils/groupBy"; import { keyFromCredential as ssssKeyFromCredential, readKey as ssssReadKey, @@ -137,11 +138,8 @@ export class Session { now: this._platform.clock.now, ownDeviceId: this._sessionInfo.deviceId, }); - this._megolmDecryption = new MegOlmDecryption({ - pickleKey: PICKLE_KEY, - olm: this._olm, - olmWorker: this._olmWorker, - }); + const keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20); + this._megolmDecryption = new MegOlmDecryption(keyLoader, this._olmWorker); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption}); } @@ -319,6 +317,7 @@ export class Session { dispose() { this._olmWorker?.dispose(); this._sessionBackup?.dispose(); + this._megolmDecryption.dispose(); for (const room of this._rooms.values()) { room.dispose(); } diff --git a/src/matrix/Sync.js b/src/matrix/Sync.js index e010f90d..0be48007 100644 --- a/src/matrix/Sync.js +++ b/src/matrix/Sync.js @@ -464,7 +464,6 @@ class SessionSyncProcessState { dispose() { this.lock?.release(); - this.preparation?.dispose(); } } diff --git a/src/matrix/e2ee/DecryptionResult.js b/src/matrix/e2ee/DecryptionResult.js index c109e689..e1c2bcc4 100644 --- a/src/matrix/e2ee/DecryptionResult.js +++ b/src/matrix/e2ee/DecryptionResult.js @@ -29,10 +29,10 @@ limitations under the License. export class DecryptionResult { - constructor(event, senderCurve25519Key, claimedKeys) { + constructor(event, senderCurve25519Key, claimedEd25519Key) { this.event = event; this.senderCurve25519Key = senderCurve25519Key; - this.claimedEd25519Key = claimedKeys.ed25519; + this.claimedEd25519Key = claimedEd25519Key; this._device = null; this._roomTracked = true; } diff --git a/src/matrix/e2ee/RoomEncryption.js b/src/matrix/e2ee/RoomEncryption.js index aba7d07d..66b3366f 100644 --- a/src/matrix/e2ee/RoomEncryption.js +++ b/src/matrix/e2ee/RoomEncryption.js @@ -15,9 +15,9 @@ limitations under the License. */ import {MEGOLM_ALGORITHM, DecryptionSource} from "./common.js"; -import {groupEventsBySession} from "./megolm/decryption/utils.js"; +import {groupEventsBySession} from "./megolm/decryption/utils"; import {mergeMap} from "../../utils/mergeMap.js"; -import {groupBy} from "../../utils/groupBy.js"; +import {groupBy} from "../../utils/groupBy"; import {makeTxnId} from "../common.js"; const ENCRYPTED_TYPE = "m.room.encrypted"; @@ -36,8 +36,6 @@ export class RoomEncryption { this._megolmDecryption = megolmDecryption; // content of the m.room.encryption event this._encryptionParams = encryptionParams; - this._megolmBackfillCache = this._megolmDecryption.createSessionCache(); - this._megolmSyncCache = this._megolmDecryption.createSessionCache(1); // caches devices to verify events this._senderDeviceCache = new Map(); this._storage = storage; @@ -76,9 +74,6 @@ export class RoomEncryption { } notifyTimelineClosed() { - // empty the backfill cache when closing the timeline - this._megolmBackfillCache.dispose(); - this._megolmBackfillCache = this._megolmDecryption.createSessionCache(); this._senderDeviceCache = new Map(); // purge the sender device cache } @@ -112,27 +107,8 @@ export class RoomEncryption { } validEvents.push(event); } - let customCache; - let sessionCache; - // we have different caches so we can keep them small but still - // have backfill and sync not invalidate each other - if (source === DecryptionSource.Sync) { - sessionCache = this._megolmSyncCache; - } else if (source === DecryptionSource.Timeline) { - sessionCache = this._megolmBackfillCache; - } else if (source === DecryptionSource.Retry) { - // when retrying, we could have mixed events from at the bottom of the timeline (sync) - // and somewhere else, so create a custom cache we use just for this operation. - customCache = this._megolmDecryption.createSessionCache(); - sessionCache = customCache; - } else { - throw new Error("Unknown source: " + source); - } const preparation = await this._megolmDecryption.prepareDecryptAll( - this._room.id, validEvents, newKeys, sessionCache, txn); - if (customCache) { - customCache.dispose(); - } + this._room.id, validEvents, newKeys, txn); return new DecryptionPreparation(preparation, errors, source, this, events); } @@ -204,37 +180,31 @@ export class RoomEncryption { return; } log.set("id", sessionId); - log.set("senderKey", senderKey); + log.set("senderKey", senderKey); try { const session = await this._sessionBackup.getSession(this._room.id, sessionId, log); if (session?.algorithm === MEGOLM_ALGORITHM) { - if (session["sender_key"] !== senderKey) { - log.set("wrong_sender_key", session["sender_key"]); - log.logLevel = log.level.Warn; - return; - } let roomKey = this._megolmDecryption.roomKeyFromBackup(this._room.id, sessionId, session); if (roomKey) { + if (roomKey.senderKey !== senderKey) { + log.set("wrong_sender_key", roomKey.senderKey); + log.logLevel = log.level.Warn; + return; + } let keyIsBestOne = false; let retryEventIds; + const txn = await this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]); try { - const txn = await this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]); - try { - keyIsBestOne = await this._megolmDecryption.writeRoomKey(roomKey, txn); - log.set("isBetter", keyIsBestOne); - if (keyIsBestOne) { - retryEventIds = roomKey.eventIds; - } - } catch (err) { - txn.abort(); - throw err; + keyIsBestOne = await this._megolmDecryption.writeRoomKey(roomKey, txn); + log.set("isBetter", keyIsBestOne); + if (keyIsBestOne) { + retryEventIds = roomKey.eventIds; } - await txn.complete(); - } finally { - // can still access properties on it afterwards - // this is just clearing the internal sessionInfo - roomKey.dispose(); + } catch (err) { + txn.abort(); + throw err; } + await txn.complete(); if (keyIsBestOne) { await log.wrap("retryDecryption", log => this._room.notifyRoomKey(roomKey, retryEventIds || [], log)); } @@ -466,8 +436,6 @@ export class RoomEncryption { dispose() { this._disposed = true; - this._megolmBackfillCache.dispose(); - this._megolmSyncCache.dispose(); } } diff --git a/src/matrix/e2ee/megolm/Decryption.js b/src/matrix/e2ee/megolm/Decryption.ts similarity index 60% rename from src/matrix/e2ee/megolm/Decryption.js rename to src/matrix/e2ee/megolm/Decryption.ts index 8f4714ea..842d423d 100644 --- a/src/matrix/e2ee/megolm/Decryption.js +++ b/src/matrix/e2ee/megolm/Decryption.ts @@ -15,23 +15,26 @@ limitations under the License. */ import {DecryptionError} from "../common.js"; -import * as RoomKey from "./decryption/RoomKey.js"; -import {SessionInfo} from "./decryption/SessionInfo.js"; import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js"; -import {SessionDecryption} from "./decryption/SessionDecryption.js"; -import {SessionCache} from "./decryption/SessionCache.js"; +import {SessionDecryption} from "./decryption/SessionDecryption"; import {MEGOLM_ALGORITHM} from "../common.js"; -import {validateEvent, groupEventsBySession} from "./decryption/utils.js"; +import {validateEvent, groupEventsBySession} from "./decryption/utils"; +import {keyFromStorage, keyFromDeviceMessage, keyFromBackup} from "./decryption/RoomKey"; +import type {RoomKey, IncomingRoomKey} from "./decryption/RoomKey"; +import type {KeyLoader} from "./decryption/KeyLoader"; +import type {OlmWorker} from "../OlmWorker"; +import type {Transaction} from "../../storage/idb/Transaction"; +import type {TimelineEvent} from "../../storage/types"; +import type {DecryptionResult} from "../DecryptionResult"; +import type {LogItem} from "../../../logging/LogItem"; export class Decryption { - constructor({pickleKey, olm, olmWorker}) { - this._pickleKey = pickleKey; - this._olm = olm; - this._olmWorker = olmWorker; - } + private keyLoader: KeyLoader; + private olmWorker?: OlmWorker; - createSessionCache(size) { - return new SessionCache(size); + constructor(keyLoader: KeyLoader, olmWorker: OlmWorker | undefined) { + this.keyLoader = keyLoader; + this.olmWorker = olmWorker; } async addMissingKeyEventIds(roomId, senderKey, sessionId, eventIds, txn) { @@ -75,9 +78,9 @@ export class Decryption { * @param {[type]} txn [description] * @return {DecryptionPreparation} */ - async prepareDecryptAll(roomId, events, newKeys, sessionCache, txn) { + async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IncomingRoomKey[] | undefined, txn: Transaction) { const errors = new Map(); - const validEvents = []; + const validEvents: TimelineEvent[] = []; for (const event of events) { if (validateEvent(event)) { @@ -89,11 +92,11 @@ export class Decryption { const eventsBySession = groupEventsBySession(validEvents); - const sessionDecryptions = []; + const sessionDecryptions: SessionDecryption[] = []; await Promise.all(Array.from(eventsBySession.values()).map(async group => { - const sessionInfo = await this._getSessionInfo(roomId, group.senderKey, group.sessionId, newKeys, sessionCache, txn); - if (sessionInfo) { - sessionDecryptions.push(new SessionDecryption(sessionInfo, group.events, this._olmWorker)); + const key = await this.getRoomKey(roomId, group.senderKey!, group.sessionId!, newKeys, txn); + if (key) { + sessionDecryptions.push(new SessionDecryption(key, group.events, this.olmWorker, this.keyLoader)); } else { for (const event of group.events) { errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event)); @@ -104,63 +107,43 @@ export class Decryption { return new DecryptionPreparation(roomId, sessionDecryptions, errors); } - async _getSessionInfo(roomId, senderKey, sessionId, newKeys, sessionCache, txn) { - let sessionInfo; + private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IncomingRoomKey[] | undefined, txn: Transaction): Promise { if (newKeys) { - const key = newKeys.find(k => k.roomId === roomId && k.senderKey === senderKey && k.sessionId === sessionId); - if (key) { - sessionInfo = await key.createSessionInfo(this._olm, this._pickleKey, txn); - if (sessionInfo) { - sessionCache.add(sessionInfo); - } + const key = newKeys.find(k => k.isForSession(roomId, senderKey, sessionId)); + if (key && await key.checkBetterThanKeyInStorage(this.keyLoader, txn)) { + return key; } } // look only in the cache after looking into newKeys as it may contains that are better - if (!sessionInfo) { - sessionInfo = sessionCache.get(roomId, senderKey, sessionId); + const cachedKey = this.keyLoader.getCachedKey(roomId, senderKey, sessionId); + if (cachedKey) { + return cachedKey; } - if (!sessionInfo) { - const sessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); - if (sessionEntry && sessionEntry.session) { - let session = new this._olm.InboundGroupSession(); - try { - session.unpickle(this._pickleKey, sessionEntry.session); - sessionInfo = new SessionInfo(roomId, senderKey, session, sessionEntry.claimedKeys); - } catch (err) { - session.free(); - throw err; - } - sessionCache.add(sessionInfo); - } + const storageKey = await keyFromStorage(roomId, senderKey, sessionId, txn); + if (storageKey && storageKey.serializationKey) { + return storageKey; } - return sessionInfo; } /** * Writes the key as an inbound group session if there is not already a better key in the store - * @param {RoomKey} key - * @param {Transaction} txn a storage transaction with read/write on inboundGroupSessions - * @return {Promise} whether the key was the best for the sessio id and was written */ - writeRoomKey(key, txn) { - return key.write(this._olm, this._pickleKey, txn); + writeRoomKey(key: IncomingRoomKey, txn: Transaction): Promise { + return key.write(this.keyLoader, txn); } /** * Extracts room keys from decrypted device messages. * The key won't be persisted yet, you need to call RoomKey.write for that. - * - * @param {Array} decryptionResults, any non megolm m.room_key messages will be ignored. - * @return {Array} an array with validated RoomKey's. Note that it is possible we already have a better version of this key in storage though; writing the key will tell you so. */ - roomKeysFromDeviceMessages(decryptionResults, log) { - let keys = []; + roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IncomingRoomKey[] { + const keys: IncomingRoomKey[] = []; for (const dr of decryptionResults) { if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) { continue; } log.wrap("room_key", log => { - const key = RoomKey.fromDeviceMessage(dr); + const key = keyFromDeviceMessage(dr); if (key) { log.set("roomId", key.roomId); log.set("id", key.sessionId); @@ -174,8 +157,11 @@ export class Decryption { return keys; } - roomKeyFromBackup(roomId, sessionId, sessionInfo) { - return RoomKey.fromBackup(roomId, sessionId, sessionInfo); + roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IncomingRoomKey | undefined { + return keyFromBackup(roomId, sessionId, sessionInfo); + } + + dispose() { + this.keyLoader.dispose(); } } - diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts new file mode 100644 index 00000000..58f968c8 --- /dev/null +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -0,0 +1,433 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {isBetterThan, IncomingRoomKey} from "./RoomKey"; +import {BaseLRUCache} from "../../../../utils/LRUCache"; +import type {RoomKey} from "./RoomKey"; + +export declare class OlmDecryptionResult { + readonly plaintext: string; + readonly message_index: number; +} + +export declare class OlmInboundGroupSession { + constructor(); + free(): void; + pickle(key: string | Uint8Array): string; + unpickle(key: string | Uint8Array, pickle: string); + create(session_key: string): string; + import_session(session_key: string): string; + decrypt(message: string): OlmDecryptionResult; + session_id(): string; + first_known_index(): number; + export_session(message_index: number): string; +} + +/* +Because Olm only has very limited memory available when compiled to wasm, +we limit the amount of sessions held in memory. +*/ +export class KeyLoader extends BaseLRUCache { + + private pickleKey: string; + private olm: any; + private resolveUnusedOperation?: () => void; + private operationBecomesUnusedPromise?: Promise; + + constructor(olm: any, pickleKey: string, limit: number) { + super(limit); + this.pickleKey = pickleKey; + this.olm = olm; + } + + getCachedKey(roomId: string, senderKey: string, sessionId: string): RoomKey | undefined { + const idx = this.findCachedKeyIndex(roomId, senderKey, sessionId); + if (idx !== -1) { + return this._getByIndexAndMoveUp(idx)!.key; + } + } + + async useKey(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { + const keyOp = await this.allocateOperation(key); + try { + return await callback(keyOp.session, this.pickleKey); + } finally { + this.releaseOperation(keyOp); + } + } + + get running() { + return this._entries.some(op => op.refCount !== 0); + } + + dispose() { + for (let i = 0; i < this._entries.length; i += 1) { + this._entries[i].dispose(); + } + // remove all entries + this._entries.splice(0, this._entries.length); + } + + private async allocateOperation(key: RoomKey): Promise { + let idx; + while((idx = this.findIndexForAllocation(key)) === -1) { + await this.operationBecomesUnused(); + } + if (idx < this.size) { + const op = this._getByIndexAndMoveUp(idx)!; + // cache hit + if (op.isForKey(key)) { + op.refCount += 1; + return op; + } else { + // refCount should be 0 here + op.refCount = 1; + op.key = key; + key.loadInto(op.session, this.pickleKey); + } + return op; + } else { + // create new operation + const session = new this.olm.InboundGroupSession(); + key.loadInto(session, this.pickleKey); + const op = new KeyOperation(key, session); + this._set(op); + return op; + } + } + + private releaseOperation(op: KeyOperation) { + op.refCount -= 1; + if (op.refCount <= 0 && this.resolveUnusedOperation) { + this.resolveUnusedOperation(); + // promise is resolved now, we'll need a new one for next await so clear + this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined; + } + } + + private operationBecomesUnused(): Promise { + if (!this.operationBecomesUnusedPromise) { + this.operationBecomesUnusedPromise = new Promise(resolve => { + this.resolveUnusedOperation = resolve; + }); + } + return this.operationBecomesUnusedPromise; + } + + private findIndexForAllocation(key: RoomKey) { + let idx = this.findIndexSameKey(key); // cache hit + if (idx === -1) { + if (this.size < this.limit) { + idx = this.size; + } else { + idx = this.findIndexSameSessionUnused(key); + if (idx === -1) { + idx = this.findIndexOldestUnused(); + } + } + } + return idx; + } + + private findCachedKeyIndex(roomId: string, senderKey: string, sessionId: string): number { + return this._entries.reduce((bestIdx, op, i, arr) => { + const bestOp = bestIdx === -1 ? undefined : arr[bestIdx]; + // only operations that are the "best" for their session can be used, see comment on isBest + if (op.isBest === true && op.isForSameSession(roomId, senderKey, sessionId)) { + if (!bestOp || op.isBetter(bestOp)) { + return i; + } + } + return bestIdx; + }, -1); + } + + private findIndexSameKey(key: RoomKey): number { + return this._entries.findIndex(op => { + return op.isForSameSession(key.roomId, key.senderKey, key.sessionId) && op.isForKey(key); + }); + } + + private findIndexSameSessionUnused(key: RoomKey): number { + return this._entries.reduce((worstIdx, op, i, arr) => { + const worst = worstIdx === -1 ? undefined : arr[worstIdx]; + // we try to pick the worst operation to overwrite, so the best one stays in the cache + if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) { + if (!worst || !op.isBetter(worst)) { + return i; + } + } + return worstIdx; + }, -1); + } + + private findIndexOldestUnused(): number { + for (let i = this._entries.length - 1; i >= 0; i -= 1) { + const op = this._entries[i]; + if (op.refCount === 0) { + return i; + } + } + return -1; + } +} + +class KeyOperation { + session: OlmInboundGroupSession; + key: RoomKey; + refCount: number; + + constructor(key: RoomKey, session: OlmInboundGroupSession) { + this.key = key; + this.session = session; + this.refCount = 1; + } + + isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean { + return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId; + } + + // assumes isForSameSession is true + isBetter(other: KeyOperation) { + return isBetterThan(this.session, other.session); + } + + isForKey(key: RoomKey) { + return this.key.serializationKey === key.serializationKey && + this.key.serializationType === key.serializationType; + } + + dispose() { + this.session.free(); + } + + /** returns whether the key for this operation has been checked at some point against storage + * and was determined to be the better key, undefined if it hasn't been checked yet. + * Only keys that are the best keys can be returned by getCachedKey as returning a cache hit + * will usually not check for a better session in storage. Also see RoomKey.isBetter. */ + get isBest(): boolean | undefined { + return this.key.isBetter; + } +} + +export function tests() { + let instances = 0; + + class MockRoomKey extends IncomingRoomKey { + private _roomId: string; + private _senderKey: string; + private _sessionId: string; + private _firstKnownIndex: number; + + constructor(roomId: string, senderKey: string, sessionId: string, firstKnownIndex: number) { + super(); + this._roomId = roomId; + this._senderKey = senderKey; + this._sessionId = sessionId; + this._firstKnownIndex = firstKnownIndex; + } + + get roomId(): string { return this._roomId; } + get senderKey(): string { return this._senderKey; } + get sessionId(): string { return this._sessionId; } + get claimedEd25519Key(): string { return "claimedEd25519Key"; } + get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; } + get serializationType(): string { return "type"; } + get eventIds(): string[] | undefined { return undefined; } + loadInto(session: OlmInboundGroupSession) { + const mockSession = session as MockInboundSession; + mockSession.sessionId = this.sessionId; + mockSession.firstKnownIndex = this._firstKnownIndex; + } + } + + class MockInboundSession { + public sessionId: string = ""; + public firstKnownIndex: number = 0; + + constructor() { + instances += 1; + } + + free(): void { instances -= 1; } + pickle(key: string | Uint8Array): string { return `${this.sessionId}-pickled-session`; } + unpickle(key: string | Uint8Array, pickle: string) {} + create(session_key: string): string { return `${this.sessionId}-created-session`; } + import_session(session_key: string): string { return ""; } + decrypt(message: string): OlmDecryptionResult { return {} as OlmDecryptionResult; } + session_id(): string { return this.sessionId; } + first_known_index(): number { return this.firstKnownIndex; } + export_session(message_index: number): string { return `${this.sessionId}-exported-session`; } + } + + const PICKLE_KEY = "🥒🔑"; + const olm = {InboundGroupSession: MockInboundSession}; + const roomId = "!abc:hs.tld"; + const aliceSenderKey = "abc"; + const bobSenderKey = "def"; + const sessionId1 = "s123"; + const sessionId2 = "s456"; + + return { + "load key gives correct session": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let callback1Called = false; + let callback2Called = false; + const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + callback1Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 1); + await Promise.resolve(); // make sure they are busy in parallel + }); + const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 2), async session => { + callback2Called = true; + assert.equal(session.session_id(), sessionId2); + assert.equal(session.first_known_index(), 2); + await Promise.resolve(); // make sure they are busy in parallel + }); + assert.equal(loader.size, 2); + await Promise.all([p1, p2]); + assert(callback1Called); + assert(callback2Called); + }, + "keys with different first index are kept separate": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let callback1Called = false; + let callback2Called = false; + const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + callback1Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 1); + await Promise.resolve(); // make sure they are busy in parallel + }); + const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2), async session => { + callback2Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 2); + await Promise.resolve(); // make sure they are busy in parallel + }); + assert.equal(loader.size, 2); + await Promise.all([p1, p2]); + assert(callback1Called); + assert(callback2Called); + }, + "useKey blocks as long as no free sessions are available": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 1); + let resolve; + let callbackCalled = false; + loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + await new Promise(r => resolve = r); + }); + await Promise.resolve(); + assert.equal(loader.size, 1); + const promise = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), session => { + callbackCalled = true; + }); + assert.equal(callbackCalled, false); + resolve(); + await promise; + assert.equal(callbackCalled, true); + }, + "cache hit while key in use, then replace (check refCount works properly)": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 1); + let resolve1, resolve2; + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + const p1 = loader.useKey(key1, async session => { + await new Promise(r => resolve1 = r); + }); + const p2 = loader.useKey(key1, async session => { + await new Promise(r => resolve2 = r); + }); + await Promise.resolve(); + assert.equal(loader.size, 1); + assert.equal(loader.running, true); + resolve1(); + await p1; + assert.equal(loader.running, true); + resolve2(); + await p2; + assert.equal(loader.running, false); + let callbackCalled = false; + await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => { + callbackCalled = true; + assert.equal(session.session_id(), sessionId2); + assert.equal(session.first_known_index(), 1); + }); + assert.equal(loader.size, 1); + assert.equal(callbackCalled, true); + }, + "cache hit while key not in use": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let resolve1, resolve2, invocations = 0; + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + await loader.useKey(key1, async session => { invocations += 1; }); + key1.isBetter = true; + assert.equal(loader.size, 1); + const cachedKey = loader.getCachedKey(roomId, aliceSenderKey, sessionId1)!; + assert.equal(cachedKey, key1); + await loader.useKey(cachedKey, async session => { invocations += 1; }); + assert.equal(loader.size, 1); + assert.equal(invocations, 2); + }, + "dispose calls free on all sessions": async assert => { + instances = 0; + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {}); + await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {}); + assert.equal(instances, 2); + assert.equal(loader.size, 2); + loader.dispose(); + assert.strictEqual(instances, 0, "instances"); + assert.strictEqual(loader.size, 0, "loader.size"); + }, + "checkBetterThanKeyInStorage false with cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + await loader.useKey(key1, async session => {}); + // fake we've checked with storage that this is the best key, + // and as long is it remains the best key with newly added keys, + // it will be returned from getCachedKey (as called from checkBetterThanKeyInStorage) + key1.isBetter = true; + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3); + // this will hit cache of key 1 so we pass in null as txn + const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any); + assert.strictEqual(isBetter, false); + assert.strictEqual(key2.isBetter, false); + }, + "checkBetterThanKeyInStorage true with cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key + await loader.useKey(key1, async session => {}); + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + // this will hit cache of key 1 so we pass in null as txn + const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any); + assert.strictEqual(isBetter, true); + assert.strictEqual(key2.isBetter, true); + }, + "prefer to remove worst key for a session from cache": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); + await loader.useKey(key1, async session => {}); + key1.isBetter = true; // set to true just so it gets returned from getCachedKey + const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 4); + await loader.useKey(key2, async session => {}); + const key3 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3); + await loader.useKey(key3, async session => {}); + assert.strictEqual(loader.getCachedKey(roomId, aliceSenderKey, sessionId1), key1); + }, + } +} diff --git a/src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.js b/src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.ts similarity index 61% rename from src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.js rename to src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.ts index e5ce2845..f3807c39 100644 --- a/src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.js +++ b/src/matrix/e2ee/megolm/decryption/ReplayDetectionEntry.ts @@ -14,11 +14,24 @@ See the License for the specific language governing permissions and limitations under the License. */ +import type {TimelineEvent} from "../../../storage/types"; + export class ReplayDetectionEntry { - constructor(sessionId, messageIndex, event) { + public readonly sessionId: string; + public readonly messageIndex: number; + public readonly event: TimelineEvent; + + constructor(sessionId: string, messageIndex: number, event: TimelineEvent) { this.sessionId = sessionId; this.messageIndex = messageIndex; - this.eventId = event.event_id; - this.timestamp = event.origin_server_ts; + this.event = event; + } + + get eventId(): string { + return this.event.event_id; + } + + get timestamp(): number { + return this.event.origin_server_ts; } } diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.js b/src/matrix/e2ee/megolm/decryption/RoomKey.js deleted file mode 100644 index 6fcd738b..00000000 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.js +++ /dev/null @@ -1,166 +0,0 @@ -import {SessionInfo} from "./SessionInfo.js"; - -export class BaseRoomKey { - constructor() { - this._sessionInfo = null; - this._isBetter = null; - this._eventIds = null; - } - - async createSessionInfo(olm, pickleKey, txn) { - if (this._isBetter === false) { - return; - } - const session = new olm.InboundGroupSession(); - try { - this._loadSessionKey(session); - this._isBetter = await this._isBetterThanKnown(session, olm, pickleKey, txn); - if (this._isBetter) { - const claimedKeys = {ed25519: this.claimedEd25519Key}; - this._sessionInfo = new SessionInfo(this.roomId, this.senderKey, session, claimedKeys); - // retain the session so we don't have to create a new session during write. - this._sessionInfo.retain(); - return this._sessionInfo; - } else { - session.free(); - return; - } - } catch (err) { - this._sessionInfo = null; - session.free(); - throw err; - } - } - - async _isBetterThanKnown(session, olm, pickleKey, txn) { - let isBetter = true; - // TODO: we could potentially have a small speedup here if we looked first in the SessionCache here... - const existingSessionEntry = await txn.inboundGroupSessions.get(this.roomId, this.senderKey, this.sessionId); - if (existingSessionEntry?.session) { - const existingSession = new olm.InboundGroupSession(); - try { - existingSession.unpickle(pickleKey, existingSessionEntry.session); - isBetter = session.first_known_index() < existingSession.first_known_index(); - } finally { - existingSession.free(); - } - } - // store the event ids that can be decrypted with this key - // before we overwrite them if called from `write`. - if (existingSessionEntry?.eventIds) { - this._eventIds = existingSessionEntry.eventIds; - } - return isBetter; - } - - async write(olm, pickleKey, txn) { - // we checked already and we had a better session in storage, so don't write - if (this._isBetter === false) { - return false; - } - if (!this._sessionInfo) { - await this.createSessionInfo(olm, pickleKey, txn); - } - if (this._sessionInfo) { - const session = this._sessionInfo.session; - const sessionEntry = { - roomId: this.roomId, - senderKey: this.senderKey, - sessionId: this.sessionId, - session: session.pickle(pickleKey), - claimedKeys: this._sessionInfo.claimedKeys, - }; - txn.inboundGroupSessions.set(sessionEntry); - this.dispose(); - return true; - } - return false; - } - - get eventIds() { - return this._eventIds; - } - - dispose() { - if (this._sessionInfo) { - this._sessionInfo.release(); - this._sessionInfo = null; - } - } -} - -class DeviceMessageRoomKey extends BaseRoomKey { - constructor(decryptionResult) { - super(); - this._decryptionResult = decryptionResult; - } - - get roomId() { return this._decryptionResult.event.content?.["room_id"]; } - get senderKey() { return this._decryptionResult.senderCurve25519Key; } - get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } - get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } - - _loadSessionKey(session) { - const sessionKey = this._decryptionResult.event.content?.["session_key"]; - session.create(sessionKey); - } -} - -class BackupRoomKey extends BaseRoomKey { - constructor(roomId, sessionId, backupInfo) { - super(); - this._roomId = roomId; - this._sessionId = sessionId; - this._backupInfo = backupInfo; - } - - get roomId() { return this._roomId; } - get senderKey() { return this._backupInfo["sender_key"]; } - get sessionId() { return this._sessionId; } - get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } - - _loadSessionKey(session) { - const sessionKey = this._backupInfo["session_key"]; - session.import_session(sessionKey); - } -} - -export function fromDeviceMessage(dr) { - const roomId = dr.event.content?.["room_id"]; - const sessionId = dr.event.content?.["session_id"]; - const sessionKey = dr.event.content?.["session_key"]; - if ( - typeof roomId === "string" || - typeof sessionId === "string" || - typeof senderKey === "string" || - typeof sessionKey === "string" - ) { - return new DeviceMessageRoomKey(dr); - } -} - -/* -sessionInfo is a response from key backup and has the following keys: - algorithm - forwarding_curve25519_key_chain - sender_claimed_keys - sender_key - session_key - */ -export function fromBackup(roomId, sessionId, sessionInfo) { - const sessionKey = sessionInfo["session_key"]; - const senderKey = sessionInfo["sender_key"]; - // TODO: can we just trust this? - const claimedEd25519Key = sessionInfo["sender_claimed_keys"]?.["ed25519"]; - - if ( - typeof roomId === "string" && - typeof sessionId === "string" && - typeof senderKey === "string" && - typeof sessionKey === "string" && - typeof claimedEd25519Key === "string" - ) { - return new BackupRoomKey(roomId, sessionId, sessionInfo); - } -} - diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts new file mode 100644 index 00000000..81f1a9be --- /dev/null +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -0,0 +1,245 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; +import type {Transaction} from "../../../storage/idb/Transaction"; +import type {DecryptionResult} from "../../DecryptionResult"; +import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; + +export abstract class RoomKey { + private _isBetter: boolean | undefined; + + isForSession(roomId: string, senderKey: string, sessionId: string) { + return this.roomId === roomId && this.senderKey === senderKey && this.sessionId === sessionId; + } + + abstract get roomId(): string; + abstract get senderKey(): string; + abstract get sessionId(): string; + abstract get claimedEd25519Key(): string; + abstract get serializationKey(): string; + abstract get serializationType(): string; + abstract get eventIds(): string[] | undefined; + abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; + /* Whether the key has been checked against storage (or is from storage) + * to be the better key for a given session. Given that all keys are checked to be better + * as part of writing, we can trust that when this returns true, it really is the best key + * available between storage and cached keys in memory. This is why keys with this field set to + * true are used by the key loader to return cached keys. Also see KeyOperation.isBest there. */ + get isBetter(): boolean | undefined { return this._isBetter; } + // should only be set in key.checkBetterThanKeyInStorage + set isBetter(value: boolean | undefined) { this._isBetter = value; } +} + +export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { + return newSession.first_known_index() < existingSession.first_known_index(); +} + +export abstract class IncomingRoomKey extends RoomKey { + private _eventIds?: string[]; + + checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise { + return this._checkBetterThanKeyInStorage(loader, undefined, txn); + } + + async write(loader: KeyLoader, txn: Transaction): Promise { + // we checked already and we had a better session in storage, so don't write + let pickledSession; + if (this.isBetter === undefined) { + // if this key wasn't used to decrypt any messages in the same sync, + // we haven't checked if this is the best key yet, + // so do that now to not overwrite a better key. + // while we have the key deserialized, also pickle it to store it later on here. + await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => { + pickledSession = session.pickle(pickleKey); + }, txn); + } + if (this.isBetter === false) { + return false; + } + // before calling write in parallel, we need to check loader.running is false so we are sure our transaction will not be closed + if (!pickledSession) { + pickledSession = await loader.useKey(this, (session, pickleKey) => session.pickle(pickleKey)); + } + const sessionEntry = { + roomId: this.roomId, + senderKey: this.senderKey, + sessionId: this.sessionId, + session: pickledSession, + claimedKeys: {"ed25519": this.claimedEd25519Key}, + }; + txn.inboundGroupSessions.set(sessionEntry); + return true; + } + + get eventIds() { return this._eventIds; } + + private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { + if (this.isBetter !== undefined) { + return this.isBetter; + } + let existingKey = loader.getCachedKey(this.roomId, this.senderKey, this.sessionId); + if (!existingKey) { + const storageKey = await keyFromStorage(this.roomId, this.senderKey, this.sessionId, txn); + // store the event ids that can be decrypted with this key + // before we overwrite them if called from `write`. + if (storageKey) { + if (storageKey.hasSession) { + existingKey = storageKey; + } else if (storageKey.eventIds) { + this._eventIds = storageKey.eventIds; + } + } + } + if (existingKey) { + const key = existingKey; + await loader.useKey(this, async newSession => { + await loader.useKey(key, (existingSession, pickleKey) => { + // set isBetter as soon as possible, on both keys compared, + // as it is is used to determine whether a key can be used for the cache + this.isBetter = isBetterThan(newSession, existingSession); + key.isBetter = !this.isBetter; + if (this.isBetter && callback) { + callback(newSession, pickleKey); + } + }); + }); + } else { + // no previous key, so we're the best \o/ + this.isBetter = true; + } + return this.isBetter!; + } +} + +class DeviceMessageRoomKey extends IncomingRoomKey { + private _decryptionResult: DecryptionResult; + + constructor(decryptionResult: DecryptionResult) { + super(); + this._decryptionResult = decryptionResult; + } + + get roomId() { return this._decryptionResult.event.content?.["room_id"]; } + get senderKey() { return this._decryptionResult.senderCurve25519Key; } + get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } + get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } + get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; } + get serializationType(): string { return "create"; } + + loadInto(session) { + session.create(this.serializationKey); + } +} + +class BackupRoomKey extends IncomingRoomKey { + private _roomId: string; + private _sessionId: string; + private _backupInfo: string; + + constructor(roomId, sessionId, backupInfo) { + super(); + this._roomId = roomId; + this._sessionId = sessionId; + this._backupInfo = backupInfo; + } + + get roomId() { return this._roomId; } + get senderKey() { return this._backupInfo["sender_key"]; } + get sessionId() { return this._sessionId; } + get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } + get serializationKey(): string { return this._backupInfo["session_key"]; } + get serializationType(): string { return "import_session"; } + + loadInto(session) { + session.import_session(this.serializationKey); + } +} + +class StoredRoomKey extends RoomKey { + private storageEntry: InboundGroupSessionEntry; + + constructor(storageEntry: InboundGroupSessionEntry) { + super(); + this.isBetter = true; // usually the key in storage is the best until checks prove otherwise + this.storageEntry = storageEntry; + } + + get roomId() { return this.storageEntry.roomId; } + get senderKey() { return this.storageEntry.senderKey; } + get sessionId() { return this.storageEntry.sessionId; } + get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; } + get eventIds() { return this.storageEntry.eventIds; } + get serializationKey(): string { return this.storageEntry.session || ""; } + get serializationType(): string { return "unpickle"; } + + loadInto(session, pickleKey) { + session.unpickle(pickleKey, this.serializationKey); + } + + get hasSession() { + // sessions are stored before they are received + // to keep track of events that need it to be decrypted. + // This is used to retry decryption of those events once the session is received. + return !!this.serializationKey; + } +} + +export function keyFromDeviceMessage(dr: DecryptionResult): DeviceMessageRoomKey | undefined { + const sessionKey = dr.event.content?.["session_key"]; + const key = new DeviceMessageRoomKey(dr); + if ( + typeof key.roomId === "string" && + typeof key.sessionId === "string" && + typeof key.senderKey === "string" && + typeof sessionKey === "string" + ) { + return key; + } +} + +/* +sessionInfo is a response from key backup and has the following keys: + algorithm + forwarding_curve25519_key_chain + sender_claimed_keys + sender_key + session_key + */ +export function keyFromBackup(roomId, sessionId, backupInfo): BackupRoomKey | undefined { + const sessionKey = backupInfo["session_key"]; + const senderKey = backupInfo["sender_key"]; + // TODO: can we just trust this? + const claimedEd25519Key = backupInfo["sender_claimed_keys"]?.["ed25519"]; + + if ( + typeof roomId === "string" && + typeof sessionId === "string" && + typeof senderKey === "string" && + typeof sessionKey === "string" && + typeof claimedEd25519Key === "string" + ) { + return new BackupRoomKey(roomId, sessionId, backupInfo); + } +} + +export async function keyFromStorage(roomId: string, senderKey: string, sessionId: string, txn: Transaction): Promise { + const existingSessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); + if (existingSessionEntry) { + return new StoredRoomKey(existingSessionEntry); + } + return; +} diff --git a/src/matrix/e2ee/megolm/decryption/SessionCache.js b/src/matrix/e2ee/megolm/decryption/SessionCache.js deleted file mode 100644 index c5b2c0fb..00000000 --- a/src/matrix/e2ee/megolm/decryption/SessionCache.js +++ /dev/null @@ -1,61 +0,0 @@ -/* -Copyright 2020 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -import {BaseLRUCache} from "../../../../utils/LRUCache.js"; -const DEFAULT_CACHE_SIZE = 10; - -/** - * Cache of unpickled inbound megolm session. - */ -export class SessionCache extends BaseLRUCache { - constructor(limit) { - limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE; - super(limit); - } - - /** - * @param {string} roomId - * @param {string} senderKey - * @param {string} sessionId - * @return {SessionInfo?} - */ - get(roomId, senderKey, sessionId) { - return this._get(s => { - return s.roomId === roomId && - s.senderKey === senderKey && - sessionId === s.sessionId; - }); - } - - add(sessionInfo) { - sessionInfo.retain(); - this._set(sessionInfo, s => { - return s.roomId === sessionInfo.roomId && - s.senderKey === sessionInfo.senderKey && - s.sessionId === sessionInfo.sessionId; - }); - } - - _onEvictEntry(sessionInfo) { - sessionInfo.release(); - } - - dispose() { - for (const sessionInfo of this._entries) { - sessionInfo.release(); - } - } -} diff --git a/src/matrix/e2ee/megolm/decryption/SessionDecryption.js b/src/matrix/e2ee/megolm/decryption/SessionDecryption.js deleted file mode 100644 index 137ae9f8..00000000 --- a/src/matrix/e2ee/megolm/decryption/SessionDecryption.js +++ /dev/null @@ -1,90 +0,0 @@ -/* -Copyright 2020 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -import {DecryptionResult} from "../../DecryptionResult.js"; -import {DecryptionError} from "../../common.js"; -import {ReplayDetectionEntry} from "./ReplayDetectionEntry.js"; - -/** - * Does the actual decryption of all events for a given megolm session in a batch - */ -export class SessionDecryption { - constructor(sessionInfo, events, olmWorker) { - sessionInfo.retain(); - this._sessionInfo = sessionInfo; - this._events = events; - this._olmWorker = olmWorker; - this._decryptionRequests = olmWorker ? [] : null; - } - - async decryptAll() { - const replayEntries = []; - const results = new Map(); - let errors; - const roomId = this._sessionInfo.roomId; - - await Promise.all(this._events.map(async event => { - try { - const {session} = this._sessionInfo; - const ciphertext = event.content.ciphertext; - let decryptionResult; - if (this._olmWorker) { - const request = this._olmWorker.megolmDecrypt(session, ciphertext); - this._decryptionRequests.push(request); - decryptionResult = await request.response(); - } else { - decryptionResult = session.decrypt(ciphertext); - } - const plaintext = decryptionResult.plaintext; - const messageIndex = decryptionResult.message_index; - let payload; - try { - payload = JSON.parse(plaintext); - } catch (err) { - throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err}); - } - if (payload.room_id !== roomId) { - throw new DecryptionError("MEGOLM_WRONG_ROOM", event, - {encryptedRoomId: payload.room_id, eventRoomId: roomId}); - } - replayEntries.push(new ReplayDetectionEntry(session.session_id(), messageIndex, event)); - const result = new DecryptionResult(payload, this._sessionInfo.senderKey, this._sessionInfo.claimedKeys); - results.set(event.event_id, result); - } catch (err) { - // ignore AbortError from cancelling decryption requests in dispose method - if (err.name === "AbortError") { - return; - } - if (!errors) { - errors = new Map(); - } - errors.set(event.event_id, err); - } - })); - - return {results, errors, replayEntries}; - } - - dispose() { - if (this._decryptionRequests) { - for (const r of this._decryptionRequests) { - r.abort(); - } - } - // TODO: cancel decryptions here - this._sessionInfo.release(); - } -} diff --git a/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts new file mode 100644 index 00000000..7e466806 --- /dev/null +++ b/src/matrix/e2ee/megolm/decryption/SessionDecryption.ts @@ -0,0 +1,103 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {DecryptionResult} from "../../DecryptionResult.js"; +import {DecryptionError} from "../../common.js"; +import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; +import type {RoomKey} from "./RoomKey.js"; +import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader"; +import type {OlmWorker} from "../../OlmWorker"; +import type {TimelineEvent} from "../../../storage/types"; + +interface DecryptAllResult { + readonly results: Map; + readonly errors?: Map; + readonly replayEntries: ReplayDetectionEntry[]; +} +/** + * Does the actual decryption of all events for a given megolm session in a batch + */ +export class SessionDecryption { + private key: RoomKey; + private events: TimelineEvent[]; + private keyLoader: KeyLoader; + private olmWorker?: OlmWorker; + private decryptionRequests?: any[]; + + constructor(key: RoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) { + this.key = key; + this.events = events; + this.olmWorker = olmWorker; + this.keyLoader = keyLoader; + this.decryptionRequests = olmWorker ? [] : undefined; + } + + async decryptAll(): Promise { + const replayEntries: ReplayDetectionEntry[] = []; + const results: Map = new Map(); + let errors: Map | undefined; + + await this.keyLoader.useKey(this.key, async session => { + for (const event of this.events) { + try { + const ciphertext = event.content.ciphertext as string; + let decryptionResult: OlmDecryptionResult | undefined; + // TODO: pass all cipthertexts in one go to the megolm worker and don't deserialize the key until in the worker? + if (this.olmWorker) { + const request = this.olmWorker.megolmDecrypt(session, ciphertext); + this.decryptionRequests!.push(request); + decryptionResult = await request.response(); + } else { + decryptionResult = session.decrypt(ciphertext); + } + const {plaintext} = decryptionResult!; + let payload; + try { + payload = JSON.parse(plaintext); + } catch (err) { + throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err}); + } + if (payload.room_id !== this.key.roomId) { + throw new DecryptionError("MEGOLM_WRONG_ROOM", event, + {encryptedRoomId: payload.room_id, eventRoomId: this.key.roomId}); + } + replayEntries.push(new ReplayDetectionEntry(this.key.sessionId, decryptionResult!.message_index, event)); + const result = new DecryptionResult(payload, this.key.senderKey, this.key.claimedEd25519Key); + results.set(event.event_id, result); + } catch (err) { + // ignore AbortError from cancelling decryption requests in dispose method + if (err.name === "AbortError") { + return; + } + if (!errors) { + errors = new Map(); + } + errors.set(event.event_id, err); + } + } + }); + + return {results, errors, replayEntries}; + } + + dispose() { + if (this.decryptionRequests) { + for (const r of this.decryptionRequests) { + r.abort(); + } + } + } +} diff --git a/src/matrix/e2ee/megolm/decryption/SessionInfo.js b/src/matrix/e2ee/megolm/decryption/SessionInfo.js deleted file mode 100644 index 098bc3de..00000000 --- a/src/matrix/e2ee/megolm/decryption/SessionInfo.js +++ /dev/null @@ -1,49 +0,0 @@ -/* -Copyright 2020 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -/** - * session loaded in memory with everything needed to create DecryptionResults - * and to store/retrieve it in the SessionCache - */ -export class SessionInfo { - constructor(roomId, senderKey, session, claimedKeys) { - this.roomId = roomId; - this.senderKey = senderKey; - this.session = session; - this.claimedKeys = claimedKeys; - this._refCounter = 0; - } - - get sessionId() { - return this.session?.session_id(); - } - - retain() { - this._refCounter += 1; - } - - release() { - this._refCounter -= 1; - if (this._refCounter <= 0) { - this.dispose(); - } - } - - dispose() { - this.session.free(); - this.session = null; - } -} diff --git a/src/matrix/e2ee/megolm/decryption/utils.js b/src/matrix/e2ee/megolm/decryption/utils.ts similarity index 50% rename from src/matrix/e2ee/megolm/decryption/utils.js rename to src/matrix/e2ee/megolm/decryption/utils.ts index c38b1416..4207006b 100644 --- a/src/matrix/e2ee/megolm/decryption/utils.js +++ b/src/matrix/e2ee/megolm/decryption/utils.ts @@ -14,44 +14,46 @@ See the License for the specific language governing permissions and limitations under the License. */ -import {groupByWithCreator} from "../../../../utils/groupBy.js"; +import {groupByWithCreator} from "../../../../utils/groupBy"; +import type {TimelineEvent} from "../../../storage/types"; -function getSenderKey(event) { +function getSenderKey(event: TimelineEvent): string | undefined { return event.content?.["sender_key"]; } -function getSessionId(event) { +function getSessionId(event: TimelineEvent): string | undefined { return event.content?.["session_id"]; } -function getCiphertext(event) { +function getCiphertext(event: TimelineEvent): string | undefined { return event.content?.ciphertext; } -export function validateEvent(event) { +export function validateEvent(event: TimelineEvent) { return typeof getSenderKey(event) === "string" && typeof getSessionId(event) === "string" && typeof getCiphertext(event) === "string"; } -class SessionKeyGroup { +export class SessionKeyGroup { + public readonly events: TimelineEvent[]; constructor() { this.events = []; } - get senderKey() { - return getSenderKey(this.events[0]); + get senderKey(): string | undefined { + return getSenderKey(this.events[0]!); } - get sessionId() { - return getSessionId(this.events[0]); + get sessionId(): string | undefined { + return getSessionId(this.events[0]!); } } -export function groupEventsBySession(events) { - return groupByWithCreator(events, - event => `${getSenderKey(event)}|${getSessionId(event)}`, +export function groupEventsBySession(events: TimelineEvent[]): Map { + return groupByWithCreator(events, + (event: TimelineEvent) => `${getSenderKey(event)}|${getSessionId(event)}`, () => new SessionKeyGroup(), - (group, event) => group.events.push(event) + (group: SessionKeyGroup, event: TimelineEvent) => group.events.push(event) ); } diff --git a/src/matrix/e2ee/olm/Decryption.js b/src/matrix/e2ee/olm/Decryption.js index 7556c367..0af3bd23 100644 --- a/src/matrix/e2ee/olm/Decryption.js +++ b/src/matrix/e2ee/olm/Decryption.js @@ -15,7 +15,7 @@ limitations under the License. */ import {DecryptionError} from "../common.js"; -import {groupBy} from "../../../utils/groupBy.js"; +import {groupBy} from "../../../utils/groupBy"; import {MultiLock} from "../../../utils/Lock.js"; import {Session} from "./Session.js"; import {DecryptionResult} from "../DecryptionResult.js"; @@ -150,7 +150,7 @@ export class Decryption { throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error}); } this._validatePayload(payload, event); - return new DecryptionResult(payload, senderKey, payload.keys); + return new DecryptionResult(payload, senderKey, payload.keys.ed25519); } else { throw new DecryptionError("OLM_NO_MATCHING_SESSION", event, {knownSessionIds: senderKeyDecryption.sessions.map(s => s.id)}); diff --git a/src/matrix/e2ee/olm/Encryption.js b/src/matrix/e2ee/olm/Encryption.js index 3bc66ec3..1b720ae7 100644 --- a/src/matrix/e2ee/olm/Encryption.js +++ b/src/matrix/e2ee/olm/Encryption.js @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -import {groupByWithCreator} from "../../../utils/groupBy.js"; +import {groupByWithCreator} from "../../../utils/groupBy"; import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js"; import {createSessionEntry} from "./Session.js"; diff --git a/src/matrix/room/Room.js b/src/matrix/room/Room.js index cab0e13b..a0d7c693 100644 --- a/src/matrix/room/Room.js +++ b/src/matrix/room/Room.js @@ -76,6 +76,10 @@ export class Room extends BaseRoom { let eventsToDecrypt = roomResponse?.timeline?.events || []; // when new keys arrive, also see if any older events can now be retried to decrypt if (newKeys) { + // TODO: if a key is considered by roomEncryption.prepareDecryptAll to use for decryption, + // key.eventIds will be set. We could somehow try to reuse that work, but retrying also needs + // to happen if a key is not needed to decrypt this sync or there are indeed no encrypted messages + // in this sync at all. retryEntries = await this._getSyncRetryDecryptEntries(newKeys, roomEncryption, txn); if (retryEntries.length) { log.set("retry", retryEntries.length); diff --git a/src/matrix/room/timeline/persistence/MemberWriter.js b/src/matrix/room/timeline/persistence/MemberWriter.js index b21a4461..1cdcb7d5 100644 --- a/src/matrix/room/timeline/persistence/MemberWriter.js +++ b/src/matrix/room/timeline/persistence/MemberWriter.js @@ -15,7 +15,7 @@ limitations under the License. */ import {MemberChange, RoomMember, EVENT_TYPE as MEMBER_EVENT_TYPE} from "../../members/RoomMember.js"; -import {LRUCache} from "../../../../utils/LRUCache.js"; +import {LRUCache} from "../../../../utils/LRUCache"; export class MemberWriter { constructor(roomId) { diff --git a/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts b/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts index 5dc0205f..22093884 100644 --- a/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts +++ b/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts @@ -17,24 +17,26 @@ limitations under the License. import {MIN_UNICODE, MAX_UNICODE} from "./common"; import {Store} from "../Store"; -interface InboundGroupSession { +export interface InboundGroupSessionEntry { roomId: string; senderKey: string; sessionId: string; session?: string; claimedKeys?: { [algorithm : string] : string }; eventIds?: string[]; - key: string; } +type InboundGroupSessionStorageEntry = InboundGroupSessionEntry & { key: string }; + + function encodeKey(roomId: string, senderKey: string, sessionId: string): string { return `${roomId}|${senderKey}|${sessionId}`; } export class InboundGroupSessionStore { - private _store: Store; + private _store: Store; - constructor(store: Store) { + constructor(store: Store) { this._store = store; } @@ -44,13 +46,14 @@ export class InboundGroupSessionStore { return key === fetchedKey; } - get(roomId: string, senderKey: string, sessionId: string): Promise { + get(roomId: string, senderKey: string, sessionId: string): Promise { return this._store.get(encodeKey(roomId, senderKey, sessionId)); } - set(session: InboundGroupSession): void { - session.key = encodeKey(session.roomId, session.senderKey, session.sessionId); - this._store.put(session); + set(session: InboundGroupSessionEntry): void { + const storageEntry = session as InboundGroupSessionStorageEntry; + storageEntry.key = encodeKey(session.roomId, session.senderKey, session.sessionId); + this._store.put(storageEntry); } removeAllForRoom(roomId: string) { diff --git a/src/utils/LRUCache.js b/src/utils/LRUCache.ts similarity index 63% rename from src/utils/LRUCache.js rename to src/utils/LRUCache.ts index 185e5aeb..c5a7cd06 100644 --- a/src/utils/LRUCache.js +++ b/src/utils/LRUCache.ts @@ -14,18 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ + +type FindCallback = (value: T) => boolean; /** * Very simple least-recently-used cache implementation * that should be fast enough for very small cache sizes */ -export class BaseLRUCache { - constructor(limit) { - this._limit = limit; +export class BaseLRUCache { + + public readonly limit: number; + protected _entries: T[]; + + constructor(limit: number) { + this.limit = limit; this._entries = []; } - _get(findEntryFn) { - const idx = this._entries.findIndex(findEntryFn); + get size() { return this._entries.length; } + + protected _get(findEntryFn: FindCallback) { + return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn)); + } + + protected _getByIndexAndMoveUp(idx: number) { if (idx !== -1) { const entry = this._entries[idx]; // move to top @@ -37,11 +48,11 @@ export class BaseLRUCache { } } - _set(value, findEntryFn) { - let indexToRemove = this._entries.findIndex(findEntryFn); + protected _set(value: T, findEntryFn?: FindCallback) { + let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1; this._entries.unshift(value); if (indexToRemove === -1) { - if (this._entries.length > this._limit) { + if (this._entries.length > this.limit) { indexToRemove = this._entries.length - 1; } } else { @@ -49,75 +60,82 @@ export class BaseLRUCache { indexToRemove += 1; } if (indexToRemove !== -1) { - this._onEvictEntry(this._entries[indexToRemove]); + this.onEvictEntry(this._entries[indexToRemove]); this._entries.splice(indexToRemove, 1); } } - _onEvictEntry() {} + protected onEvictEntry(entry: T) {} } -export class LRUCache extends BaseLRUCache { - constructor(limit, keyFn) { +export class LRUCache extends BaseLRUCache { + private _keyFn: (T) => K; + + constructor(limit, keyFn: (T) => K) { super(limit); this._keyFn = keyFn; } - get(key) { + get(key: K): T | undefined { return this._get(e => this._keyFn(e) === key); } - set(value) { + set(value: T) { const key = this._keyFn(value); this._set(value, e => this._keyFn(e) === key); } } export function tests() { + interface NameTuple { + id: number; + name: string; + } + return { "can retrieve added entries": assert => { - const cache = new LRUCache(2, e => e.id); + const cache = new LRUCache(2, e => e.id); cache.set({id: 1, name: "Alice"}); cache.set({id: 2, name: "Bob"}); - assert.equal(cache.get(1).name, "Alice"); - assert.equal(cache.get(2).name, "Bob"); + assert.equal(cache.get(1)!.name, "Alice"); + assert.equal(cache.get(2)!.name, "Bob"); }, "first entry is evicted first": assert => { - const cache = new LRUCache(2, e => e.id); + const cache = new LRUCache(2, e => e.id); cache.set({id: 1, name: "Alice"}); cache.set({id: 2, name: "Bob"}); cache.set({id: 3, name: "Charly"}); assert.equal(cache.get(1), undefined); - assert.equal(cache.get(2).name, "Bob"); - assert.equal(cache.get(3).name, "Charly"); - assert.equal(cache._entries.length, 2); + assert.equal(cache.get(2)!.name, "Bob"); + assert.equal(cache.get(3)!.name, "Charly"); + assert.equal(cache.size, 2); }, "second entry is evicted if first is requested": assert => { - const cache = new LRUCache(2, e => e.id); + const cache = new LRUCache(2, e => e.id); cache.set({id: 1, name: "Alice"}); cache.set({id: 2, name: "Bob"}); cache.get(1); cache.set({id: 3, name: "Charly"}); - assert.equal(cache.get(1).name, "Alice"); + assert.equal(cache.get(1)!.name, "Alice"); assert.equal(cache.get(2), undefined); - assert.equal(cache.get(3).name, "Charly"); - assert.equal(cache._entries.length, 2); + assert.equal(cache.get(3)!.name, "Charly"); + assert.equal(cache.size, 2); }, "setting an entry twice removes the first": assert => { - const cache = new LRUCache(2, e => e.id); + const cache = new LRUCache(2, e => e.id); cache.set({id: 1, name: "Alice"}); cache.set({id: 2, name: "Bob"}); cache.set({id: 1, name: "Al Ice"}); cache.set({id: 3, name: "Charly"}); - assert.equal(cache.get(1).name, "Al Ice"); + assert.equal(cache.get(1)!.name, "Al Ice"); assert.equal(cache.get(2), undefined); - assert.equal(cache.get(3).name, "Charly"); - assert.equal(cache._entries.length, 2); + assert.equal(cache.get(3)!.name, "Charly"); + assert.equal(cache.size, 2); }, "evict callback is called": assert => { let evictions = 0; - class CustomCache extends LRUCache { - _onEvictEntry(entry) { + class CustomCache extends LRUCache { + onEvictEntry(entry) { assert.equal(entry.name, "Alice"); evictions += 1; } @@ -130,8 +148,8 @@ export function tests() { }, "evict callback is called when replacing entry with same identity": assert => { let evictions = 0; - class CustomCache extends LRUCache { - _onEvictEntry(entry) { + class CustomCache extends LRUCache { + onEvictEntry(entry) { assert.equal(entry.name, "Alice"); evictions += 1; } diff --git a/src/utils/groupBy.js b/src/utils/groupBy.ts similarity index 79% rename from src/utils/groupBy.js rename to src/utils/groupBy.ts index 9bed5298..2d91b209 100644 --- a/src/utils/groupBy.js +++ b/src/utils/groupBy.ts @@ -14,14 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -export function groupBy(array, groupFn) { - return groupByWithCreator(array, groupFn, +export function groupBy(array: V[], groupFn: (V) => K): Map { + return groupByWithCreator(array, groupFn, () => {return [];}, (array, value) => array.push(value) ); } -export function groupByWithCreator(array, groupFn, createCollectionFn, addCollectionFn) { +export function groupByWithCreator(array: V[], groupFn: (V) => K, createCollectionFn: () => C, addCollectionFn: (C, V) => void): Map { return array.reduce((map, value) => { const key = groupFn(value); let collection = map.get(key); @@ -31,10 +31,10 @@ export function groupByWithCreator(array, groupFn, createCollectionFn, addCollec } addCollectionFn(collection, value); return map; - }, new Map()); + }, new Map()); } -export function countBy(events, mapper) { +export function countBy(events: V[], mapper: (V) => string | number): { [key: string]: number } { return events.reduce((counts, event) => { const mappedValue = mapper(event); if (!counts[mappedValue]) {