diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts new file mode 100644 index 00000000..d72ec546 --- /dev/null +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -0,0 +1,111 @@ +/* +Copyright 2021 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {SessionCache} from "./SessionCache"; +import {IRoomKey} from "./RoomKey"; + +export declare class OlmInboundGroupSession { + constructor(); + free(): void; + pickle(key: string | Uint8Array): string; + unpickle(key: string | Uint8Array, pickle: string); + create(session_key: string): string; + import_session(session_key: string): string; + decrypt(message: string): object; + session_id(): string; + first_known_index(): number; + export_session(message_index: number): string; +} + +/* +Because Olm only has very limited memory available when compiled to wasm, +we limit the amount of sessions held in memory. +*/ +export class KeyLoader { + + public readonly cache: SessionCache; + private pickleKey: string; + private olm: any; + private resolveUnusedEntry?: () => void; + private entryBecomesUnusedPromise?: Promise; + + constructor(olm: any, pickleKey: string, limit: number) { + this.cache = new SessionCache(limit); + this.pickleKey = pickleKey; + this.olm = olm; + } + + async useKey(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise | T): Promise { + const cacheEntry = await this.allocateEntry(key); + try { + const {session} = cacheEntry; + key.loadInto(session, this.pickleKey); + return await callback(session, this.pickleKey); + } finally { + this.freeEntry(cacheEntry); + } + } + + get running() { + return !!this.cache.find(entry => entry.inUse); + } + + private async allocateEntry(key: IRoomKey): Promise { + let entry; + if (this.cache.size >= this.cache.limit) { + while(!(entry = this.cache.find(entry => !entry.inUse))) { + await this.entryBecomesUnused(); + } + entry.inUse = true; + entry.key = key; + } else { + const session: OlmInboundGroupSession = new this.olm.InboundGroupSession(); + const entry = new CacheEntry(key, session); + this.cache.add(entry); + } + return entry; + } + + private freeEntry(entry: CacheEntry) { + entry.inUse = false; + if (this.resolveUnusedEntry) { + this.resolveUnusedEntry(); + // promise is resolved now, we'll need a new one for next await so clear + this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined; + } + } + + private entryBecomesUnused(): Promise { + if (!this.entryBecomesUnusedPromise) { + this.entryBecomesUnusedPromise = new Promise(resolve => { + this.resolveUnusedEntry = resolve; + }); + } + return this.entryBecomesUnusedPromise; + } +} + +class CacheEntry { + inUse: boolean; + session: OlmInboundGroupSession; + key: IRoomKey; + + constructor(key, session) { + this.key = key; + this.session = session; + this.inUse = true; + } +} diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.ts b/src/matrix/e2ee/megolm/decryption/RoomKey.ts index ed799759..c4c00d99 100644 --- a/src/matrix/e2ee/megolm/decryption/RoomKey.ts +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.ts @@ -14,22 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ -import type {InboundGroupSession} from "../../../storage/idb/stores/InboundGroupSessionStore"; +import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {Transaction} from "../../../storage/idb/Transaction"; import type {DecryptionResult} from "../../DecryptionResult"; - -declare class OlmInboundGroupSession { - constructor(); - free(): void; - pickle(key: string | Uint8Array): string; - unpickle(key: string | Uint8Array, pickle: string); - create(session_key: string): string; - import_session(session_key: string): string; - decrypt(message: string): object; - session_id(): string; - first_known_index(): number; - export_session(message_index: number): string; -} +import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; +import {SessionCache} from "./SessionCache"; export interface IRoomKey { get roomId(): string; @@ -37,24 +26,24 @@ export interface IRoomKey { get sessionId(): string; get claimedEd25519Key(): string; get eventIds(): string[] | undefined; - deserializeInto(session: OlmInboundGroupSession, pickleKey: string): void; + loadInto(session: OlmInboundGroupSession, pickleKey: string): void; } export interface IIncomingRoomKey extends IRoomKey { get isBetter(): boolean | undefined; - checkIsBetterThanStorage(keyDeserialization: KeyDeserialization, txn: Transaction): Promise; - write(keyDeserialization: KeyDeserialization, txn: Transaction): Promise; + checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise; + write(loader: KeyLoader, txn: Transaction): Promise; } abstract class BaseIncomingRoomKey implements IIncomingRoomKey { private _eventIds?: string[]; private _isBetter?: boolean; - checkBetterKeyInStorage(keyDeserialization: KeyDeserialization, txn: Transaction): Promise { - return this._checkBetterKeyInStorage(keyDeserialization, undefined, txn); + checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise { + return this._checkBetterKeyInStorage(loader, undefined, txn); } - async write(keyDeserialization: KeyDeserialization, pickleKey: string, txn: Transaction): Promise { + async write(loader: KeyLoader, txn: Transaction): Promise { // we checked already and we had a better session in storage, so don't write let pickledSession; if (this._isBetter === undefined) { @@ -62,23 +51,23 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { // we haven't checked if this is the best key yet, // so do that now to not overwrite a better key. // while we have the key deserialized, also pickle it to store it later on here. - await this._checkBetterKeyInStorage(keyDeserialization, session => { + await this._checkBetterKeyInStorage(loader, (session, pickleKey) => { pickledSession = session.pickle(pickleKey); }, txn); } if (this._isBetter === false) { return false; } - // before calling write in parallel, we need to check keyDeserialization.running is false so we are sure our transaction will not be closed + // before calling write in parallel, we need to check loader.running is false so we are sure our transaction will not be closed if (!pickledSession) { - pickledSession = await keyDeserialization.useKey(this, session => session.pickle(pickleKey)); + pickledSession = await loader.useKey(this, (session, pickleKey) => session.pickle(pickleKey)); } const sessionEntry = { roomId: this.roomId, senderKey: this.senderKey, sessionId: this.sessionId, session: pickledSession, - claimedKeys: this._sessionInfo.claimedKeys, + claimedKeys: {"ed25519": this.claimedEd25519Key}, }; txn.inboundGroupSessions.set(sessionEntry); return true; @@ -87,11 +76,11 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { get eventIds() { return this._eventIds; } get isBetter() { return this._isBetter; } - private async _checkBetterKeyInStorage(keyDeserialization: KeyDeserialization, callback?: (session: OlmInboundGroupSession) => void, txn: Transaction): Promise { + private async _checkBetterKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise { if (this._isBetter !== undefined) { return this._isBetter; } - let existingKey = keyDeserialization.cache.get(this.roomId, this.senderKey, this.sessionId); + let existingKey = loader.cache.get(this.roomId, this.senderKey, this.sessionId); if (!existingKey) { const storageKey = await fromStorage(this.roomId, this.senderKey, this.sessionId, txn); // store the event ids that can be decrypted with this key @@ -105,11 +94,11 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { } } if (existingKey) { - this._isBetter = await keyDeserialization.useKey(key, newSession => { - return keyDeserialization.useKey(existingKey, existingSession => { + this._isBetter = await loader.useKey(this, newSession => { + return loader.useKey(existingKey, (existingSession, pickleKey) => { const isBetter = newSession.first_known_index() < existingSession.first_known_index(); if (isBetter && callback) { - callback(newSession); + callback(newSession, pickleKey); } return isBetter; }); @@ -120,9 +109,15 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey { } return this._isBetter; } + + abstract get roomId(): string; + abstract get senderKey(): string; + abstract get sessionId(): string; + abstract get claimedEd25519Key(): string; + abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; } -class DeviceMessageRoomKey extends BaseIncomingRoomKey implements IIncomingRoomKey { +class DeviceMessageRoomKey extends BaseIncomingRoomKey { private _decryptionResult: DecryptionResult; constructor(decryptionResult: DecryptionResult) { @@ -135,13 +130,13 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey implements IIncomingRoomK get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } - deserializeInto(session) { + loadInto(session) { const sessionKey = this._decryptionResult.event.content?.["session_key"]; session.create(sessionKey); } } -class BackupRoomKey extends BaseIncomingRoomKey implements IIncomingRoomKey { +class BackupRoomKey extends BaseIncomingRoomKey { private _roomId: string; private _sessionId: string; private _backupInfo: string; @@ -158,16 +153,16 @@ class BackupRoomKey extends BaseIncomingRoomKey implements IIncomingRoomKey { get sessionId() { return this._sessionId; } get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } - deserializeInto(session) { + loadInto(session) { const sessionKey = this._backupInfo["session_key"]; session.import_session(sessionKey); } } class StoredRoomKey implements IRoomKey { - private storageEntry: InboundGroupSession; + private storageEntry: InboundGroupSessionEntry; - constructor(storageEntry: InboundGroupSession) { + constructor(storageEntry: InboundGroupSessionEntry) { this.storageEntry = storageEntry; } @@ -177,7 +172,7 @@ class StoredRoomKey implements IRoomKey { get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; } get eventIds() { return this.storageEntry.eventIds; } - deserializeInto(session, pickleKey) { + loadInto(session, pickleKey) { session.unpickle(pickleKey, this.storageEntry.session); } @@ -189,17 +184,16 @@ class StoredRoomKey implements IRoomKey { } } -export function fromDeviceMessage(dr) { - const roomId = dr.event.content?.["room_id"]; - const sessionId = dr.event.content?.["session_id"]; +export function fromDeviceMessage(dr: DecryptionResult): DeviceMessageRoomKey | undefined { const sessionKey = dr.event.content?.["session_key"]; + const key = new DeviceMessageRoomKey(dr); if ( - typeof roomId === "string" || - typeof sessionId === "string" || - typeof senderKey === "string" || + typeof key.roomId === "string" && + typeof key.sessionId === "string" && + typeof key.senderKey === "string" && typeof sessionKey === "string" ) { - return new DeviceMessageRoomKey(dr); + return key; } } @@ -211,11 +205,11 @@ sessionInfo is a response from key backup and has the following keys: sender_key session_key */ -export function fromBackup(roomId, sessionId, sessionInfo) { - const sessionKey = sessionInfo["session_key"]; - const senderKey = sessionInfo["sender_key"]; +export function fromBackup(roomId, sessionId, backupInfo): BackupRoomKey | undefined { + const sessionKey = backupInfo["session_key"]; + const senderKey = backupInfo["sender_key"]; // TODO: can we just trust this? - const claimedEd25519Key = sessionInfo["sender_claimed_keys"]?.["ed25519"]; + const claimedEd25519Key = backupInfo["sender_claimed_keys"]?.["ed25519"]; if ( typeof roomId === "string" && @@ -224,7 +218,7 @@ export function fromBackup(roomId, sessionId, sessionInfo) { typeof sessionKey === "string" && typeof claimedEd25519Key === "string" ) { - return new BackupRoomKey(roomId, sessionId, sessionInfo); + return new BackupRoomKey(roomId, sessionId, backupInfo); } } @@ -235,82 +229,3 @@ export async function fromStorage(roomId: string, senderKey: string, sessionId: } return; } -/* -Because Olm only has very limited memory available when compiled to wasm, -we limit the amount of sessions held in memory. -*/ -class KeyDeserialization { - - public readonly cache: SessionCache; - private pickleKey: string; - private olm: any; - private resolveUnusedEntry?: () => void; - private entryBecomesUnusedPromise?: Promise; - - constructor({olm, pickleKey, limit}) { - this.cache = new SessionCache(limit); - this.pickleKey = pickleKey; - this.olm = olm; - } - - async useKey(key: IRoomKey, callback: (session: OlmInboundGroupSession) => Promise | T): Promise { - const cacheEntry = await this.allocateEntry(key); - try { - const {session} = cacheEntry; - key.deserializeInto(session, this.pickleKey); - return await callback(session); - } finally { - this.freeEntry(cacheEntry); - } - } - - get running() { - return !!this.cache.find(entry => entry.inUse); - } - - private async allocateEntry(key): CacheEntry { - let entry; - if (this.cache.size >= MAX) { - while(!(entry = this.cache.find(entry => !entry.inUse))) { - await this.entryBecomesUnused(); - } - entry.inUse = true; - entry.key = key; - } else { - const session: OlmInboundGroupSession = new this.olm.InboundGroupSession(); - const entry = new CacheEntry(key, session); - this.cache.add(entry); - } - return entry; - } - - private freeEntry(entry) { - entry.inUse = false; - if (this.resolveUnusedEntry) { - this.resolveUnusedEntry(); - // promise is resolved now, we'll need a new one for next await so clear - this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined; - } - } - - private entryBecomesUnused(): Promise { - if (!this.entryBecomesUnusedPromise) { - this.entryBecomesUnusedPromise = new Promise(resolve => { - this.resolveUnusedEntry = resolve; - }); - } - return this.entryBecomesUnusedPromise; - } -} - -class CacheEntry { - inUse: boolean; - session: OlmInboundGroupSession; - key: IRoomKey; - - constructor(key, session) { - this.key = key; - this.session = session; - this.inUse = true; - } -} diff --git a/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts b/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts index 65ce99ce..22093884 100644 --- a/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts +++ b/src/matrix/storage/idb/stores/InboundGroupSessionStore.ts @@ -17,24 +17,26 @@ limitations under the License. import {MIN_UNICODE, MAX_UNICODE} from "./common"; import {Store} from "../Store"; -export interface InboundGroupSession { +export interface InboundGroupSessionEntry { roomId: string; senderKey: string; sessionId: string; session?: string; claimedKeys?: { [algorithm : string] : string }; eventIds?: string[]; - key: string; } +type InboundGroupSessionStorageEntry = InboundGroupSessionEntry & { key: string }; + + function encodeKey(roomId: string, senderKey: string, sessionId: string): string { return `${roomId}|${senderKey}|${sessionId}`; } export class InboundGroupSessionStore { - private _store: Store; + private _store: Store; - constructor(store: Store) { + constructor(store: Store) { this._store = store; } @@ -44,13 +46,14 @@ export class InboundGroupSessionStore { return key === fetchedKey; } - get(roomId: string, senderKey: string, sessionId: string): Promise { + get(roomId: string, senderKey: string, sessionId: string): Promise { return this._store.get(encodeKey(roomId, senderKey, sessionId)); } - set(session: InboundGroupSession): void { - session.key = encodeKey(session.roomId, session.senderKey, session.sessionId); - this._store.put(session); + set(session: InboundGroupSessionEntry): void { + const storageEntry = session as InboundGroupSessionStorageEntry; + storageEntry.key = encodeKey(session.roomId, session.senderKey, session.sessionId); + this._store.put(storageEntry); } removeAllForRoom(roomId: string) { diff --git a/src/utils/LRUCache.js b/src/utils/LRUCache.js index 183fc995..98ded41b 100644 --- a/src/utils/LRUCache.js +++ b/src/utils/LRUCache.js @@ -24,6 +24,9 @@ export class BaseLRUCache { this._entries = []; } + get size() { return this._entries.length; } + get limit() { return this._limit; } + _get(findEntryFn) { const idx = this._entries.findIndex(findEntryFn); if (idx !== -1) {