diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts index d72ec546..bde3065f 100644 --- a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -15,7 +15,8 @@ limitations under the License. */ import {SessionCache} from "./SessionCache"; -import {IRoomKey} from "./RoomKey"; +import {IRoomKey, isBetterThan} from "./RoomKey"; +import {BaseLRUCache} from "../../../../utils/LRUCache"; export declare class OlmInboundGroupSession { constructor(); @@ -30,82 +31,193 @@ export declare class OlmInboundGroupSession { export_session(message_index: number): string; } +// this is what cache.get(...) should return +function findIndexBestForSession(ops: KeyOperation[], roomId: string, senderKey: string, sessionId: string): number { + return ops.reduce((bestIdx, op, i, arr) => { + const bestOp = bestIdx === -1 ? undefined : arr[bestIdx]; + if (op.isForSameSession(roomId, senderKey, sessionId)) { + if (!bestOp || op.isBetter(bestOp)) { + return i; + } + } + return bestIdx; + }, -1); +} + + /* Because Olm only has very limited memory available when compiled to wasm, we limit the amount of sessions held in memory. */ -export class KeyLoader { +export class KeyLoader extends BaseLRUCache { - public readonly cache: SessionCache; + private runningOps: Set; + private unusedOps: Set; private pickleKey: string; private olm: any; - private resolveUnusedEntry?: () => void; - private entryBecomesUnusedPromise?: Promise; + private resolveUnusedOperation?: () => void; + private operationBecomesUnusedPromise?: Promise; constructor(olm: any, pickleKey: string, limit: number) { - this.cache = new SessionCache(limit); + super(limit); this.pickleKey = pickleKey; this.olm = olm; } + getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined { + const idx = this.findIndexBestForSession(roomId, senderKey, sessionId); + if (idx !== -1) { + return this._getByIndexAndMoveUp(idx)!.key; + } + } + async useKey(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { - const cacheEntry = await this.allocateEntry(key); + const keyOp = await this.allocateOperation(key); try { - const {session} = cacheEntry; - key.loadInto(session, this.pickleKey); - return await callback(session, this.pickleKey); + return await callback(keyOp.session, this.pickleKey); } finally { - this.freeEntry(cacheEntry); + this.releaseOperation(keyOp); } } get running() { - return !!this.cache.find(entry => entry.inUse); + return this._entries.some(op => op.refCount !== 0); } - private async allocateEntry(key: IRoomKey): Promise { - let entry; - if (this.cache.size >= this.cache.limit) { - while(!(entry = this.cache.find(entry => !entry.inUse))) { - await this.entryBecomesUnused(); + dispose() { + for (let i = 0; i < this._entries.length; i += 1) { + this._entries[i].dispose(); + } + // remove all entries + this._entries.splice(0, this._entries.length); + } + + private async allocateOperation(key: IRoomKey): Promise { + let idx; + while((idx = this.findIndexForAllocation(key)) === -1) { + await this.operationBecomesUnused(); + } + if (idx < this.size) { + const op = this._getByIndexAndMoveUp(idx)!; + // cache hit + if (op.isForKey(key)) { + op.refCount += 1; + return op; + } else { + // refCount should be 0 here + op.refCount = 1; + op.key = key; + key.loadInto(op.session, this.pickleKey); } - entry.inUse = true; - entry.key = key; + return op; } else { - const session: OlmInboundGroupSession = new this.olm.InboundGroupSession(); - const entry = new CacheEntry(key, session); - this.cache.add(entry); + // create new operation + const session = new this.olm.InboundGroupSession(); + key.loadInto(session, this.pickleKey); + const op = new KeyOperation(key, session); + this._set(op); + return op; } - return entry; } - private freeEntry(entry: CacheEntry) { - entry.inUse = false; - if (this.resolveUnusedEntry) { - this.resolveUnusedEntry(); + private releaseOperation(op: KeyOperation) { + op.refCount -= 1; + if (op.refCount <= 0 && this.resolveUnusedOperation) { + this.resolveUnusedOperation(); // promise is resolved now, we'll need a new one for next await so clear - this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined; + this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined; } } - private entryBecomesUnused(): Promise { - if (!this.entryBecomesUnusedPromise) { - this.entryBecomesUnusedPromise = new Promise(resolve => { - this.resolveUnusedEntry = resolve; + private operationBecomesUnused(): Promise { + if (!this.operationBecomesUnusedPromise) { + this.operationBecomesUnusedPromise = new Promise(resolve => { + this.resolveUnusedOperation = resolve; }); } - return this.entryBecomesUnusedPromise; + return this.operationBecomesUnusedPromise; + } + + private findIndexForAllocation(key: IRoomKey) { + let idx = this.findIndexSameKey(key); // cache hit + if (idx === -1) { + idx = this.findIndexSameSessionUnused(key); + if (idx === -1) { + if (this.size < this.limit) { + idx = this.size; + } else { + idx = this.findIndexOldestUnused(); + } + } + } + return idx; + } + + private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number { + return this._entries.reduce((bestIdx, op, i, arr) => { + const bestOp = bestIdx === -1 ? undefined : arr[bestIdx]; + if (op.isForSameSession(roomId, senderKey, sessionId)) { + if (!bestOp || op.isBetter(bestOp)) { + return i; + } + } + return bestIdx; + }, -1); + } + + private findIndexSameKey(key: IRoomKey): number { + return this._entries.findIndex(op => { + return op.isForKey(key); + }); + } + + private findIndexSameSessionUnused(key: IRoomKey): number { + for (let i = this._entries.length - 1; i >= 0; i -= 1) { + const op = this._entries[i]; + if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) { + return i; + } + } + return -1; + } + + private findIndexOldestUnused(): number { + for (let i = this._entries.length - 1; i >= 0; i -= 1) { + const op = this._entries[i]; + if (op.refCount === 0) { + return i; + } + } + return -1; } } -class CacheEntry { - inUse: boolean; +class KeyOperation { session: OlmInboundGroupSession; key: IRoomKey; + refCount: number; - constructor(key, session) { + constructor(key: IRoomKey, session: OlmInboundGroupSession) { this.key = key; this.session = session; - this.inUse = true; + this.refCount = 1; + } + + isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean { + return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId; + } + + // assumes isForSameSession is true + isBetter(other: KeyOperation) { + return isBetterThan(this.session, other.session); + } + + isForKey(key: IRoomKey) { + return this.key.serializationKey === key.serializationKey && + this.key.serializationType === key.serializationType; + } + + dispose() { + this.session.free(); } } diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index c4c00d99..14f94613 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -18,20 +18,25 @@ import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/Inbound import type {Transaction} from "../../../storage/idb/Transaction"; import type {DecryptionResult} from "../../DecryptionResult"; import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; -import {SessionCache} from "./SessionCache"; export interface IRoomKey { get roomId(): string; get senderKey(): string; get sessionId(): string; get claimedEd25519Key(): string; + get serializationKey(): string; + get serializationType(): string; get eventIds(): string[] | undefined; loadInto(session: OlmInboundGroupSession, pickleKey: string): void; } +export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { + return newSession.first_known_index() < existingSession.first_known_index(); +} + export interface IIncomingRoomKey extends IRoomKey { get isBetter(): boolean | undefined; - checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise; + checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise; write(loader: KeyLoader, txn: Transaction): Promise; } @@ -39,8 +44,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { private _eventIds?: string[]; private _isBetter?: boolean; - checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise { - return this._checkBetterKeyInStorage(loader, undefined, txn); + checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise { + return this._checkBetterThanKeyInStorage(loader, undefined, txn); } async write(loader: KeyLoader, txn: Transaction): Promise { @@ -51,7 +56,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { // we haven't checked if this is the best key yet, // so do that now to not overwrite a better key. // while we have the key deserialized, also pickle it to store it later on here. - await this._checkBetterKeyInStorage(loader, (session, pickleKey) => { + await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => { pickledSession = session.pickle(pickleKey); }, txn); } @@ -76,7 +81,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { get eventIds() { return this._eventIds; } get isBetter() { return this._isBetter; } - private async _checkBetterKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { + private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { if (this._isBetter !== undefined) { return this._isBetter; } @@ -96,7 +101,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { if (existingKey) { this._isBetter = await loader.useKey(this, newSession => { return loader.useKey(existingKey, (existingSession, pickleKey) => { - const isBetter = newSession.first_known_index() < existingSession.first_known_index(); + const isBetter = isBetterThan(newSession, existingSession); if (isBetter && callback) { callback(newSession, pickleKey); } @@ -114,6 +119,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { abstract get senderKey(): string; abstract get sessionId(): string; abstract get claimedEd25519Key(): string; + abstract get serializationKey(): string; + abstract get serializationType(): string; abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; } @@ -129,10 +136,11 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey { get senderKey() { return this._decryptionResult.senderCurve25519Key; } get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } + get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; } + get serializationType(): string { return "create"; } loadInto(session) { - const sessionKey = this._decryptionResult.event.content?.["session_key"]; - session.create(sessionKey); + session.create(this.serializationKey); } } @@ -152,10 +160,11 @@ class BackupRoomKey extends BaseIncomingRoomKey { get senderKey() { return this._backupInfo["sender_key"]; } get sessionId() { return this._sessionId; } get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } - + get serializationKey(): string { return this._backupInfo["session_key"]; } + get serializationType(): string { return "import_session"; } + loadInto(session) { - const sessionKey = this._backupInfo["session_key"]; - session.import_session(sessionKey); + session.import_session(this.serializationKey); } } @@ -171,9 +180,11 @@ class StoredRoomKey implements IRoomKey { get sessionId() { return this.storageEntry.sessionId; } get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; } get eventIds() { return this.storageEntry.eventIds; } - + get serializationKey(): string { return this.storageEntry.session || ""; } + get serializationType(): string { return "unpickle"; } + loadInto(session, pickleKey) { - session.unpickle(pickleKey, this.storageEntry.session); + session.unpickle(pickleKey, this.serializationKey); } get hasSession() { diff --git a/src/matrix/e2ee/megolm/decryption/SessionCache.js b/src/matrix/e2ee/megolm/decryption/SessionCache.js deleted file mode 100644 index cbb868ea..00000000 --- a/src/matrix/e2ee/megolm/decryption/SessionCache.js +++ /dev/null @@ -1,63 +0,0 @@ -/* -Copyright 2020 The Matrix.org Foundation C.I.C. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -import {BaseLRUCache} from "../../../../utils/LRUCache"; -const DEFAULT_CACHE_SIZE = 10; - -/** - * Cache of unpickled inbound megolm session. - */ -export class SessionCache extends BaseLRUCache { - constructor(limit) { - limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE; - super(limit); - } - - /** - * @param {string} roomId - * @param {string} senderKey - * @param {string} sessionId - * @return {SessionInfo?} - */ - get(roomId, senderKey, sessionId) { - const sessionInfo = this._get(s => { - return s.roomId === roomId && - s.senderKey === senderKey && - sessionId === s.sessionId; - }); - sessionInfo?.retain(); - return sessionInfo; - } - - add(sessionInfo) { - sessionInfo.retain(); - this._set(sessionInfo, s => { - return s.roomId === sessionInfo.roomId && - s.senderKey === sessionInfo.senderKey && - s.sessionId === sessionInfo.sessionId; - }); - } - - _onEvictEntry(sessionInfo) { - sessionInfo.release(); - } - - dispose() { - for (const sessionInfo of this._entries) { - sessionInfo.release(); - } - } -} diff --git a/src/utils/LRUCache.ts b/src/utils/LRUCache.ts index d275b9f1..05d8fef5 100644 --- a/src/utils/LRUCache.ts +++ b/src/utils/LRUCache.ts @@ -14,25 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ + +type FindCallback = (value: T) => boolean; /** * Very simple least-recently-used cache implementation * that should be fast enough for very small cache sizes */ export class BaseLRUCache { - private _limit: number; - private _entries: T[]; + public readonly limit: number; + protected _entries: T[]; constructor(limit: number) { - this._limit = limit; + this.limit = limit; this._entries = []; } get size() { return this._entries.length; } - get limit() { return this._limit; } - _get(findEntryFn: (T) => boolean) { - const idx = this._entries.findIndex(findEntryFn); + protected _get(findEntryFn: FindCallback) { + return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn)); + } + + protected _getByIndexAndMoveUp(idx: number) { if (idx !== -1) { const entry = this._entries[idx]; // move to top @@ -44,11 +48,11 @@ export class BaseLRUCache { } } - _set(value: T, findEntryFn: (T) => boolean) { - let indexToRemove = this._entries.findIndex(findEntryFn); + protected _set(value: T, findEntryFn?: FindCallback) { + let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1; this._entries.unshift(value); if (indexToRemove === -1) { - if (this._entries.length > this._limit) { + if (this._entries.length > this.limit) { indexToRemove = this._entries.length - 1; } } else { @@ -56,22 +60,12 @@ export class BaseLRUCache { indexToRemove += 1; } if (indexToRemove !== -1) { - this._onEvictEntry(this._entries[indexToRemove]); + this.onEvictEntry(this._entries[indexToRemove]); this._entries.splice(indexToRemove, 1); } } - find(callback: (T) => boolean) { - // iterate backwards so least recently used items are found first - for (let i = this._entries.length - 1; i >= 0; i -= 1) { - const entry = this._entries[i]; - if (callback(entry)) { - return entry; - } - } - } - - _onEvictEntry(entry: T) {} + protected onEvictEntry(entry: T) {} } export class LRUCache extends BaseLRUCache {