Merge pull request #554 from vector-im/bwindels/fix-551

Only keep a limited amount of olm InboundGroupSession objects in memory to prevent OOM
This commit is contained in:
Bruno Windels 2021-10-26 11:30:10 +02:00 committed by GitHub
commit fae4493abc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 957 additions and 558 deletions

View file

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {OLM_ALGORITHM} from "./e2ee/common.js"; import {OLM_ALGORITHM} from "./e2ee/common.js";
import {countBy, groupBy} from "../utils/groupBy.js"; import {countBy, groupBy} from "../utils/groupBy";
export class DeviceMessageHandler { export class DeviceMessageHandler {
constructor({storage}) { constructor({storage}) {
@ -67,12 +67,4 @@ class SyncPreparation {
this.newRoomKeys = newRoomKeys; this.newRoomKeys = newRoomKeys;
this.newKeysByRoom = groupBy(newRoomKeys, r => r.roomId); this.newKeysByRoom = groupBy(newRoomKeys, r => r.roomId);
} }
dispose() {
if (this.newRoomKeys) {
for (const k of this.newRoomKeys) {
k.dispose();
}
}
}
} }

View file

@ -26,14 +26,15 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
import {Account as E2EEAccount} from "./e2ee/Account.js"; import {Account as E2EEAccount} from "./e2ee/Account.js";
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js"; import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js";
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption";
import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader";
import {SessionBackup} from "./e2ee/megolm/SessionBackup.js"; import {SessionBackup} from "./e2ee/megolm/SessionBackup.js";
import {Encryption as MegOlmEncryption} from "./e2ee/megolm/Encryption.js"; import {Encryption as MegOlmEncryption} from "./e2ee/megolm/Encryption.js";
import {MEGOLM_ALGORITHM} from "./e2ee/common.js"; import {MEGOLM_ALGORITHM} from "./e2ee/common.js";
import {RoomEncryption} from "./e2ee/RoomEncryption.js"; import {RoomEncryption} from "./e2ee/RoomEncryption.js";
import {DeviceTracker} from "./e2ee/DeviceTracker.js"; import {DeviceTracker} from "./e2ee/DeviceTracker.js";
import {LockMap} from "../utils/LockMap.js"; import {LockMap} from "../utils/LockMap.js";
import {groupBy} from "../utils/groupBy.js"; import {groupBy} from "../utils/groupBy";
import { import {
keyFromCredential as ssssKeyFromCredential, keyFromCredential as ssssKeyFromCredential,
readKey as ssssReadKey, readKey as ssssReadKey,
@ -137,11 +138,8 @@ export class Session {
now: this._platform.clock.now, now: this._platform.clock.now,
ownDeviceId: this._sessionInfo.deviceId, ownDeviceId: this._sessionInfo.deviceId,
}); });
this._megolmDecryption = new MegOlmDecryption({ const keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20);
pickleKey: PICKLE_KEY, this._megolmDecryption = new MegOlmDecryption(keyLoader, this._olmWorker);
olm: this._olm,
olmWorker: this._olmWorker,
});
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption});
} }
@ -319,6 +317,7 @@ export class Session {
dispose() { dispose() {
this._olmWorker?.dispose(); this._olmWorker?.dispose();
this._sessionBackup?.dispose(); this._sessionBackup?.dispose();
this._megolmDecryption.dispose();
for (const room of this._rooms.values()) { for (const room of this._rooms.values()) {
room.dispose(); room.dispose();
} }

View file

@ -464,7 +464,6 @@ class SessionSyncProcessState {
dispose() { dispose() {
this.lock?.release(); this.lock?.release();
this.preparation?.dispose();
} }
} }

View file

@ -29,10 +29,10 @@ limitations under the License.
export class DecryptionResult { export class DecryptionResult {
constructor(event, senderCurve25519Key, claimedKeys) { constructor(event, senderCurve25519Key, claimedEd25519Key) {
this.event = event; this.event = event;
this.senderCurve25519Key = senderCurve25519Key; this.senderCurve25519Key = senderCurve25519Key;
this.claimedEd25519Key = claimedKeys.ed25519; this.claimedEd25519Key = claimedEd25519Key;
this._device = null; this._device = null;
this._roomTracked = true; this._roomTracked = true;
} }

View file

@ -15,9 +15,9 @@ limitations under the License.
*/ */
import {MEGOLM_ALGORITHM, DecryptionSource} from "./common.js"; import {MEGOLM_ALGORITHM, DecryptionSource} from "./common.js";
import {groupEventsBySession} from "./megolm/decryption/utils.js"; import {groupEventsBySession} from "./megolm/decryption/utils";
import {mergeMap} from "../../utils/mergeMap.js"; import {mergeMap} from "../../utils/mergeMap.js";
import {groupBy} from "../../utils/groupBy.js"; import {groupBy} from "../../utils/groupBy";
import {makeTxnId} from "../common.js"; import {makeTxnId} from "../common.js";
const ENCRYPTED_TYPE = "m.room.encrypted"; const ENCRYPTED_TYPE = "m.room.encrypted";
@ -36,8 +36,6 @@ export class RoomEncryption {
this._megolmDecryption = megolmDecryption; this._megolmDecryption = megolmDecryption;
// content of the m.room.encryption event // content of the m.room.encryption event
this._encryptionParams = encryptionParams; this._encryptionParams = encryptionParams;
this._megolmBackfillCache = this._megolmDecryption.createSessionCache();
this._megolmSyncCache = this._megolmDecryption.createSessionCache(1);
// caches devices to verify events // caches devices to verify events
this._senderDeviceCache = new Map(); this._senderDeviceCache = new Map();
this._storage = storage; this._storage = storage;
@ -76,9 +74,6 @@ export class RoomEncryption {
} }
notifyTimelineClosed() { notifyTimelineClosed() {
// empty the backfill cache when closing the timeline
this._megolmBackfillCache.dispose();
this._megolmBackfillCache = this._megolmDecryption.createSessionCache();
this._senderDeviceCache = new Map(); // purge the sender device cache this._senderDeviceCache = new Map(); // purge the sender device cache
} }
@ -112,27 +107,8 @@ export class RoomEncryption {
} }
validEvents.push(event); validEvents.push(event);
} }
let customCache;
let sessionCache;
// we have different caches so we can keep them small but still
// have backfill and sync not invalidate each other
if (source === DecryptionSource.Sync) {
sessionCache = this._megolmSyncCache;
} else if (source === DecryptionSource.Timeline) {
sessionCache = this._megolmBackfillCache;
} else if (source === DecryptionSource.Retry) {
// when retrying, we could have mixed events from at the bottom of the timeline (sync)
// and somewhere else, so create a custom cache we use just for this operation.
customCache = this._megolmDecryption.createSessionCache();
sessionCache = customCache;
} else {
throw new Error("Unknown source: " + source);
}
const preparation = await this._megolmDecryption.prepareDecryptAll( const preparation = await this._megolmDecryption.prepareDecryptAll(
this._room.id, validEvents, newKeys, sessionCache, txn); this._room.id, validEvents, newKeys, txn);
if (customCache) {
customCache.dispose();
}
return new DecryptionPreparation(preparation, errors, source, this, events); return new DecryptionPreparation(preparation, errors, source, this, events);
} }
@ -208,33 +184,27 @@ export class RoomEncryption {
try { try {
const session = await this._sessionBackup.getSession(this._room.id, sessionId, log); const session = await this._sessionBackup.getSession(this._room.id, sessionId, log);
if (session?.algorithm === MEGOLM_ALGORITHM) { if (session?.algorithm === MEGOLM_ALGORITHM) {
if (session["sender_key"] !== senderKey) {
log.set("wrong_sender_key", session["sender_key"]);
log.logLevel = log.level.Warn;
return;
}
let roomKey = this._megolmDecryption.roomKeyFromBackup(this._room.id, sessionId, session); let roomKey = this._megolmDecryption.roomKeyFromBackup(this._room.id, sessionId, session);
if (roomKey) { if (roomKey) {
if (roomKey.senderKey !== senderKey) {
log.set("wrong_sender_key", roomKey.senderKey);
log.logLevel = log.level.Warn;
return;
}
let keyIsBestOne = false; let keyIsBestOne = false;
let retryEventIds; let retryEventIds;
const txn = await this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]);
try { try {
const txn = await this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]); keyIsBestOne = await this._megolmDecryption.writeRoomKey(roomKey, txn);
try { log.set("isBetter", keyIsBestOne);
keyIsBestOne = await this._megolmDecryption.writeRoomKey(roomKey, txn); if (keyIsBestOne) {
log.set("isBetter", keyIsBestOne); retryEventIds = roomKey.eventIds;
if (keyIsBestOne) {
retryEventIds = roomKey.eventIds;
}
} catch (err) {
txn.abort();
throw err;
} }
await txn.complete(); } catch (err) {
} finally { txn.abort();
// can still access properties on it afterwards throw err;
// this is just clearing the internal sessionInfo
roomKey.dispose();
} }
await txn.complete();
if (keyIsBestOne) { if (keyIsBestOne) {
await log.wrap("retryDecryption", log => this._room.notifyRoomKey(roomKey, retryEventIds || [], log)); await log.wrap("retryDecryption", log => this._room.notifyRoomKey(roomKey, retryEventIds || [], log));
} }
@ -466,8 +436,6 @@ export class RoomEncryption {
dispose() { dispose() {
this._disposed = true; this._disposed = true;
this._megolmBackfillCache.dispose();
this._megolmSyncCache.dispose();
} }
} }

View file

@ -15,23 +15,26 @@ limitations under the License.
*/ */
import {DecryptionError} from "../common.js"; import {DecryptionError} from "../common.js";
import * as RoomKey from "./decryption/RoomKey.js";
import {SessionInfo} from "./decryption/SessionInfo.js";
import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js"; import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js";
import {SessionDecryption} from "./decryption/SessionDecryption.js"; import {SessionDecryption} from "./decryption/SessionDecryption";
import {SessionCache} from "./decryption/SessionCache.js";
import {MEGOLM_ALGORITHM} from "../common.js"; import {MEGOLM_ALGORITHM} from "../common.js";
import {validateEvent, groupEventsBySession} from "./decryption/utils.js"; import {validateEvent, groupEventsBySession} from "./decryption/utils";
import {keyFromStorage, keyFromDeviceMessage, keyFromBackup} from "./decryption/RoomKey";
import type {RoomKey, IncomingRoomKey} from "./decryption/RoomKey";
import type {KeyLoader} from "./decryption/KeyLoader";
import type {OlmWorker} from "../OlmWorker";
import type {Transaction} from "../../storage/idb/Transaction";
import type {TimelineEvent} from "../../storage/types";
import type {DecryptionResult} from "../DecryptionResult";
import type {LogItem} from "../../../logging/LogItem";
export class Decryption { export class Decryption {
constructor({pickleKey, olm, olmWorker}) { private keyLoader: KeyLoader;
this._pickleKey = pickleKey; private olmWorker?: OlmWorker;
this._olm = olm;
this._olmWorker = olmWorker;
}
createSessionCache(size) { constructor(keyLoader: KeyLoader, olmWorker: OlmWorker | undefined) {
return new SessionCache(size); this.keyLoader = keyLoader;
this.olmWorker = olmWorker;
} }
async addMissingKeyEventIds(roomId, senderKey, sessionId, eventIds, txn) { async addMissingKeyEventIds(roomId, senderKey, sessionId, eventIds, txn) {
@ -75,9 +78,9 @@ export class Decryption {
* @param {[type]} txn [description] * @param {[type]} txn [description]
* @return {DecryptionPreparation} * @return {DecryptionPreparation}
*/ */
async prepareDecryptAll(roomId, events, newKeys, sessionCache, txn) { async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IncomingRoomKey[] | undefined, txn: Transaction) {
const errors = new Map(); const errors = new Map();
const validEvents = []; const validEvents: TimelineEvent[] = [];
for (const event of events) { for (const event of events) {
if (validateEvent(event)) { if (validateEvent(event)) {
@ -89,11 +92,11 @@ export class Decryption {
const eventsBySession = groupEventsBySession(validEvents); const eventsBySession = groupEventsBySession(validEvents);
const sessionDecryptions = []; const sessionDecryptions: SessionDecryption[] = [];
await Promise.all(Array.from(eventsBySession.values()).map(async group => { await Promise.all(Array.from(eventsBySession.values()).map(async group => {
const sessionInfo = await this._getSessionInfo(roomId, group.senderKey, group.sessionId, newKeys, sessionCache, txn); const key = await this.getRoomKey(roomId, group.senderKey!, group.sessionId!, newKeys, txn);
if (sessionInfo) { if (key) {
sessionDecryptions.push(new SessionDecryption(sessionInfo, group.events, this._olmWorker)); sessionDecryptions.push(new SessionDecryption(key, group.events, this.olmWorker, this.keyLoader));
} else { } else {
for (const event of group.events) { for (const event of group.events) {
errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event)); errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event));
@ -104,63 +107,43 @@ export class Decryption {
return new DecryptionPreparation(roomId, sessionDecryptions, errors); return new DecryptionPreparation(roomId, sessionDecryptions, errors);
} }
async _getSessionInfo(roomId, senderKey, sessionId, newKeys, sessionCache, txn) { private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IncomingRoomKey[] | undefined, txn: Transaction): Promise<RoomKey | undefined> {
let sessionInfo;
if (newKeys) { if (newKeys) {
const key = newKeys.find(k => k.roomId === roomId && k.senderKey === senderKey && k.sessionId === sessionId); const key = newKeys.find(k => k.isForSession(roomId, senderKey, sessionId));
if (key) { if (key && await key.checkBetterThanKeyInStorage(this.keyLoader, txn)) {
sessionInfo = await key.createSessionInfo(this._olm, this._pickleKey, txn); return key;
if (sessionInfo) {
sessionCache.add(sessionInfo);
}
} }
} }
// look only in the cache after looking into newKeys as it may contains that are better // look only in the cache after looking into newKeys as it may contains that are better
if (!sessionInfo) { const cachedKey = this.keyLoader.getCachedKey(roomId, senderKey, sessionId);
sessionInfo = sessionCache.get(roomId, senderKey, sessionId); if (cachedKey) {
return cachedKey;
} }
if (!sessionInfo) { const storageKey = await keyFromStorage(roomId, senderKey, sessionId, txn);
const sessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); if (storageKey && storageKey.serializationKey) {
if (sessionEntry && sessionEntry.session) { return storageKey;
let session = new this._olm.InboundGroupSession();
try {
session.unpickle(this._pickleKey, sessionEntry.session);
sessionInfo = new SessionInfo(roomId, senderKey, session, sessionEntry.claimedKeys);
} catch (err) {
session.free();
throw err;
}
sessionCache.add(sessionInfo);
}
} }
return sessionInfo;
} }
/** /**
* Writes the key as an inbound group session if there is not already a better key in the store * Writes the key as an inbound group session if there is not already a better key in the store
* @param {RoomKey} key
* @param {Transaction} txn a storage transaction with read/write on inboundGroupSessions
* @return {Promise<boolean>} whether the key was the best for the sessio id and was written
*/ */
writeRoomKey(key, txn) { writeRoomKey(key: IncomingRoomKey, txn: Transaction): Promise<boolean> {
return key.write(this._olm, this._pickleKey, txn); return key.write(this.keyLoader, txn);
} }
/** /**
* Extracts room keys from decrypted device messages. * Extracts room keys from decrypted device messages.
* The key won't be persisted yet, you need to call RoomKey.write for that. * The key won't be persisted yet, you need to call RoomKey.write for that.
*
* @param {Array<OlmDecryptionResult>} decryptionResults, any non megolm m.room_key messages will be ignored.
* @return {Array<RoomKey>} an array with validated RoomKey's. Note that it is possible we already have a better version of this key in storage though; writing the key will tell you so.
*/ */
roomKeysFromDeviceMessages(decryptionResults, log) { roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IncomingRoomKey[] {
let keys = []; const keys: IncomingRoomKey[] = [];
for (const dr of decryptionResults) { for (const dr of decryptionResults) {
if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) { if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) {
continue; continue;
} }
log.wrap("room_key", log => { log.wrap("room_key", log => {
const key = RoomKey.fromDeviceMessage(dr); const key = keyFromDeviceMessage(dr);
if (key) { if (key) {
log.set("roomId", key.roomId); log.set("roomId", key.roomId);
log.set("id", key.sessionId); log.set("id", key.sessionId);
@ -174,8 +157,11 @@ export class Decryption {
return keys; return keys;
} }
roomKeyFromBackup(roomId, sessionId, sessionInfo) { roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IncomingRoomKey | undefined {
return RoomKey.fromBackup(roomId, sessionId, sessionInfo); return keyFromBackup(roomId, sessionId, sessionInfo);
}
dispose() {
this.keyLoader.dispose();
} }
} }

View file

@ -0,0 +1,433 @@
/*
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {isBetterThan, IncomingRoomKey} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache";
import type {RoomKey} from "./RoomKey";
export declare class OlmDecryptionResult {
readonly plaintext: string;
readonly message_index: number;
}
export declare class OlmInboundGroupSession {
constructor();
free(): void;
pickle(key: string | Uint8Array): string;
unpickle(key: string | Uint8Array, pickle: string);
create(session_key: string): string;
import_session(session_key: string): string;
decrypt(message: string): OlmDecryptionResult;
session_id(): string;
first_known_index(): number;
export_session(message_index: number): string;
}
/*
Because Olm only has very limited memory available when compiled to wasm,
we limit the amount of sessions held in memory.
*/
export class KeyLoader extends BaseLRUCache<KeyOperation> {
private pickleKey: string;
private olm: any;
private resolveUnusedOperation?: () => void;
private operationBecomesUnusedPromise?: Promise<void>;
constructor(olm: any, pickleKey: string, limit: number) {
super(limit);
this.pickleKey = pickleKey;
this.olm = olm;
}
getCachedKey(roomId: string, senderKey: string, sessionId: string): RoomKey | undefined {
const idx = this.findCachedKeyIndex(roomId, senderKey, sessionId);
if (idx !== -1) {
return this._getByIndexAndMoveUp(idx)!.key;
}
}
async useKey<T>(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const keyOp = await this.allocateOperation(key);
try {
return await callback(keyOp.session, this.pickleKey);
} finally {
this.releaseOperation(keyOp);
}
}
get running() {
return this._entries.some(op => op.refCount !== 0);
}
dispose() {
for (let i = 0; i < this._entries.length; i += 1) {
this._entries[i].dispose();
}
// remove all entries
this._entries.splice(0, this._entries.length);
}
private async allocateOperation(key: RoomKey): Promise<KeyOperation> {
let idx;
while((idx = this.findIndexForAllocation(key)) === -1) {
await this.operationBecomesUnused();
}
if (idx < this.size) {
const op = this._getByIndexAndMoveUp(idx)!;
// cache hit
if (op.isForKey(key)) {
op.refCount += 1;
return op;
} else {
// refCount should be 0 here
op.refCount = 1;
op.key = key;
key.loadInto(op.session, this.pickleKey);
}
return op;
} else {
// create new operation
const session = new this.olm.InboundGroupSession();
key.loadInto(session, this.pickleKey);
const op = new KeyOperation(key, session);
this._set(op);
return op;
}
}
private releaseOperation(op: KeyOperation) {
op.refCount -= 1;
if (op.refCount <= 0 && this.resolveUnusedOperation) {
this.resolveUnusedOperation();
// promise is resolved now, we'll need a new one for next await so clear
this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined;
}
}
private operationBecomesUnused(): Promise<void> {
if (!this.operationBecomesUnusedPromise) {
this.operationBecomesUnusedPromise = new Promise(resolve => {
this.resolveUnusedOperation = resolve;
});
}
return this.operationBecomesUnusedPromise;
}
private findIndexForAllocation(key: RoomKey) {
let idx = this.findIndexSameKey(key); // cache hit
if (idx === -1) {
if (this.size < this.limit) {
idx = this.size;
} else {
idx = this.findIndexSameSessionUnused(key);
if (idx === -1) {
idx = this.findIndexOldestUnused();
}
}
}
return idx;
}
private findCachedKeyIndex(roomId: string, senderKey: string, sessionId: string): number {
return this._entries.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
// only operations that are the "best" for their session can be used, see comment on isBest
if (op.isBest === true && op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
}
return bestIdx;
}, -1);
}
private findIndexSameKey(key: RoomKey): number {
return this._entries.findIndex(op => {
return op.isForSameSession(key.roomId, key.senderKey, key.sessionId) && op.isForKey(key);
});
}
private findIndexSameSessionUnused(key: RoomKey): number {
return this._entries.reduce((worstIdx, op, i, arr) => {
const worst = worstIdx === -1 ? undefined : arr[worstIdx];
// we try to pick the worst operation to overwrite, so the best one stays in the cache
if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) {
if (!worst || !op.isBetter(worst)) {
return i;
}
}
return worstIdx;
}, -1);
}
private findIndexOldestUnused(): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
if (op.refCount === 0) {
return i;
}
}
return -1;
}
}
class KeyOperation {
session: OlmInboundGroupSession;
key: RoomKey;
refCount: number;
constructor(key: RoomKey, session: OlmInboundGroupSession) {
this.key = key;
this.session = session;
this.refCount = 1;
}
isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean {
return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId;
}
// assumes isForSameSession is true
isBetter(other: KeyOperation) {
return isBetterThan(this.session, other.session);
}
isForKey(key: RoomKey) {
return this.key.serializationKey === key.serializationKey &&
this.key.serializationType === key.serializationType;
}
dispose() {
this.session.free();
}
/** returns whether the key for this operation has been checked at some point against storage
* and was determined to be the better key, undefined if it hasn't been checked yet.
* Only keys that are the best keys can be returned by getCachedKey as returning a cache hit
* will usually not check for a better session in storage. Also see RoomKey.isBetter. */
get isBest(): boolean | undefined {
return this.key.isBetter;
}
}
export function tests() {
let instances = 0;
class MockRoomKey extends IncomingRoomKey {
private _roomId: string;
private _senderKey: string;
private _sessionId: string;
private _firstKnownIndex: number;
constructor(roomId: string, senderKey: string, sessionId: string, firstKnownIndex: number) {
super();
this._roomId = roomId;
this._senderKey = senderKey;
this._sessionId = sessionId;
this._firstKnownIndex = firstKnownIndex;
}
get roomId(): string { return this._roomId; }
get senderKey(): string { return this._senderKey; }
get sessionId(): string { return this._sessionId; }
get claimedEd25519Key(): string { return "claimedEd25519Key"; }
get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; }
get serializationType(): string { return "type"; }
get eventIds(): string[] | undefined { return undefined; }
loadInto(session: OlmInboundGroupSession) {
const mockSession = session as MockInboundSession;
mockSession.sessionId = this.sessionId;
mockSession.firstKnownIndex = this._firstKnownIndex;
}
}
class MockInboundSession {
public sessionId: string = "";
public firstKnownIndex: number = 0;
constructor() {
instances += 1;
}
free(): void { instances -= 1; }
pickle(key: string | Uint8Array): string { return `${this.sessionId}-pickled-session`; }
unpickle(key: string | Uint8Array, pickle: string) {}
create(session_key: string): string { return `${this.sessionId}-created-session`; }
import_session(session_key: string): string { return ""; }
decrypt(message: string): OlmDecryptionResult { return {} as OlmDecryptionResult; }
session_id(): string { return this.sessionId; }
first_known_index(): number { return this.firstKnownIndex; }
export_session(message_index: number): string { return `${this.sessionId}-exported-session`; }
}
const PICKLE_KEY = "🥒🔑";
const olm = {InboundGroupSession: MockInboundSession};
const roomId = "!abc:hs.tld";
const aliceSenderKey = "abc";
const bobSenderKey = "def";
const sessionId1 = "s123";
const sessionId2 = "s456";
return {
"load key gives correct session": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
let callback1Called = false;
let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
callback1Called = true;
assert.equal(session.session_id(), sessionId1);
assert.equal(session.first_known_index(), 1);
await Promise.resolve(); // make sure they are busy in parallel
});
const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 2), async session => {
callback2Called = true;
assert.equal(session.session_id(), sessionId2);
assert.equal(session.first_known_index(), 2);
await Promise.resolve(); // make sure they are busy in parallel
});
assert.equal(loader.size, 2);
await Promise.all([p1, p2]);
assert(callback1Called);
assert(callback2Called);
},
"keys with different first index are kept separate": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
let callback1Called = false;
let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
callback1Called = true;
assert.equal(session.session_id(), sessionId1);
assert.equal(session.first_known_index(), 1);
await Promise.resolve(); // make sure they are busy in parallel
});
const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2), async session => {
callback2Called = true;
assert.equal(session.session_id(), sessionId1);
assert.equal(session.first_known_index(), 2);
await Promise.resolve(); // make sure they are busy in parallel
});
assert.equal(loader.size, 2);
await Promise.all([p1, p2]);
assert(callback1Called);
assert(callback2Called);
},
"useKey blocks as long as no free sessions are available": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1);
let resolve;
let callbackCalled = false;
loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
await new Promise(r => resolve = r);
});
await Promise.resolve();
assert.equal(loader.size, 1);
const promise = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), session => {
callbackCalled = true;
});
assert.equal(callbackCalled, false);
resolve();
await promise;
assert.equal(callbackCalled, true);
},
"cache hit while key in use, then replace (check refCount works properly)": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1);
let resolve1, resolve2;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
const p1 = loader.useKey(key1, async session => {
await new Promise(r => resolve1 = r);
});
const p2 = loader.useKey(key1, async session => {
await new Promise(r => resolve2 = r);
});
await Promise.resolve();
assert.equal(loader.size, 1);
assert.equal(loader.running, true);
resolve1();
await p1;
assert.equal(loader.running, true);
resolve2();
await p2;
assert.equal(loader.running, false);
let callbackCalled = false;
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {
callbackCalled = true;
assert.equal(session.session_id(), sessionId2);
assert.equal(session.first_known_index(), 1);
});
assert.equal(loader.size, 1);
assert.equal(callbackCalled, true);
},
"cache hit while key not in use": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
let resolve1, resolve2, invocations = 0;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
await loader.useKey(key1, async session => { invocations += 1; });
key1.isBetter = true;
assert.equal(loader.size, 1);
const cachedKey = loader.getCachedKey(roomId, aliceSenderKey, sessionId1)!;
assert.equal(cachedKey, key1);
await loader.useKey(cachedKey, async session => { invocations += 1; });
assert.equal(loader.size, 1);
assert.equal(invocations, 2);
},
"dispose calls free on all sessions": async assert => {
instances = 0;
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {});
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {});
assert.equal(instances, 2);
assert.equal(loader.size, 2);
loader.dispose();
assert.strictEqual(instances, 0, "instances");
assert.strictEqual(loader.size, 0, "loader.size");
},
"checkBetterThanKeyInStorage false with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {});
// fake we've checked with storage that this is the best key,
// and as long is it remains the best key with newly added keys,
// it will be returned from getCachedKey (as called from checkBetterThanKeyInStorage)
key1.isBetter = true;
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3);
// this will hit cache of key 1 so we pass in null as txn
const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any);
assert.strictEqual(isBetter, false);
assert.strictEqual(key2.isBetter, false);
},
"checkBetterThanKeyInStorage true with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key
await loader.useKey(key1, async session => {});
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
// this will hit cache of key 1 so we pass in null as txn
const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any);
assert.strictEqual(isBetter, true);
assert.strictEqual(key2.isBetter, true);
},
"prefer to remove worst key for a session from cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {});
key1.isBetter = true; // set to true just so it gets returned from getCachedKey
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 4);
await loader.useKey(key2, async session => {});
const key3 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3);
await loader.useKey(key3, async session => {});
assert.strictEqual(loader.getCachedKey(roomId, aliceSenderKey, sessionId1), key1);
},
}
}

View file

@ -14,11 +14,24 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import type {TimelineEvent} from "../../../storage/types";
export class ReplayDetectionEntry { export class ReplayDetectionEntry {
constructor(sessionId, messageIndex, event) { public readonly sessionId: string;
public readonly messageIndex: number;
public readonly event: TimelineEvent;
constructor(sessionId: string, messageIndex: number, event: TimelineEvent) {
this.sessionId = sessionId; this.sessionId = sessionId;
this.messageIndex = messageIndex; this.messageIndex = messageIndex;
this.eventId = event.event_id; this.event = event;
this.timestamp = event.origin_server_ts; }
get eventId(): string {
return this.event.event_id;
}
get timestamp(): number {
return this.event.origin_server_ts;
} }
} }

View file

@ -1,166 +0,0 @@
import {SessionInfo} from "./SessionInfo.js";
export class BaseRoomKey {
constructor() {
this._sessionInfo = null;
this._isBetter = null;
this._eventIds = null;
}
async createSessionInfo(olm, pickleKey, txn) {
if (this._isBetter === false) {
return;
}
const session = new olm.InboundGroupSession();
try {
this._loadSessionKey(session);
this._isBetter = await this._isBetterThanKnown(session, olm, pickleKey, txn);
if (this._isBetter) {
const claimedKeys = {ed25519: this.claimedEd25519Key};
this._sessionInfo = new SessionInfo(this.roomId, this.senderKey, session, claimedKeys);
// retain the session so we don't have to create a new session during write.
this._sessionInfo.retain();
return this._sessionInfo;
} else {
session.free();
return;
}
} catch (err) {
this._sessionInfo = null;
session.free();
throw err;
}
}
async _isBetterThanKnown(session, olm, pickleKey, txn) {
let isBetter = true;
// TODO: we could potentially have a small speedup here if we looked first in the SessionCache here...
const existingSessionEntry = await txn.inboundGroupSessions.get(this.roomId, this.senderKey, this.sessionId);
if (existingSessionEntry?.session) {
const existingSession = new olm.InboundGroupSession();
try {
existingSession.unpickle(pickleKey, existingSessionEntry.session);
isBetter = session.first_known_index() < existingSession.first_known_index();
} finally {
existingSession.free();
}
}
// store the event ids that can be decrypted with this key
// before we overwrite them if called from `write`.
if (existingSessionEntry?.eventIds) {
this._eventIds = existingSessionEntry.eventIds;
}
return isBetter;
}
async write(olm, pickleKey, txn) {
// we checked already and we had a better session in storage, so don't write
if (this._isBetter === false) {
return false;
}
if (!this._sessionInfo) {
await this.createSessionInfo(olm, pickleKey, txn);
}
if (this._sessionInfo) {
const session = this._sessionInfo.session;
const sessionEntry = {
roomId: this.roomId,
senderKey: this.senderKey,
sessionId: this.sessionId,
session: session.pickle(pickleKey),
claimedKeys: this._sessionInfo.claimedKeys,
};
txn.inboundGroupSessions.set(sessionEntry);
this.dispose();
return true;
}
return false;
}
get eventIds() {
return this._eventIds;
}
dispose() {
if (this._sessionInfo) {
this._sessionInfo.release();
this._sessionInfo = null;
}
}
}
class DeviceMessageRoomKey extends BaseRoomKey {
constructor(decryptionResult) {
super();
this._decryptionResult = decryptionResult;
}
get roomId() { return this._decryptionResult.event.content?.["room_id"]; }
get senderKey() { return this._decryptionResult.senderCurve25519Key; }
get sessionId() { return this._decryptionResult.event.content?.["session_id"]; }
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
_loadSessionKey(session) {
const sessionKey = this._decryptionResult.event.content?.["session_key"];
session.create(sessionKey);
}
}
class BackupRoomKey extends BaseRoomKey {
constructor(roomId, sessionId, backupInfo) {
super();
this._roomId = roomId;
this._sessionId = sessionId;
this._backupInfo = backupInfo;
}
get roomId() { return this._roomId; }
get senderKey() { return this._backupInfo["sender_key"]; }
get sessionId() { return this._sessionId; }
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
_loadSessionKey(session) {
const sessionKey = this._backupInfo["session_key"];
session.import_session(sessionKey);
}
}
export function fromDeviceMessage(dr) {
const roomId = dr.event.content?.["room_id"];
const sessionId = dr.event.content?.["session_id"];
const sessionKey = dr.event.content?.["session_key"];
if (
typeof roomId === "string" ||
typeof sessionId === "string" ||
typeof senderKey === "string" ||
typeof sessionKey === "string"
) {
return new DeviceMessageRoomKey(dr);
}
}
/*
sessionInfo is a response from key backup and has the following keys:
algorithm
forwarding_curve25519_key_chain
sender_claimed_keys
sender_key
session_key
*/
export function fromBackup(roomId, sessionId, sessionInfo) {
const sessionKey = sessionInfo["session_key"];
const senderKey = sessionInfo["sender_key"];
// TODO: can we just trust this?
const claimedEd25519Key = sessionInfo["sender_claimed_keys"]?.["ed25519"];
if (
typeof roomId === "string" &&
typeof sessionId === "string" &&
typeof senderKey === "string" &&
typeof sessionKey === "string" &&
typeof claimedEd25519Key === "string"
) {
return new BackupRoomKey(roomId, sessionId, sessionInfo);
}
}

View file

@ -0,0 +1,245 @@
/*
Copyright 2021 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader";
export abstract class RoomKey {
private _isBetter: boolean | undefined;
isForSession(roomId: string, senderKey: string, sessionId: string) {
return this.roomId === roomId && this.senderKey === senderKey && this.sessionId === sessionId;
}
abstract get roomId(): string;
abstract get senderKey(): string;
abstract get sessionId(): string;
abstract get claimedEd25519Key(): string;
abstract get serializationKey(): string;
abstract get serializationType(): string;
abstract get eventIds(): string[] | undefined;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
/* Whether the key has been checked against storage (or is from storage)
* to be the better key for a given session. Given that all keys are checked to be better
* as part of writing, we can trust that when this returns true, it really is the best key
* available between storage and cached keys in memory. This is why keys with this field set to
* true are used by the key loader to return cached keys. Also see KeyOperation.isBest there. */
get isBetter(): boolean | undefined { return this._isBetter; }
// should only be set in key.checkBetterThanKeyInStorage
set isBetter(value: boolean | undefined) { this._isBetter = value; }
}
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index();
}
export abstract class IncomingRoomKey extends RoomKey {
private _eventIds?: string[];
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
return this._checkBetterThanKeyInStorage(loader, undefined, txn);
}
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
// we checked already and we had a better session in storage, so don't write
let pickledSession;
if (this.isBetter === undefined) {
// if this key wasn't used to decrypt any messages in the same sync,
// we haven't checked if this is the best key yet,
// so do that now to not overwrite a better key.
// while we have the key deserialized, also pickle it to store it later on here.
await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => {
pickledSession = session.pickle(pickleKey);
}, txn);
}
if (this.isBetter === false) {
return false;
}
// before calling write in parallel, we need to check loader.running is false so we are sure our transaction will not be closed
if (!pickledSession) {
pickledSession = await loader.useKey(this, (session, pickleKey) => session.pickle(pickleKey));
}
const sessionEntry = {
roomId: this.roomId,
senderKey: this.senderKey,
sessionId: this.sessionId,
session: pickledSession,
claimedKeys: {"ed25519": this.claimedEd25519Key},
};
txn.inboundGroupSessions.set(sessionEntry);
return true;
}
get eventIds() { return this._eventIds; }
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this.isBetter !== undefined) {
return this.isBetter;
}
let existingKey = loader.getCachedKey(this.roomId, this.senderKey, this.sessionId);
if (!existingKey) {
const storageKey = await keyFromStorage(this.roomId, this.senderKey, this.sessionId, txn);
// store the event ids that can be decrypted with this key
// before we overwrite them if called from `write`.
if (storageKey) {
if (storageKey.hasSession) {
existingKey = storageKey;
} else if (storageKey.eventIds) {
this._eventIds = storageKey.eventIds;
}
}
}
if (existingKey) {
const key = existingKey;
await loader.useKey(this, async newSession => {
await loader.useKey(key, (existingSession, pickleKey) => {
// set isBetter as soon as possible, on both keys compared,
// as it is is used to determine whether a key can be used for the cache
this.isBetter = isBetterThan(newSession, existingSession);
key.isBetter = !this.isBetter;
if (this.isBetter && callback) {
callback(newSession, pickleKey);
}
});
});
} else {
// no previous key, so we're the best \o/
this.isBetter = true;
}
return this.isBetter!;
}
}
class DeviceMessageRoomKey extends IncomingRoomKey {
private _decryptionResult: DecryptionResult;
constructor(decryptionResult: DecryptionResult) {
super();
this._decryptionResult = decryptionResult;
}
get roomId() { return this._decryptionResult.event.content?.["room_id"]; }
get senderKey() { return this._decryptionResult.senderCurve25519Key; }
get sessionId() { return this._decryptionResult.event.content?.["session_id"]; }
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; }
get serializationType(): string { return "create"; }
loadInto(session) {
session.create(this.serializationKey);
}
}
class BackupRoomKey extends IncomingRoomKey {
private _roomId: string;
private _sessionId: string;
private _backupInfo: string;
constructor(roomId, sessionId, backupInfo) {
super();
this._roomId = roomId;
this._sessionId = sessionId;
this._backupInfo = backupInfo;
}
get roomId() { return this._roomId; }
get senderKey() { return this._backupInfo["sender_key"]; }
get sessionId() { return this._sessionId; }
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
get serializationKey(): string { return this._backupInfo["session_key"]; }
get serializationType(): string { return "import_session"; }
loadInto(session) {
session.import_session(this.serializationKey);
}
}
class StoredRoomKey extends RoomKey {
private storageEntry: InboundGroupSessionEntry;
constructor(storageEntry: InboundGroupSessionEntry) {
super();
this.isBetter = true; // usually the key in storage is the best until checks prove otherwise
this.storageEntry = storageEntry;
}
get roomId() { return this.storageEntry.roomId; }
get senderKey() { return this.storageEntry.senderKey; }
get sessionId() { return this.storageEntry.sessionId; }
get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; }
get eventIds() { return this.storageEntry.eventIds; }
get serializationKey(): string { return this.storageEntry.session || ""; }
get serializationType(): string { return "unpickle"; }
loadInto(session, pickleKey) {
session.unpickle(pickleKey, this.serializationKey);
}
get hasSession() {
// sessions are stored before they are received
// to keep track of events that need it to be decrypted.
// This is used to retry decryption of those events once the session is received.
return !!this.serializationKey;
}
}
export function keyFromDeviceMessage(dr: DecryptionResult): DeviceMessageRoomKey | undefined {
const sessionKey = dr.event.content?.["session_key"];
const key = new DeviceMessageRoomKey(dr);
if (
typeof key.roomId === "string" &&
typeof key.sessionId === "string" &&
typeof key.senderKey === "string" &&
typeof sessionKey === "string"
) {
return key;
}
}
/*
sessionInfo is a response from key backup and has the following keys:
algorithm
forwarding_curve25519_key_chain
sender_claimed_keys
sender_key
session_key
*/
export function keyFromBackup(roomId, sessionId, backupInfo): BackupRoomKey | undefined {
const sessionKey = backupInfo["session_key"];
const senderKey = backupInfo["sender_key"];
// TODO: can we just trust this?
const claimedEd25519Key = backupInfo["sender_claimed_keys"]?.["ed25519"];
if (
typeof roomId === "string" &&
typeof sessionId === "string" &&
typeof senderKey === "string" &&
typeof sessionKey === "string" &&
typeof claimedEd25519Key === "string"
) {
return new BackupRoomKey(roomId, sessionId, backupInfo);
}
}
export async function keyFromStorage(roomId: string, senderKey: string, sessionId: string, txn: Transaction): Promise<StoredRoomKey | undefined> {
const existingSessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId);
if (existingSessionEntry) {
return new StoredRoomKey(existingSessionEntry);
}
return;
}

View file

@ -1,61 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {BaseLRUCache} from "../../../../utils/LRUCache.js";
const DEFAULT_CACHE_SIZE = 10;
/**
* Cache of unpickled inbound megolm session.
*/
export class SessionCache extends BaseLRUCache {
constructor(limit) {
limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE;
super(limit);
}
/**
* @param {string} roomId
* @param {string} senderKey
* @param {string} sessionId
* @return {SessionInfo?}
*/
get(roomId, senderKey, sessionId) {
return this._get(s => {
return s.roomId === roomId &&
s.senderKey === senderKey &&
sessionId === s.sessionId;
});
}
add(sessionInfo) {
sessionInfo.retain();
this._set(sessionInfo, s => {
return s.roomId === sessionInfo.roomId &&
s.senderKey === sessionInfo.senderKey &&
s.sessionId === sessionInfo.sessionId;
});
}
_onEvictEntry(sessionInfo) {
sessionInfo.release();
}
dispose() {
for (const sessionInfo of this._entries) {
sessionInfo.release();
}
}
}

View file

@ -1,90 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {DecryptionResult} from "../../DecryptionResult.js";
import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry.js";
/**
* Does the actual decryption of all events for a given megolm session in a batch
*/
export class SessionDecryption {
constructor(sessionInfo, events, olmWorker) {
sessionInfo.retain();
this._sessionInfo = sessionInfo;
this._events = events;
this._olmWorker = olmWorker;
this._decryptionRequests = olmWorker ? [] : null;
}
async decryptAll() {
const replayEntries = [];
const results = new Map();
let errors;
const roomId = this._sessionInfo.roomId;
await Promise.all(this._events.map(async event => {
try {
const {session} = this._sessionInfo;
const ciphertext = event.content.ciphertext;
let decryptionResult;
if (this._olmWorker) {
const request = this._olmWorker.megolmDecrypt(session, ciphertext);
this._decryptionRequests.push(request);
decryptionResult = await request.response();
} else {
decryptionResult = session.decrypt(ciphertext);
}
const plaintext = decryptionResult.plaintext;
const messageIndex = decryptionResult.message_index;
let payload;
try {
payload = JSON.parse(plaintext);
} catch (err) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err});
}
if (payload.room_id !== roomId) {
throw new DecryptionError("MEGOLM_WRONG_ROOM", event,
{encryptedRoomId: payload.room_id, eventRoomId: roomId});
}
replayEntries.push(new ReplayDetectionEntry(session.session_id(), messageIndex, event));
const result = new DecryptionResult(payload, this._sessionInfo.senderKey, this._sessionInfo.claimedKeys);
results.set(event.event_id, result);
} catch (err) {
// ignore AbortError from cancelling decryption requests in dispose method
if (err.name === "AbortError") {
return;
}
if (!errors) {
errors = new Map();
}
errors.set(event.event_id, err);
}
}));
return {results, errors, replayEntries};
}
dispose() {
if (this._decryptionRequests) {
for (const r of this._decryptionRequests) {
r.abort();
}
}
// TODO: cancel decryptions here
this._sessionInfo.release();
}
}

View file

@ -0,0 +1,103 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {DecryptionResult} from "../../DecryptionResult.js";
import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry";
import type {RoomKey} from "./RoomKey.js";
import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader";
import type {OlmWorker} from "../../OlmWorker";
import type {TimelineEvent} from "../../../storage/types";
interface DecryptAllResult {
readonly results: Map<string, DecryptionResult>;
readonly errors?: Map<string, Error>;
readonly replayEntries: ReplayDetectionEntry[];
}
/**
* Does the actual decryption of all events for a given megolm session in a batch
*/
export class SessionDecryption {
private key: RoomKey;
private events: TimelineEvent[];
private keyLoader: KeyLoader;
private olmWorker?: OlmWorker;
private decryptionRequests?: any[];
constructor(key: RoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) {
this.key = key;
this.events = events;
this.olmWorker = olmWorker;
this.keyLoader = keyLoader;
this.decryptionRequests = olmWorker ? [] : undefined;
}
async decryptAll(): Promise<DecryptAllResult> {
const replayEntries: ReplayDetectionEntry[] = [];
const results: Map<string, DecryptionResult> = new Map();
let errors: Map<string, Error> | undefined;
await this.keyLoader.useKey(this.key, async session => {
for (const event of this.events) {
try {
const ciphertext = event.content.ciphertext as string;
let decryptionResult: OlmDecryptionResult | undefined;
// TODO: pass all cipthertexts in one go to the megolm worker and don't deserialize the key until in the worker?
if (this.olmWorker) {
const request = this.olmWorker.megolmDecrypt(session, ciphertext);
this.decryptionRequests!.push(request);
decryptionResult = await request.response();
} else {
decryptionResult = session.decrypt(ciphertext);
}
const {plaintext} = decryptionResult!;
let payload;
try {
payload = JSON.parse(plaintext);
} catch (err) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err});
}
if (payload.room_id !== this.key.roomId) {
throw new DecryptionError("MEGOLM_WRONG_ROOM", event,
{encryptedRoomId: payload.room_id, eventRoomId: this.key.roomId});
}
replayEntries.push(new ReplayDetectionEntry(this.key.sessionId, decryptionResult!.message_index, event));
const result = new DecryptionResult(payload, this.key.senderKey, this.key.claimedEd25519Key);
results.set(event.event_id, result);
} catch (err) {
// ignore AbortError from cancelling decryption requests in dispose method
if (err.name === "AbortError") {
return;
}
if (!errors) {
errors = new Map();
}
errors.set(event.event_id, err);
}
}
});
return {results, errors, replayEntries};
}
dispose() {
if (this.decryptionRequests) {
for (const r of this.decryptionRequests) {
r.abort();
}
}
}
}

View file

@ -1,49 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/**
* session loaded in memory with everything needed to create DecryptionResults
* and to store/retrieve it in the SessionCache
*/
export class SessionInfo {
constructor(roomId, senderKey, session, claimedKeys) {
this.roomId = roomId;
this.senderKey = senderKey;
this.session = session;
this.claimedKeys = claimedKeys;
this._refCounter = 0;
}
get sessionId() {
return this.session?.session_id();
}
retain() {
this._refCounter += 1;
}
release() {
this._refCounter -= 1;
if (this._refCounter <= 0) {
this.dispose();
}
}
dispose() {
this.session.free();
this.session = null;
}
}

View file

@ -14,44 +14,46 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {groupByWithCreator} from "../../../../utils/groupBy.js"; import {groupByWithCreator} from "../../../../utils/groupBy";
import type {TimelineEvent} from "../../../storage/types";
function getSenderKey(event) { function getSenderKey(event: TimelineEvent): string | undefined {
return event.content?.["sender_key"]; return event.content?.["sender_key"];
} }
function getSessionId(event) { function getSessionId(event: TimelineEvent): string | undefined {
return event.content?.["session_id"]; return event.content?.["session_id"];
} }
function getCiphertext(event) { function getCiphertext(event: TimelineEvent): string | undefined {
return event.content?.ciphertext; return event.content?.ciphertext;
} }
export function validateEvent(event) { export function validateEvent(event: TimelineEvent) {
return typeof getSenderKey(event) === "string" && return typeof getSenderKey(event) === "string" &&
typeof getSessionId(event) === "string" && typeof getSessionId(event) === "string" &&
typeof getCiphertext(event) === "string"; typeof getCiphertext(event) === "string";
} }
class SessionKeyGroup { export class SessionKeyGroup {
public readonly events: TimelineEvent[];
constructor() { constructor() {
this.events = []; this.events = [];
} }
get senderKey() { get senderKey(): string | undefined {
return getSenderKey(this.events[0]); return getSenderKey(this.events[0]!);
} }
get sessionId() { get sessionId(): string | undefined {
return getSessionId(this.events[0]); return getSessionId(this.events[0]!);
} }
} }
export function groupEventsBySession(events) { export function groupEventsBySession(events: TimelineEvent[]): Map<string, SessionKeyGroup> {
return groupByWithCreator(events, return groupByWithCreator<string, TimelineEvent, SessionKeyGroup>(events,
event => `${getSenderKey(event)}|${getSessionId(event)}`, (event: TimelineEvent) => `${getSenderKey(event)}|${getSessionId(event)}`,
() => new SessionKeyGroup(), () => new SessionKeyGroup(),
(group, event) => group.events.push(event) (group: SessionKeyGroup, event: TimelineEvent) => group.events.push(event)
); );
} }

View file

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {DecryptionError} from "../common.js"; import {DecryptionError} from "../common.js";
import {groupBy} from "../../../utils/groupBy.js"; import {groupBy} from "../../../utils/groupBy";
import {MultiLock} from "../../../utils/Lock.js"; import {MultiLock} from "../../../utils/Lock.js";
import {Session} from "./Session.js"; import {Session} from "./Session.js";
import {DecryptionResult} from "../DecryptionResult.js"; import {DecryptionResult} from "../DecryptionResult.js";
@ -150,7 +150,7 @@ export class Decryption {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error}); throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error});
} }
this._validatePayload(payload, event); this._validatePayload(payload, event);
return new DecryptionResult(payload, senderKey, payload.keys); return new DecryptionResult(payload, senderKey, payload.keys.ed25519);
} else { } else {
throw new DecryptionError("OLM_NO_MATCHING_SESSION", event, throw new DecryptionError("OLM_NO_MATCHING_SESSION", event,
{knownSessionIds: senderKeyDecryption.sessions.map(s => s.id)}); {knownSessionIds: senderKeyDecryption.sessions.map(s => s.id)});

View file

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {groupByWithCreator} from "../../../utils/groupBy.js"; import {groupByWithCreator} from "../../../utils/groupBy";
import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js"; import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js";
import {createSessionEntry} from "./Session.js"; import {createSessionEntry} from "./Session.js";

View file

@ -76,6 +76,10 @@ export class Room extends BaseRoom {
let eventsToDecrypt = roomResponse?.timeline?.events || []; let eventsToDecrypt = roomResponse?.timeline?.events || [];
// when new keys arrive, also see if any older events can now be retried to decrypt // when new keys arrive, also see if any older events can now be retried to decrypt
if (newKeys) { if (newKeys) {
// TODO: if a key is considered by roomEncryption.prepareDecryptAll to use for decryption,
// key.eventIds will be set. We could somehow try to reuse that work, but retrying also needs
// to happen if a key is not needed to decrypt this sync or there are indeed no encrypted messages
// in this sync at all.
retryEntries = await this._getSyncRetryDecryptEntries(newKeys, roomEncryption, txn); retryEntries = await this._getSyncRetryDecryptEntries(newKeys, roomEncryption, txn);
if (retryEntries.length) { if (retryEntries.length) {
log.set("retry", retryEntries.length); log.set("retry", retryEntries.length);

View file

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {MemberChange, RoomMember, EVENT_TYPE as MEMBER_EVENT_TYPE} from "../../members/RoomMember.js"; import {MemberChange, RoomMember, EVENT_TYPE as MEMBER_EVENT_TYPE} from "../../members/RoomMember.js";
import {LRUCache} from "../../../../utils/LRUCache.js"; import {LRUCache} from "../../../../utils/LRUCache";
export class MemberWriter { export class MemberWriter {
constructor(roomId) { constructor(roomId) {

View file

@ -17,24 +17,26 @@ limitations under the License.
import {MIN_UNICODE, MAX_UNICODE} from "./common"; import {MIN_UNICODE, MAX_UNICODE} from "./common";
import {Store} from "../Store"; import {Store} from "../Store";
interface InboundGroupSession { export interface InboundGroupSessionEntry {
roomId: string; roomId: string;
senderKey: string; senderKey: string;
sessionId: string; sessionId: string;
session?: string; session?: string;
claimedKeys?: { [algorithm : string] : string }; claimedKeys?: { [algorithm : string] : string };
eventIds?: string[]; eventIds?: string[];
key: string;
} }
type InboundGroupSessionStorageEntry = InboundGroupSessionEntry & { key: string };
function encodeKey(roomId: string, senderKey: string, sessionId: string): string { function encodeKey(roomId: string, senderKey: string, sessionId: string): string {
return `${roomId}|${senderKey}|${sessionId}`; return `${roomId}|${senderKey}|${sessionId}`;
} }
export class InboundGroupSessionStore { export class InboundGroupSessionStore {
private _store: Store<InboundGroupSession>; private _store: Store<InboundGroupSessionStorageEntry>;
constructor(store: Store<InboundGroupSession>) { constructor(store: Store<InboundGroupSessionStorageEntry>) {
this._store = store; this._store = store;
} }
@ -44,13 +46,14 @@ export class InboundGroupSessionStore {
return key === fetchedKey; return key === fetchedKey;
} }
get(roomId: string, senderKey: string, sessionId: string): Promise<InboundGroupSession | null> { get(roomId: string, senderKey: string, sessionId: string): Promise<InboundGroupSessionEntry | null> {
return this._store.get(encodeKey(roomId, senderKey, sessionId)); return this._store.get(encodeKey(roomId, senderKey, sessionId));
} }
set(session: InboundGroupSession): void { set(session: InboundGroupSessionEntry): void {
session.key = encodeKey(session.roomId, session.senderKey, session.sessionId); const storageEntry = session as InboundGroupSessionStorageEntry;
this._store.put(session); storageEntry.key = encodeKey(session.roomId, session.senderKey, session.sessionId);
this._store.put(storageEntry);
} }
removeAllForRoom(roomId: string) { removeAllForRoom(roomId: string) {

View file

@ -14,18 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
type FindCallback<T> = (value: T) => boolean;
/** /**
* Very simple least-recently-used cache implementation * Very simple least-recently-used cache implementation
* that should be fast enough for very small cache sizes * that should be fast enough for very small cache sizes
*/ */
export class BaseLRUCache { export class BaseLRUCache<T> {
constructor(limit) {
this._limit = limit; public readonly limit: number;
protected _entries: T[];
constructor(limit: number) {
this.limit = limit;
this._entries = []; this._entries = [];
} }
_get(findEntryFn) { get size() { return this._entries.length; }
const idx = this._entries.findIndex(findEntryFn);
protected _get(findEntryFn: FindCallback<T>) {
return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn));
}
protected _getByIndexAndMoveUp(idx: number) {
if (idx !== -1) { if (idx !== -1) {
const entry = this._entries[idx]; const entry = this._entries[idx];
// move to top // move to top
@ -37,11 +48,11 @@ export class BaseLRUCache {
} }
} }
_set(value, findEntryFn) { protected _set(value: T, findEntryFn?: FindCallback<T>) {
let indexToRemove = this._entries.findIndex(findEntryFn); let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1;
this._entries.unshift(value); this._entries.unshift(value);
if (indexToRemove === -1) { if (indexToRemove === -1) {
if (this._entries.length > this._limit) { if (this._entries.length > this.limit) {
indexToRemove = this._entries.length - 1; indexToRemove = this._entries.length - 1;
} }
} else { } else {
@ -49,75 +60,82 @@ export class BaseLRUCache {
indexToRemove += 1; indexToRemove += 1;
} }
if (indexToRemove !== -1) { if (indexToRemove !== -1) {
this._onEvictEntry(this._entries[indexToRemove]); this.onEvictEntry(this._entries[indexToRemove]);
this._entries.splice(indexToRemove, 1); this._entries.splice(indexToRemove, 1);
} }
} }
_onEvictEntry() {} protected onEvictEntry(entry: T) {}
} }
export class LRUCache extends BaseLRUCache { export class LRUCache<T, K> extends BaseLRUCache<T> {
constructor(limit, keyFn) { private _keyFn: (T) => K;
constructor(limit, keyFn: (T) => K) {
super(limit); super(limit);
this._keyFn = keyFn; this._keyFn = keyFn;
} }
get(key) { get(key: K): T | undefined {
return this._get(e => this._keyFn(e) === key); return this._get(e => this._keyFn(e) === key);
} }
set(value) { set(value: T) {
const key = this._keyFn(value); const key = this._keyFn(value);
this._set(value, e => this._keyFn(e) === key); this._set(value, e => this._keyFn(e) === key);
} }
} }
export function tests() { export function tests() {
interface NameTuple {
id: number;
name: string;
}
return { return {
"can retrieve added entries": assert => { "can retrieve added entries": assert => {
const cache = new LRUCache(2, e => e.id); const cache = new LRUCache<NameTuple, number>(2, e => e.id);
cache.set({id: 1, name: "Alice"}); cache.set({id: 1, name: "Alice"});
cache.set({id: 2, name: "Bob"}); cache.set({id: 2, name: "Bob"});
assert.equal(cache.get(1).name, "Alice"); assert.equal(cache.get(1)!.name, "Alice");
assert.equal(cache.get(2).name, "Bob"); assert.equal(cache.get(2)!.name, "Bob");
}, },
"first entry is evicted first": assert => { "first entry is evicted first": assert => {
const cache = new LRUCache(2, e => e.id); const cache = new LRUCache<NameTuple, number>(2, e => e.id);
cache.set({id: 1, name: "Alice"}); cache.set({id: 1, name: "Alice"});
cache.set({id: 2, name: "Bob"}); cache.set({id: 2, name: "Bob"});
cache.set({id: 3, name: "Charly"}); cache.set({id: 3, name: "Charly"});
assert.equal(cache.get(1), undefined); assert.equal(cache.get(1), undefined);
assert.equal(cache.get(2).name, "Bob"); assert.equal(cache.get(2)!.name, "Bob");
assert.equal(cache.get(3).name, "Charly"); assert.equal(cache.get(3)!.name, "Charly");
assert.equal(cache._entries.length, 2); assert.equal(cache.size, 2);
}, },
"second entry is evicted if first is requested": assert => { "second entry is evicted if first is requested": assert => {
const cache = new LRUCache(2, e => e.id); const cache = new LRUCache<NameTuple, number>(2, e => e.id);
cache.set({id: 1, name: "Alice"}); cache.set({id: 1, name: "Alice"});
cache.set({id: 2, name: "Bob"}); cache.set({id: 2, name: "Bob"});
cache.get(1); cache.get(1);
cache.set({id: 3, name: "Charly"}); cache.set({id: 3, name: "Charly"});
assert.equal(cache.get(1).name, "Alice"); assert.equal(cache.get(1)!.name, "Alice");
assert.equal(cache.get(2), undefined); assert.equal(cache.get(2), undefined);
assert.equal(cache.get(3).name, "Charly"); assert.equal(cache.get(3)!.name, "Charly");
assert.equal(cache._entries.length, 2); assert.equal(cache.size, 2);
}, },
"setting an entry twice removes the first": assert => { "setting an entry twice removes the first": assert => {
const cache = new LRUCache(2, e => e.id); const cache = new LRUCache<NameTuple, number>(2, e => e.id);
cache.set({id: 1, name: "Alice"}); cache.set({id: 1, name: "Alice"});
cache.set({id: 2, name: "Bob"}); cache.set({id: 2, name: "Bob"});
cache.set({id: 1, name: "Al Ice"}); cache.set({id: 1, name: "Al Ice"});
cache.set({id: 3, name: "Charly"}); cache.set({id: 3, name: "Charly"});
assert.equal(cache.get(1).name, "Al Ice"); assert.equal(cache.get(1)!.name, "Al Ice");
assert.equal(cache.get(2), undefined); assert.equal(cache.get(2), undefined);
assert.equal(cache.get(3).name, "Charly"); assert.equal(cache.get(3)!.name, "Charly");
assert.equal(cache._entries.length, 2); assert.equal(cache.size, 2);
}, },
"evict callback is called": assert => { "evict callback is called": assert => {
let evictions = 0; let evictions = 0;
class CustomCache extends LRUCache { class CustomCache extends LRUCache<NameTuple, number> {
_onEvictEntry(entry) { onEvictEntry(entry) {
assert.equal(entry.name, "Alice"); assert.equal(entry.name, "Alice");
evictions += 1; evictions += 1;
} }
@ -130,8 +148,8 @@ export function tests() {
}, },
"evict callback is called when replacing entry with same identity": assert => { "evict callback is called when replacing entry with same identity": assert => {
let evictions = 0; let evictions = 0;
class CustomCache extends LRUCache { class CustomCache extends LRUCache<NameTuple, number> {
_onEvictEntry(entry) { onEvictEntry(entry) {
assert.equal(entry.name, "Alice"); assert.equal(entry.name, "Alice");
evictions += 1; evictions += 1;
} }

View file

@ -14,14 +14,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
export function groupBy(array, groupFn) { export function groupBy<K, V>(array: V[], groupFn: (V) => K): Map<K, V[]> {
return groupByWithCreator(array, groupFn, return groupByWithCreator<K, V, V[]>(array, groupFn,
() => {return [];}, () => {return [];},
(array, value) => array.push(value) (array, value) => array.push(value)
); );
} }
export function groupByWithCreator(array, groupFn, createCollectionFn, addCollectionFn) { export function groupByWithCreator<K, V, C>(array: V[], groupFn: (V) => K, createCollectionFn: () => C, addCollectionFn: (C, V) => void): Map<K, C> {
return array.reduce((map, value) => { return array.reduce((map, value) => {
const key = groupFn(value); const key = groupFn(value);
let collection = map.get(key); let collection = map.get(key);
@ -31,10 +31,10 @@ export function groupByWithCreator(array, groupFn, createCollectionFn, addCollec
} }
addCollectionFn(collection, value); addCollectionFn(collection, value);
return map; return map;
}, new Map()); }, new Map<K, C>());
} }
export function countBy(events, mapper) { export function countBy<V>(events: V[], mapper: (V) => string | number): { [key: string]: number } {
return events.reduce((counts, event) => { return events.reduce((counts, event) => {
const mappedValue = mapper(event); const mappedValue = mapper(event);
if (!counts[mappedValue]) { if (!counts[mappedValue]) {