implement key caching in KeyLoader

merging session cache into it so we can better manage and recycle
keys without exposing too low-level public methods on BaseLRUCache.

Using refCount instead of inUse flag as a key can of course be used
by multiple useKey calls at the same time.
This commit is contained in:
Bruno Windels 2021-10-21 11:12:54 +02:00
parent 3bafc89855
commit 66a77519d7
4 changed files with 189 additions and 135 deletions

View file

@ -15,7 +15,8 @@ limitations under the License.
*/ */
import {SessionCache} from "./SessionCache"; import {SessionCache} from "./SessionCache";
import {IRoomKey} from "./RoomKey"; import {IRoomKey, isBetterThan} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache";
export declare class OlmInboundGroupSession { export declare class OlmInboundGroupSession {
constructor(); constructor();
@ -30,82 +31,193 @@ export declare class OlmInboundGroupSession {
export_session(message_index: number): string; export_session(message_index: number): string;
} }
// this is what cache.get(...) should return
function findIndexBestForSession(ops: KeyOperation[], roomId: string, senderKey: string, sessionId: string): number {
return ops.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
if (op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
}
return bestIdx;
}, -1);
}
/* /*
Because Olm only has very limited memory available when compiled to wasm, Because Olm only has very limited memory available when compiled to wasm,
we limit the amount of sessions held in memory. we limit the amount of sessions held in memory.
*/ */
export class KeyLoader { export class KeyLoader extends BaseLRUCache<KeyOperation> {
public readonly cache: SessionCache; private runningOps: Set<KeyOperation>;
private unusedOps: Set<KeyOperation>;
private pickleKey: string; private pickleKey: string;
private olm: any; private olm: any;
private resolveUnusedEntry?: () => void; private resolveUnusedOperation?: () => void;
private entryBecomesUnusedPromise?: Promise<void>; private operationBecomesUnusedPromise?: Promise<void>;
constructor(olm: any, pickleKey: string, limit: number) { constructor(olm: any, pickleKey: string, limit: number) {
this.cache = new SessionCache(limit); super(limit);
this.pickleKey = pickleKey; this.pickleKey = pickleKey;
this.olm = olm; this.olm = olm;
} }
getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined {
const idx = this.findIndexBestForSession(roomId, senderKey, sessionId);
if (idx !== -1) {
return this._getByIndexAndMoveUp(idx)!.key;
}
}
async useKey<T>(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> { async useKey<T>(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const cacheEntry = await this.allocateEntry(key); const keyOp = await this.allocateOperation(key);
try { try {
const {session} = cacheEntry; return await callback(keyOp.session, this.pickleKey);
key.loadInto(session, this.pickleKey);
return await callback(session, this.pickleKey);
} finally { } finally {
this.freeEntry(cacheEntry); this.releaseOperation(keyOp);
} }
} }
get running() { get running() {
return !!this.cache.find(entry => entry.inUse); return this._entries.some(op => op.refCount !== 0);
} }
private async allocateEntry(key: IRoomKey): Promise<CacheEntry> { dispose() {
let entry; for (let i = 0; i < this._entries.length; i += 1) {
if (this.cache.size >= this.cache.limit) { this._entries[i].dispose();
while(!(entry = this.cache.find(entry => !entry.inUse))) { }
await this.entryBecomesUnused(); // remove all entries
this._entries.splice(0, this._entries.length);
}
private async allocateOperation(key: IRoomKey): Promise<KeyOperation> {
let idx;
while((idx = this.findIndexForAllocation(key)) === -1) {
await this.operationBecomesUnused();
}
if (idx < this.size) {
const op = this._getByIndexAndMoveUp(idx)!;
// cache hit
if (op.isForKey(key)) {
op.refCount += 1;
return op;
} else {
// refCount should be 0 here
op.refCount = 1;
op.key = key;
key.loadInto(op.session, this.pickleKey);
} }
entry.inUse = true; return op;
entry.key = key;
} else { } else {
const session: OlmInboundGroupSession = new this.olm.InboundGroupSession(); // create new operation
const entry = new CacheEntry(key, session); const session = new this.olm.InboundGroupSession();
this.cache.add(entry); key.loadInto(session, this.pickleKey);
const op = new KeyOperation(key, session);
this._set(op);
return op;
} }
return entry;
} }
private freeEntry(entry: CacheEntry) { private releaseOperation(op: KeyOperation) {
entry.inUse = false; op.refCount -= 1;
if (this.resolveUnusedEntry) { if (op.refCount <= 0 && this.resolveUnusedOperation) {
this.resolveUnusedEntry(); this.resolveUnusedOperation();
// promise is resolved now, we'll need a new one for next await so clear // promise is resolved now, we'll need a new one for next await so clear
this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined; this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined;
} }
} }
private entryBecomesUnused(): Promise<void> { private operationBecomesUnused(): Promise<void> {
if (!this.entryBecomesUnusedPromise) { if (!this.operationBecomesUnusedPromise) {
this.entryBecomesUnusedPromise = new Promise(resolve => { this.operationBecomesUnusedPromise = new Promise(resolve => {
this.resolveUnusedEntry = resolve; this.resolveUnusedOperation = resolve;
}); });
} }
return this.entryBecomesUnusedPromise; return this.operationBecomesUnusedPromise;
}
private findIndexForAllocation(key: IRoomKey) {
let idx = this.findIndexSameKey(key); // cache hit
if (idx === -1) {
idx = this.findIndexSameSessionUnused(key);
if (idx === -1) {
if (this.size < this.limit) {
idx = this.size;
} else {
idx = this.findIndexOldestUnused();
}
}
}
return idx;
}
private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number {
return this._entries.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
if (op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
}
return bestIdx;
}, -1);
}
private findIndexSameKey(key: IRoomKey): number {
return this._entries.findIndex(op => {
return op.isForKey(key);
});
}
private findIndexSameSessionUnused(key: IRoomKey): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) {
return i;
}
}
return -1;
}
private findIndexOldestUnused(): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
if (op.refCount === 0) {
return i;
}
}
return -1;
} }
} }
class CacheEntry { class KeyOperation {
inUse: boolean;
session: OlmInboundGroupSession; session: OlmInboundGroupSession;
key: IRoomKey; key: IRoomKey;
refCount: number;
constructor(key, session) { constructor(key: IRoomKey, session: OlmInboundGroupSession) {
this.key = key; this.key = key;
this.session = session; this.session = session;
this.inUse = true; this.refCount = 1;
}
isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean {
return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId;
}
// assumes isForSameSession is true
isBetter(other: KeyOperation) {
return isBetterThan(this.session, other.session);
}
isForKey(key: IRoomKey) {
return this.key.serializationKey === key.serializationKey &&
this.key.serializationType === key.serializationType;
}
dispose() {
this.session.free();
} }
} }

View file

@ -18,20 +18,25 @@ import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/Inbound
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult"; import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader";
import {SessionCache} from "./SessionCache";
export interface IRoomKey { export interface IRoomKey {
get roomId(): string; get roomId(): string;
get senderKey(): string; get senderKey(): string;
get sessionId(): string; get sessionId(): string;
get claimedEd25519Key(): string; get claimedEd25519Key(): string;
get serializationKey(): string;
get serializationType(): string;
get eventIds(): string[] | undefined; get eventIds(): string[] | undefined;
loadInto(session: OlmInboundGroupSession, pickleKey: string): void; loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
} }
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index();
}
export interface IIncomingRoomKey extends IRoomKey { export interface IIncomingRoomKey extends IRoomKey {
get isBetter(): boolean | undefined; get isBetter(): boolean | undefined;
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>; checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
write(loader: KeyLoader, txn: Transaction): Promise<boolean>; write(loader: KeyLoader, txn: Transaction): Promise<boolean>;
} }
@ -39,8 +44,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
private _eventIds?: string[]; private _eventIds?: string[];
private _isBetter?: boolean; private _isBetter?: boolean;
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> { checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
return this._checkBetterKeyInStorage(loader, undefined, txn); return this._checkBetterThanKeyInStorage(loader, undefined, txn);
} }
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> { async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
@ -51,7 +56,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
// we haven't checked if this is the best key yet, // we haven't checked if this is the best key yet,
// so do that now to not overwrite a better key. // so do that now to not overwrite a better key.
// while we have the key deserialized, also pickle it to store it later on here. // while we have the key deserialized, also pickle it to store it later on here.
await this._checkBetterKeyInStorage(loader, (session, pickleKey) => { await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => {
pickledSession = session.pickle(pickleKey); pickledSession = session.pickle(pickleKey);
}, txn); }, txn);
} }
@ -76,7 +81,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
get eventIds() { return this._eventIds; } get eventIds() { return this._eventIds; }
get isBetter() { return this._isBetter; } get isBetter() { return this._isBetter; }
private async _checkBetterKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> { private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this._isBetter !== undefined) { if (this._isBetter !== undefined) {
return this._isBetter; return this._isBetter;
} }
@ -96,7 +101,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
if (existingKey) { if (existingKey) {
this._isBetter = await loader.useKey(this, newSession => { this._isBetter = await loader.useKey(this, newSession => {
return loader.useKey(existingKey, (existingSession, pickleKey) => { return loader.useKey(existingKey, (existingSession, pickleKey) => {
const isBetter = newSession.first_known_index() < existingSession.first_known_index(); const isBetter = isBetterThan(newSession, existingSession);
if (isBetter && callback) { if (isBetter && callback) {
callback(newSession, pickleKey); callback(newSession, pickleKey);
} }
@ -114,6 +119,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
abstract get senderKey(): string; abstract get senderKey(): string;
abstract get sessionId(): string; abstract get sessionId(): string;
abstract get claimedEd25519Key(): string; abstract get claimedEd25519Key(): string;
abstract get serializationKey(): string;
abstract get serializationType(): string;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
} }
@ -129,10 +136,11 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey {
get senderKey() { return this._decryptionResult.senderCurve25519Key; } get senderKey() { return this._decryptionResult.senderCurve25519Key; }
get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } get sessionId() { return this._decryptionResult.event.content?.["session_id"]; }
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; }
get serializationType(): string { return "create"; }
loadInto(session) { loadInto(session) {
const sessionKey = this._decryptionResult.event.content?.["session_key"]; session.create(this.serializationKey);
session.create(sessionKey);
} }
} }
@ -152,10 +160,11 @@ class BackupRoomKey extends BaseIncomingRoomKey {
get senderKey() { return this._backupInfo["sender_key"]; } get senderKey() { return this._backupInfo["sender_key"]; }
get sessionId() { return this._sessionId; } get sessionId() { return this._sessionId; }
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
get serializationKey(): string { return this._backupInfo["session_key"]; }
get serializationType(): string { return "import_session"; }
loadInto(session) { loadInto(session) {
const sessionKey = this._backupInfo["session_key"]; session.import_session(this.serializationKey);
session.import_session(sessionKey);
} }
} }
@ -171,9 +180,11 @@ class StoredRoomKey implements IRoomKey {
get sessionId() { return this.storageEntry.sessionId; } get sessionId() { return this.storageEntry.sessionId; }
get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; } get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; }
get eventIds() { return this.storageEntry.eventIds; } get eventIds() { return this.storageEntry.eventIds; }
get serializationKey(): string { return this.storageEntry.session || ""; }
get serializationType(): string { return "unpickle"; }
loadInto(session, pickleKey) { loadInto(session, pickleKey) {
session.unpickle(pickleKey, this.storageEntry.session); session.unpickle(pickleKey, this.serializationKey);
} }
get hasSession() { get hasSession() {

View file

@ -1,63 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {BaseLRUCache} from "../../../../utils/LRUCache";
const DEFAULT_CACHE_SIZE = 10;
/**
* Cache of unpickled inbound megolm session.
*/
export class SessionCache extends BaseLRUCache {
constructor(limit) {
limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE;
super(limit);
}
/**
* @param {string} roomId
* @param {string} senderKey
* @param {string} sessionId
* @return {SessionInfo?}
*/
get(roomId, senderKey, sessionId) {
const sessionInfo = this._get(s => {
return s.roomId === roomId &&
s.senderKey === senderKey &&
sessionId === s.sessionId;
});
sessionInfo?.retain();
return sessionInfo;
}
add(sessionInfo) {
sessionInfo.retain();
this._set(sessionInfo, s => {
return s.roomId === sessionInfo.roomId &&
s.senderKey === sessionInfo.senderKey &&
s.sessionId === sessionInfo.sessionId;
});
}
_onEvictEntry(sessionInfo) {
sessionInfo.release();
}
dispose() {
for (const sessionInfo of this._entries) {
sessionInfo.release();
}
}
}

View file

@ -14,25 +14,29 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
type FindCallback<T> = (value: T) => boolean;
/** /**
* Very simple least-recently-used cache implementation * Very simple least-recently-used cache implementation
* that should be fast enough for very small cache sizes * that should be fast enough for very small cache sizes
*/ */
export class BaseLRUCache<T> { export class BaseLRUCache<T> {
private _limit: number; public readonly limit: number;
private _entries: T[]; protected _entries: T[];
constructor(limit: number) { constructor(limit: number) {
this._limit = limit; this.limit = limit;
this._entries = []; this._entries = [];
} }
get size() { return this._entries.length; } get size() { return this._entries.length; }
get limit() { return this._limit; }
_get(findEntryFn: (T) => boolean) { protected _get(findEntryFn: FindCallback<T>) {
const idx = this._entries.findIndex(findEntryFn); return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn));
}
protected _getByIndexAndMoveUp(idx: number) {
if (idx !== -1) { if (idx !== -1) {
const entry = this._entries[idx]; const entry = this._entries[idx];
// move to top // move to top
@ -44,11 +48,11 @@ export class BaseLRUCache<T> {
} }
} }
_set(value: T, findEntryFn: (T) => boolean) { protected _set(value: T, findEntryFn?: FindCallback<T>) {
let indexToRemove = this._entries.findIndex(findEntryFn); let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1;
this._entries.unshift(value); this._entries.unshift(value);
if (indexToRemove === -1) { if (indexToRemove === -1) {
if (this._entries.length > this._limit) { if (this._entries.length > this.limit) {
indexToRemove = this._entries.length - 1; indexToRemove = this._entries.length - 1;
} }
} else { } else {
@ -56,22 +60,12 @@ export class BaseLRUCache<T> {
indexToRemove += 1; indexToRemove += 1;
} }
if (indexToRemove !== -1) { if (indexToRemove !== -1) {
this._onEvictEntry(this._entries[indexToRemove]); this.onEvictEntry(this._entries[indexToRemove]);
this._entries.splice(indexToRemove, 1); this._entries.splice(indexToRemove, 1);
} }
} }
find(callback: (T) => boolean) { protected onEvictEntry(entry: T) {}
// iterate backwards so least recently used items are found first
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const entry = this._entries[i];
if (callback(entry)) {
return entry;
}
}
}
_onEvictEntry(entry: T) {}
} }
export class LRUCache<T, K> extends BaseLRUCache<T> { export class LRUCache<T, K> extends BaseLRUCache<T> {