implement key caching in KeyLoader
merging session cache into it so we can better manage and recycle keys without exposing too low-level public methods on BaseLRUCache. Using refCount instead of inUse flag as a key can of course be used by multiple useKey calls at the same time.
This commit is contained in:
parent
3bafc89855
commit
66a77519d7
4 changed files with 189 additions and 135 deletions
|
@ -15,7 +15,8 @@ limitations under the License.
|
|||
*/
|
||||
|
||||
import {SessionCache} from "./SessionCache";
|
||||
import {IRoomKey} from "./RoomKey";
|
||||
import {IRoomKey, isBetterThan} from "./RoomKey";
|
||||
import {BaseLRUCache} from "../../../../utils/LRUCache";
|
||||
|
||||
export declare class OlmInboundGroupSession {
|
||||
constructor();
|
||||
|
@ -30,82 +31,193 @@ export declare class OlmInboundGroupSession {
|
|||
export_session(message_index: number): string;
|
||||
}
|
||||
|
||||
// this is what cache.get(...) should return
|
||||
function findIndexBestForSession(ops: KeyOperation[], roomId: string, senderKey: string, sessionId: string): number {
|
||||
return ops.reduce((bestIdx, op, i, arr) => {
|
||||
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
|
||||
if (op.isForSameSession(roomId, senderKey, sessionId)) {
|
||||
if (!bestOp || op.isBetter(bestOp)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return bestIdx;
|
||||
}, -1);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
Because Olm only has very limited memory available when compiled to wasm,
|
||||
we limit the amount of sessions held in memory.
|
||||
*/
|
||||
export class KeyLoader {
|
||||
export class KeyLoader extends BaseLRUCache<KeyOperation> {
|
||||
|
||||
public readonly cache: SessionCache;
|
||||
private runningOps: Set<KeyOperation>;
|
||||
private unusedOps: Set<KeyOperation>;
|
||||
private pickleKey: string;
|
||||
private olm: any;
|
||||
private resolveUnusedEntry?: () => void;
|
||||
private entryBecomesUnusedPromise?: Promise<void>;
|
||||
private resolveUnusedOperation?: () => void;
|
||||
private operationBecomesUnusedPromise?: Promise<void>;
|
||||
|
||||
constructor(olm: any, pickleKey: string, limit: number) {
|
||||
this.cache = new SessionCache(limit);
|
||||
super(limit);
|
||||
this.pickleKey = pickleKey;
|
||||
this.olm = olm;
|
||||
}
|
||||
|
||||
getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined {
|
||||
const idx = this.findIndexBestForSession(roomId, senderKey, sessionId);
|
||||
if (idx !== -1) {
|
||||
return this._getByIndexAndMoveUp(idx)!.key;
|
||||
}
|
||||
}
|
||||
|
||||
async useKey<T>(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
|
||||
const cacheEntry = await this.allocateEntry(key);
|
||||
const keyOp = await this.allocateOperation(key);
|
||||
try {
|
||||
const {session} = cacheEntry;
|
||||
key.loadInto(session, this.pickleKey);
|
||||
return await callback(session, this.pickleKey);
|
||||
return await callback(keyOp.session, this.pickleKey);
|
||||
} finally {
|
||||
this.freeEntry(cacheEntry);
|
||||
this.releaseOperation(keyOp);
|
||||
}
|
||||
}
|
||||
|
||||
get running() {
|
||||
return !!this.cache.find(entry => entry.inUse);
|
||||
return this._entries.some(op => op.refCount !== 0);
|
||||
}
|
||||
|
||||
private async allocateEntry(key: IRoomKey): Promise<CacheEntry> {
|
||||
let entry;
|
||||
if (this.cache.size >= this.cache.limit) {
|
||||
while(!(entry = this.cache.find(entry => !entry.inUse))) {
|
||||
await this.entryBecomesUnused();
|
||||
dispose() {
|
||||
for (let i = 0; i < this._entries.length; i += 1) {
|
||||
this._entries[i].dispose();
|
||||
}
|
||||
// remove all entries
|
||||
this._entries.splice(0, this._entries.length);
|
||||
}
|
||||
|
||||
private async allocateOperation(key: IRoomKey): Promise<KeyOperation> {
|
||||
let idx;
|
||||
while((idx = this.findIndexForAllocation(key)) === -1) {
|
||||
await this.operationBecomesUnused();
|
||||
}
|
||||
if (idx < this.size) {
|
||||
const op = this._getByIndexAndMoveUp(idx)!;
|
||||
// cache hit
|
||||
if (op.isForKey(key)) {
|
||||
op.refCount += 1;
|
||||
return op;
|
||||
} else {
|
||||
// refCount should be 0 here
|
||||
op.refCount = 1;
|
||||
op.key = key;
|
||||
key.loadInto(op.session, this.pickleKey);
|
||||
}
|
||||
entry.inUse = true;
|
||||
entry.key = key;
|
||||
return op;
|
||||
} else {
|
||||
const session: OlmInboundGroupSession = new this.olm.InboundGroupSession();
|
||||
const entry = new CacheEntry(key, session);
|
||||
this.cache.add(entry);
|
||||
// create new operation
|
||||
const session = new this.olm.InboundGroupSession();
|
||||
key.loadInto(session, this.pickleKey);
|
||||
const op = new KeyOperation(key, session);
|
||||
this._set(op);
|
||||
return op;
|
||||
}
|
||||
return entry;
|
||||
}
|
||||
|
||||
private freeEntry(entry: CacheEntry) {
|
||||
entry.inUse = false;
|
||||
if (this.resolveUnusedEntry) {
|
||||
this.resolveUnusedEntry();
|
||||
private releaseOperation(op: KeyOperation) {
|
||||
op.refCount -= 1;
|
||||
if (op.refCount <= 0 && this.resolveUnusedOperation) {
|
||||
this.resolveUnusedOperation();
|
||||
// promise is resolved now, we'll need a new one for next await so clear
|
||||
this.entryBecomesUnusedPromise = this.resolveUnusedEntry = undefined;
|
||||
this.operationBecomesUnusedPromise = this.resolveUnusedOperation = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private entryBecomesUnused(): Promise<void> {
|
||||
if (!this.entryBecomesUnusedPromise) {
|
||||
this.entryBecomesUnusedPromise = new Promise(resolve => {
|
||||
this.resolveUnusedEntry = resolve;
|
||||
private operationBecomesUnused(): Promise<void> {
|
||||
if (!this.operationBecomesUnusedPromise) {
|
||||
this.operationBecomesUnusedPromise = new Promise(resolve => {
|
||||
this.resolveUnusedOperation = resolve;
|
||||
});
|
||||
}
|
||||
return this.entryBecomesUnusedPromise;
|
||||
return this.operationBecomesUnusedPromise;
|
||||
}
|
||||
|
||||
private findIndexForAllocation(key: IRoomKey) {
|
||||
let idx = this.findIndexSameKey(key); // cache hit
|
||||
if (idx === -1) {
|
||||
idx = this.findIndexSameSessionUnused(key);
|
||||
if (idx === -1) {
|
||||
if (this.size < this.limit) {
|
||||
idx = this.size;
|
||||
} else {
|
||||
idx = this.findIndexOldestUnused();
|
||||
}
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number {
|
||||
return this._entries.reduce((bestIdx, op, i, arr) => {
|
||||
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
|
||||
if (op.isForSameSession(roomId, senderKey, sessionId)) {
|
||||
if (!bestOp || op.isBetter(bestOp)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return bestIdx;
|
||||
}, -1);
|
||||
}
|
||||
|
||||
private findIndexSameKey(key: IRoomKey): number {
|
||||
return this._entries.findIndex(op => {
|
||||
return op.isForKey(key);
|
||||
});
|
||||
}
|
||||
|
||||
private findIndexSameSessionUnused(key: IRoomKey): number {
|
||||
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
|
||||
const op = this._entries[i];
|
||||
if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private findIndexOldestUnused(): number {
|
||||
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
|
||||
const op = this._entries[i];
|
||||
if (op.refCount === 0) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
class CacheEntry {
|
||||
inUse: boolean;
|
||||
class KeyOperation {
|
||||
session: OlmInboundGroupSession;
|
||||
key: IRoomKey;
|
||||
refCount: number;
|
||||
|
||||
constructor(key, session) {
|
||||
constructor(key: IRoomKey, session: OlmInboundGroupSession) {
|
||||
this.key = key;
|
||||
this.session = session;
|
||||
this.inUse = true;
|
||||
this.refCount = 1;
|
||||
}
|
||||
|
||||
isForSameSession(roomId: string, senderKey: string, sessionId: string): boolean {
|
||||
return this.key.roomId === roomId && this.key.senderKey === senderKey && this.key.sessionId === sessionId;
|
||||
}
|
||||
|
||||
// assumes isForSameSession is true
|
||||
isBetter(other: KeyOperation) {
|
||||
return isBetterThan(this.session, other.session);
|
||||
}
|
||||
|
||||
isForKey(key: IRoomKey) {
|
||||
return this.key.serializationKey === key.serializationKey &&
|
||||
this.key.serializationType === key.serializationType;
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this.session.free();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,20 +18,25 @@ import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/Inbound
|
|||
import type {Transaction} from "../../../storage/idb/Transaction";
|
||||
import type {DecryptionResult} from "../../DecryptionResult";
|
||||
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader";
|
||||
import {SessionCache} from "./SessionCache";
|
||||
|
||||
export interface IRoomKey {
|
||||
get roomId(): string;
|
||||
get senderKey(): string;
|
||||
get sessionId(): string;
|
||||
get claimedEd25519Key(): string;
|
||||
get serializationKey(): string;
|
||||
get serializationType(): string;
|
||||
get eventIds(): string[] | undefined;
|
||||
loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
|
||||
}
|
||||
|
||||
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) {
|
||||
return newSession.first_known_index() < existingSession.first_known_index();
|
||||
}
|
||||
|
||||
export interface IIncomingRoomKey extends IRoomKey {
|
||||
get isBetter(): boolean | undefined;
|
||||
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
|
||||
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
|
||||
write(loader: KeyLoader, txn: Transaction): Promise<boolean>;
|
||||
}
|
||||
|
||||
|
@ -39,8 +44,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
|
|||
private _eventIds?: string[];
|
||||
private _isBetter?: boolean;
|
||||
|
||||
checkBetterKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
|
||||
return this._checkBetterKeyInStorage(loader, undefined, txn);
|
||||
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
|
||||
return this._checkBetterThanKeyInStorage(loader, undefined, txn);
|
||||
}
|
||||
|
||||
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
|
||||
|
@ -51,7 +56,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
|
|||
// we haven't checked if this is the best key yet,
|
||||
// so do that now to not overwrite a better key.
|
||||
// while we have the key deserialized, also pickle it to store it later on here.
|
||||
await this._checkBetterKeyInStorage(loader, (session, pickleKey) => {
|
||||
await this._checkBetterThanKeyInStorage(loader, (session, pickleKey) => {
|
||||
pickledSession = session.pickle(pickleKey);
|
||||
}, txn);
|
||||
}
|
||||
|
@ -76,7 +81,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
|
|||
get eventIds() { return this._eventIds; }
|
||||
get isBetter() { return this._isBetter; }
|
||||
|
||||
private async _checkBetterKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
|
||||
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
|
||||
if (this._isBetter !== undefined) {
|
||||
return this._isBetter;
|
||||
}
|
||||
|
@ -96,7 +101,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
|
|||
if (existingKey) {
|
||||
this._isBetter = await loader.useKey(this, newSession => {
|
||||
return loader.useKey(existingKey, (existingSession, pickleKey) => {
|
||||
const isBetter = newSession.first_known_index() < existingSession.first_known_index();
|
||||
const isBetter = isBetterThan(newSession, existingSession);
|
||||
if (isBetter && callback) {
|
||||
callback(newSession, pickleKey);
|
||||
}
|
||||
|
@ -114,6 +119,8 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
|
|||
abstract get senderKey(): string;
|
||||
abstract get sessionId(): string;
|
||||
abstract get claimedEd25519Key(): string;
|
||||
abstract get serializationKey(): string;
|
||||
abstract get serializationType(): string;
|
||||
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
|
||||
}
|
||||
|
||||
|
@ -129,10 +136,11 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey {
|
|||
get senderKey() { return this._decryptionResult.senderCurve25519Key; }
|
||||
get sessionId() { return this._decryptionResult.event.content?.["session_id"]; }
|
||||
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
|
||||
get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; }
|
||||
get serializationType(): string { return "create"; }
|
||||
|
||||
loadInto(session) {
|
||||
const sessionKey = this._decryptionResult.event.content?.["session_key"];
|
||||
session.create(sessionKey);
|
||||
session.create(this.serializationKey);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -152,10 +160,11 @@ class BackupRoomKey extends BaseIncomingRoomKey {
|
|||
get senderKey() { return this._backupInfo["sender_key"]; }
|
||||
get sessionId() { return this._sessionId; }
|
||||
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
|
||||
|
||||
get serializationKey(): string { return this._backupInfo["session_key"]; }
|
||||
get serializationType(): string { return "import_session"; }
|
||||
|
||||
loadInto(session) {
|
||||
const sessionKey = this._backupInfo["session_key"];
|
||||
session.import_session(sessionKey);
|
||||
session.import_session(this.serializationKey);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,9 +180,11 @@ class StoredRoomKey implements IRoomKey {
|
|||
get sessionId() { return this.storageEntry.sessionId; }
|
||||
get claimedEd25519Key() { return this.storageEntry.claimedKeys!["ed25519"]; }
|
||||
get eventIds() { return this.storageEntry.eventIds; }
|
||||
|
||||
get serializationKey(): string { return this.storageEntry.session || ""; }
|
||||
get serializationType(): string { return "unpickle"; }
|
||||
|
||||
loadInto(session, pickleKey) {
|
||||
session.unpickle(pickleKey, this.storageEntry.session);
|
||||
session.unpickle(pickleKey, this.serializationKey);
|
||||
}
|
||||
|
||||
get hasSession() {
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
/*
|
||||
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
import {BaseLRUCache} from "../../../../utils/LRUCache";
|
||||
const DEFAULT_CACHE_SIZE = 10;
|
||||
|
||||
/**
|
||||
* Cache of unpickled inbound megolm session.
|
||||
*/
|
||||
export class SessionCache extends BaseLRUCache {
|
||||
constructor(limit) {
|
||||
limit = typeof limit === "number" ? limit : DEFAULT_CACHE_SIZE;
|
||||
super(limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} roomId
|
||||
* @param {string} senderKey
|
||||
* @param {string} sessionId
|
||||
* @return {SessionInfo?}
|
||||
*/
|
||||
get(roomId, senderKey, sessionId) {
|
||||
const sessionInfo = this._get(s => {
|
||||
return s.roomId === roomId &&
|
||||
s.senderKey === senderKey &&
|
||||
sessionId === s.sessionId;
|
||||
});
|
||||
sessionInfo?.retain();
|
||||
return sessionInfo;
|
||||
}
|
||||
|
||||
add(sessionInfo) {
|
||||
sessionInfo.retain();
|
||||
this._set(sessionInfo, s => {
|
||||
return s.roomId === sessionInfo.roomId &&
|
||||
s.senderKey === sessionInfo.senderKey &&
|
||||
s.sessionId === sessionInfo.sessionId;
|
||||
});
|
||||
}
|
||||
|
||||
_onEvictEntry(sessionInfo) {
|
||||
sessionInfo.release();
|
||||
}
|
||||
|
||||
dispose() {
|
||||
for (const sessionInfo of this._entries) {
|
||||
sessionInfo.release();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -14,25 +14,29 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
type FindCallback<T> = (value: T) => boolean;
|
||||
/**
|
||||
* Very simple least-recently-used cache implementation
|
||||
* that should be fast enough for very small cache sizes
|
||||
*/
|
||||
export class BaseLRUCache<T> {
|
||||
|
||||
private _limit: number;
|
||||
private _entries: T[];
|
||||
public readonly limit: number;
|
||||
protected _entries: T[];
|
||||
|
||||
constructor(limit: number) {
|
||||
this._limit = limit;
|
||||
this.limit = limit;
|
||||
this._entries = [];
|
||||
}
|
||||
|
||||
get size() { return this._entries.length; }
|
||||
get limit() { return this._limit; }
|
||||
|
||||
_get(findEntryFn: (T) => boolean) {
|
||||
const idx = this._entries.findIndex(findEntryFn);
|
||||
protected _get(findEntryFn: FindCallback<T>) {
|
||||
return this._getByIndexAndMoveUp(this._entries.findIndex(findEntryFn));
|
||||
}
|
||||
|
||||
protected _getByIndexAndMoveUp(idx: number) {
|
||||
if (idx !== -1) {
|
||||
const entry = this._entries[idx];
|
||||
// move to top
|
||||
|
@ -44,11 +48,11 @@ export class BaseLRUCache<T> {
|
|||
}
|
||||
}
|
||||
|
||||
_set(value: T, findEntryFn: (T) => boolean) {
|
||||
let indexToRemove = this._entries.findIndex(findEntryFn);
|
||||
protected _set(value: T, findEntryFn?: FindCallback<T>) {
|
||||
let indexToRemove = findEntryFn ? this._entries.findIndex(findEntryFn) : -1;
|
||||
this._entries.unshift(value);
|
||||
if (indexToRemove === -1) {
|
||||
if (this._entries.length > this._limit) {
|
||||
if (this._entries.length > this.limit) {
|
||||
indexToRemove = this._entries.length - 1;
|
||||
}
|
||||
} else {
|
||||
|
@ -56,22 +60,12 @@ export class BaseLRUCache<T> {
|
|||
indexToRemove += 1;
|
||||
}
|
||||
if (indexToRemove !== -1) {
|
||||
this._onEvictEntry(this._entries[indexToRemove]);
|
||||
this.onEvictEntry(this._entries[indexToRemove]);
|
||||
this._entries.splice(indexToRemove, 1);
|
||||
}
|
||||
}
|
||||
|
||||
find(callback: (T) => boolean) {
|
||||
// iterate backwards so least recently used items are found first
|
||||
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
|
||||
const entry = this._entries[i];
|
||||
if (callback(entry)) {
|
||||
return entry;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_onEvictEntry(entry: T) {}
|
||||
protected onEvictEntry(entry: T) {}
|
||||
}
|
||||
|
||||
export class LRUCache<T, K> extends BaseLRUCache<T> {
|
||||
|
|
Reference in a new issue