implement key caching in KeyLoader

merging session cache into it so we can better manage and recycle
keys without exposing too low-level public methods on BaseLRUCache.

Using refCount instead of inUse flag as a key can of course be used
by multiple useKey calls at the same time.
This commit is contained in:
Bruno Windels 2021-10-21 11:12:54 +02:00
parent 3bafc89855
commit 66a77519d7
4 changed files with 189 additions and 135 deletions

View file

@ -15,7 +15,8 @@ limitations under the License.
*/
import {SessionCache} from "./SessionCache";
import {IRoomKey} from "./RoomKey";
import {IRoomKey, isBetterThan} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache";
export declare class OlmInboundGroupSession {
constructor();
@ -30,82 +31,193 @@ export declare class OlmInboundGroupSession {
export_session(message_index: number): string;
}
// this is what cache.get(...) should return
function findIndexBestForSession(ops: KeyOperation[], roomId: string, senderKey: string, sessionId: string): number {
return ops.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
if (op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
}
return bestIdx;
}, -1);
}
/*
Because Olm only has very limited memory available when compiled to wasm,
we limit the amount of sessions held in memory.
*/
export class KeyLoader {
export class KeyLoader extends BaseLRUCache<KeyOperation> {
public readonly cache: SessionCache;
private runningOps: Set<KeyOperation>;
private unusedOps: Set<KeyOperation>;
private pickleKey: string;
private olm: any;
private resolveUnusedEntry?: () => void;
private entryBecomesUnusedPromise?: Promise<void>;
private resolveUnusedOperation?: () => void;
private operationBecomesUnusedPromise?: Promise<void>;
constructor(olm: any, pickleKey: string, limit: number) {
this.cache = new SessionCache(limit);
super(limit);
this.pickleKey = pickleKey;
this.olm = olm;
}
getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined {
const idx = this.findIndexBestForSession(roomId, senderKey, sessionId);
if (idx !== -1) {
return this._getByIndexAndMoveUp(idx)!.key;
}
}
async useKey<T>(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const cacheEntry = await this.allocateEntry(key);
const keyOp = await this.allocateOperation(key);
try {
const {session} = cacheEntry;
key.loadInto(session, this.pickleKey);
return await callback(session, this.pickleKey);
return await callback(keyOp.session, this.pickleKey);
} finally {
this.freeEntry(cacheEntry);
this.releaseOperation(keyOp);
}
}
get running() {
return !!this.cache.find(entry => entry.inUse);
return this._entries.some(op => op.refCount !== 0);
}
private async allocateEntry(key: IRoomKey): Promise<CacheEntry> {
let entry;
if (this.cache.size >= this.cache.limit) {
while(!(entry = this.cache.find(entry => !entry.inUse))) {
await this.entryBecomesUnused();
dispose() {
for (let i = 0; i < this._entries.length; i += 1) {
this._entries[i].dispose();
}
// remove all entries
this._entries.splice(0, this._entries.length);
}
private async allocateOperation(key: IRoomKey): Promise<KeyOperation> {
let idx;
while((idx = this.findIndexForAllocation(key)) === -1) {
await this.operationBecomesUnused();
}
if (idx < this.size) {
const op = this._getByIndexAndMoveUp(idx)!;
// cache hit
if (op.isForKey(key)) {
op.refCount += 1;
return op;
} else {
// refCount should be 0 here
op.refCount = 1;
op.key = key;
key.loadInto(op.session, this.pickleKey);
}
entry.inUse = true;
entry.key = key;
return op;
} else {
const session: OlmInboundGroupSession = new this.olm.InboundGroupSession();
const entry = new CacheEntry(key, session);
this.cache.add(entry);
// create new operation
const session = new this.olm.InboundGroupSession();
key.loadInto(session, this.pickleKey);
const op = new KeyOperation(key, session);
this._set(op);
return op;
}
return entry;
}
private freeEntry(entry: CacheEntry) {
entry.inUse = false;
if (this.resolveUnusedEntry) {
this.resolveUnusedEntry();
private releaseOperation(op: KeyOperation) {
op.refCount -= 1;
if (op.refCount <= 0 && this.resolveUnusedOperation) {
this.resolveUnusedOperation();
// promise is resolved now, we'll need a new one for next await so clear
this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined;
this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined;
}
}
private entryBecomesUnused(): Promise<void> {
if (!this.entryBecomesUnusedPromise) {
this.entryBecomesUnusedPromise = new Promise(resolve => {
this.resolveUnusedEntry = resolve;
private operationBecomesUnused(): Promise<void> {
if (!this.operationBecomesUnusedPromise) {
this.operationBecomesUnusedPromise = new Promise(resolve => {
this.resolveUnusedOperation = resolve;
});
}
return this.entryBecomesUnusedPromise;
return this.operationBecomesUnusedPromise;
}
private findIndexForAllocation(key: IRoomKey) {
let idx = this.findIndexSameKey(key); // cache hit
if (idx === -1) {
idx = this.findIndexSameSessionUnused(key);
if (idx === -1) {
if (this.size < this.limit) {
idx = this.size;
} else {
idx = this.findIndexOldestUnused();
}
}
}
return idx;
}
private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number {
return this._entries.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
if (op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
}
return bestIdx;
}, -1);
}
private findIndexSameKey(key: IRoomKey): number {
return this._entries.findIndex(op => {
return op.isForKey(key);
});
}
private findIndexSameSessionUnused(key: IRoomKey): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) {
return i;
}
}
return -1;
}
private findIndexOldestUnused(): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
if (op.refCount === 0) {
return i;
}
}
return -1;
}
}
class CacheEntry {
inUse: boolean;
class KeyOperation {
session: OlmInboundGroupSession;
key: IRoomKey;
refCount: number;
constructor(key, session) {
constructor(key: IRoomKey, session: OlmInboundGroupSession) {
this.key = key;
this.session = session;
this.inUse = true;
this.refCount = 1;
}
isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean {
return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId;
}
// assumes isForSameSession is true
isBetter(other: KeyOperation) {
return isBetterThan(this.session, other.session);
}
isForKey(key: IRoomKey) {
return this.key.serializationKey === key.serializationKey &&
this.key.serializationType === key.serializationType;
}
dispose() {
this.session.free();
}
}

View file

@ -18,20 +18,25 @@ import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/Inbound
import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader";
import {SessionCache} from "./SessionCache";
export interface IRoomKey {
get roomId(): string;
get senderKey(): string;
get sessionId(): string;
get claimedEd25519Key(): string;
get serializationKey(): string;
get serializationType(): string;
get eventIds(): string[] | undefined;
loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
}
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index();
}
export interface IIncomingRoomKey extends IRoomKey {
get isBetter(): boolean | undefined;
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
write(loader: KeyLoader, txn: Transaction): Promise<boolean>;
}
@ -39,8 +44,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
private _eventIds?: string[];
private _isBetter?: boolean;
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
return this._checkBetterKeyInStorage(loader, undefined, txn);
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
return this._checkBetterThanKeyInStorage(loader, undefined, txn);
}
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
@ -51,7 +56,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
// we haven't checked if this is the best key yet,
// so do that now to not overwrite a better key.
// while we have the key deserialized, also pickle it to store it later on here.
await this._checkBetterKeyInStorage(loader, (session, pickleKey) => {
await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => {
pickledSession = session.pickle(pickleKey);
}, txn);
}
@ -76,7 +81,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
get eventIds() { return this._eventIds; }
get isBetter() { return this._isBetter; }
private async _checkBetterKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this._isBetter !== undefined) {
return this._isBetter;
}
@ -96,7 +101,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
if (existingKey) {
this._isBetter = await loader.useKey(this, newSession => {
return loader.useKey(existingKey, (existingSession, pickleKey) => {
const isBetter = newSession.first_known_index() < existingSession.first_known_index();
const isBetter = isBetterThan(newSession, existingSession);
if (isBetter && callback) {
callback(newSession, pickleKey);
}
@ -114,6 +119,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
abstract get senderKey(): string;
abstract get sessionId(): string;
abstract get claimedEd25519Key(): string;
abstract get serializationKey(): string;
abstract get serializationType(): string;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
}
@ -129,10 +136,11 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey {
get senderKey() { return this._decryptionResult.senderCurve25519Key; }
get sessionId() { return this._decryptionResult.event.content?.["session_id"]; }
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; }
get serializationType(): string { return "create"; }
loadInto(session) {
const sessionKey = this._decryptionResult.event.content?.["session_key"];
session.create(sessionKey);
session.create(this.serializationKey);
}
}
@ -152,10 +160,11 @@ class BackupRoomKey extends BaseIncomingRoomKey {
get senderKey() { return this._backupInfo["sender_key"]; }
get sessionId() { return this._sessionId; }
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
get serializationKey(): string { return this._backupInfo["session_key"]; }
get serializationType(): string { return "import_session"; }
loadInto(session) {
const sessionKey = this._backupInfo["session_key"];
session.import_session(sessionKey);
session.import_session(this.serializationKey);
}
}
@ -171,9 +180,11 @@ class StoredRoomKey implements IRoomKey {
get sessionId() { return this.storageEntry.sessionId; }
get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; }
get eventIds() { return this.storageEntry.eventIds; }
get serializationKey(): string { return this.storageEntry.session || ""; }
get serializationType(): string { return "unpickle"; }
loadInto(session, pickleKey) {
session.unpickle(pickleKey, this.storageEntry.session);
session.unpickle(pickleKey, this.serializationKey);
}
get hasSession() {

View file

@ -1,63 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {BaseLRUCache} from "../../../../utils/LRUCache";
const DEFAULT_CACHE_SIZE = 10;
/**
* Cache of unpickled inbound megolm session.
*/
export class SessionCache extends BaseLRUCache {
constructor(limit) {
limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE;
super(limit);
}
/**
* @param {string} roomId
* @param {string} senderKey
* @param {string} sessionId
* @return {SessionInfo?}
*/
get(roomId, senderKey, sessionId) {
const sessionInfo = this._get(s => {
return s.roomId === roomId &&
s.senderKey === senderKey &&
sessionId === s.sessionId;
});
sessionInfo?.retain();
return sessionInfo;
}
add(sessionInfo) {
sessionInfo.retain();
this._set(sessionInfo, s => {
return s.roomId === sessionInfo.roomId &&
s.senderKey === sessionInfo.senderKey &&
s.sessionId === sessionInfo.sessionId;
});
}
_onEvictEntry(sessionInfo) {
sessionInfo.release();
}
dispose() {
for (const sessionInfo of this._entries) {
sessionInfo.release();
}
}
}

View file

@ -14,25 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
type FindCallback<T> = (value: T) => boolean;
/**
* Very simple least-recently-used cache implementation
* that should be fast enough for very small cache sizes
*/
export class BaseLRUCache<T> {
private _limit: number;
private _entries: T[];
public readonly limit: number;
protected _entries: T[];
constructor(limit: number) {
this._limit = limit;
this.limit = limit;
this._entries = [];
}
get size() { return this._entries.length; }
get limit() { return this._limit; }
_get(findEntryFn: (T) => boolean) {
const idx = this._entries.findIndex(findEntryFn);
protected _get(findEntryFn: FindCallback<T>) {
return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn));
}
protected _getByIndexAndMoveUp(idx: number) {
if (idx !== -1) {
const entry = this._entries[idx];
// move to top
@ -44,11 +48,11 @@ export class BaseLRUCache<T> {
}
}
_set(value: T, findEntryFn: (T) => boolean) {
let indexToRemove = this._entries.findIndex(findEntryFn);
protected _set(value: T, findEntryFn?: FindCallback<T>) {
let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1;
this._entries.unshift(value);
if (indexToRemove === -1) {
if (this._entries.length > this._limit) {
if (this._entries.length > this.limit) {
indexToRemove = this._entries.length - 1;
}
} else {
@ -56,22 +60,12 @@ export class BaseLRUCache<T> {
indexToRemove += 1;
}
if (indexToRemove !== -1) {
this._onEvictEntry(this._entries[indexToRemove]);
this.onEvictEntry(this._entries[indexToRemove]);
this._entries.splice(indexToRemove, 1);
}
}
find(callback: (T) => boolean) {
// iterate backwards so least recently used items are found first
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const entry = this._entries[i];
if (callback(entry)) {
return entry;
}
}
}
_onEvictEntry(entry: T) {}
protected onEvictEntry(entry: T) {}
}
export class LRUCache<T, K> extends BaseLRUCache<T> {