prevent cache hiding better keys in storage (+ tests)

This commit is contained in:
Bruno Windels 2021-10-25 19:17:13 +02:00
parent 3c2604b384
commit ab2f15b5a2
4 changed files with 125 additions and 78 deletions

View file

@ -20,7 +20,7 @@ import {SessionDecryption} from "./decryption/SessionDecryption";
import {MEGOLM_ALGORITHM} from "../common.js";
import {validateEvent, groupEventsBySession} from "./decryption/utils";
import {keyFromStorage, keyFromDeviceMessage, keyFromBackup} from "./decryption/RoomKey";
import type {IRoomKey, IIncomingRoomKey} from "./decryption/RoomKey";
import type {RoomKey, IncomingRoomKey} from "./decryption/RoomKey";
import type {KeyLoader} from "./decryption/KeyLoader";
import type {OlmWorker} from "../OlmWorker";
import type {Transaction} from "../../storage/idb/Transaction";
@ -78,7 +78,7 @@ export class Decryption {
* @param {[type]} txn [description]
* @return {DecryptionPreparation}
*/
async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IIncomingRoomKey[] | undefined, txn: Transaction) {
async prepareDecryptAll(roomId: string, events: TimelineEvent[], newKeys: IncomingRoomKey[] | undefined, txn: Transaction) {
const errors = new Map();
const validEvents: TimelineEvent[] = [];
@ -107,7 +107,7 @@ export class Decryption {
return new DecryptionPreparation(roomId, sessionDecryptions, errors);
}
private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IIncomingRoomKey[] | undefined, txn: Transaction): Promise<IRoomKey | undefined> {
private async getRoomKey(roomId: string, senderKey: string, sessionId: string, newKeys: IncomingRoomKey[] | undefined, txn: Transaction): Promise<RoomKey | undefined> {
if (newKeys) {
const key = newKeys.find(k => k.roomId === roomId && k.senderKey === senderKey && k.sessionId === sessionId);
if (key && await key.checkBetterThanKeyInStorage(this.keyLoader, txn)) {
@ -128,7 +128,7 @@ export class Decryption {
/**
* Writes the key as an inbound group session if there is not already a better key in the store
*/
writeRoomKey(key: IIncomingRoomKey, txn: Transaction): Promise<boolean> {
writeRoomKey(key: IncomingRoomKey, txn: Transaction): Promise<boolean> {
return key.write(this.keyLoader, txn);
}
@ -136,8 +136,8 @@ export class Decryption {
* Extracts room keys from decrypted device messages.
* The key won't be persisted yet, you need to call RoomKey.write for that.
*/
roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IIncomingRoomKey[] {
const keys: IIncomingRoomKey[] = [];
roomKeysFromDeviceMessages(decryptionResults: DecryptionResult[], log: LogItem): IncomingRoomKey[] {
const keys: IncomingRoomKey[] = [];
for (const dr of decryptionResults) {
if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) {
continue;
@ -157,7 +157,7 @@ export class Decryption {
return keys;
}
roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IIncomingRoomKey | undefined {
roomKeyFromBackup(roomId: string, sessionId: string, sessionInfo: string): IncomingRoomKey | undefined {
return keyFromBackup(roomId, sessionId, sessionInfo);
}

View file

@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
import {IRoomKey, isBetterThan} from "./RoomKey";
import {isBetterThan, IncomingRoomKey} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache";
import type {RoomKey} from "./RoomKey";
export declare class OlmDecryptionResult {
readonly plaintext: string;
@ -53,14 +53,14 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
this.olm = olm;
}
getCachedKey(roomId: string, senderKey: string, sessionId: string): IRoomKey | undefined {
const idx = this.findIndexBestForSession(roomId, senderKey, sessionId);
getCachedKey(roomId: string, senderKey: string, sessionId: string): RoomKey | undefined {
const idx = this.findCachedKeyIndex(roomId, senderKey, sessionId);
if (idx !== -1) {
return this._getByIndexAndMoveUp(idx)!.key;
}
}
async useKey<T>(key: IRoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
async useKey<T>(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const keyOp = await this.allocateOperation(key);
try {
return await callback(keyOp.session, this.pickleKey);
@ -81,7 +81,7 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
this._entries.splice(0, this._entries.length);
}
private async allocateOperation(key: IRoomKey): Promise<KeyOperation> {
private async allocateOperation(key: RoomKey): Promise<KeyOperation> {
let idx;
while((idx = this.findIndexForAllocation(key)) === -1) {
await this.operationBecomesUnused();
@ -127,14 +127,14 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
return this.operationBecomesUnusedPromise;
}
private findIndexForAllocation(key: IRoomKey) {
private findIndexForAllocation(key: RoomKey) {
let idx = this.findIndexSameKey(key); // cache hit
if (idx === -1) {
idx = this.findIndexSameSessionUnused(key);
if (idx === -1) {
if (this.size < this.limit) {
idx = this.size;
} else {
idx = this.findIndexSameSessionUnused(key);
if (idx === -1) {
idx = this.findIndexOldestUnused();
}
}
@ -142,10 +142,11 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
return idx;
}
private findIndexBestForSession(roomId: string, senderKey: string, sessionId: string): number {
private findCachedKeyIndex(roomId: string, senderKey: string, sessionId: string): number {
return this._entries.reduce((bestIdx, op, i, arr) => {
const bestOp = bestIdx === -1 ? undefined : arr[bestIdx];
if (op.isForSameSession(roomId, senderKey, sessionId)) {
// only operations that are the "best" for their session can be used, see comment on isBest
if (op.isBest === true && op.isForSameSession(roomId, senderKey, sessionId)) {
if (!bestOp || op.isBetter(bestOp)) {
return i;
}
@ -154,20 +155,23 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
}, -1);
}
private findIndexSameKey(key: IRoomKey): number {
private findIndexSameKey(key: RoomKey): number {
return this._entries.findIndex(op => {
return op.isForSameSession(key.roomId, key.senderKey, key.sessionId) && op.isForKey(key);
});
}
private findIndexSameSessionUnused(key: IRoomKey): number {
for (let i = this._entries.length - 1; i >= 0; i -= 1) {
const op = this._entries[i];
private findIndexSameSessionUnused(key: RoomKey): number {
return this._entries.reduce((worstIdx, op, i, arr) => {
const worst = worstIdx === -1 ? undefined : arr[worstIdx];
// we try to pick the worst operation to overwrite, so the best one stays in the cache
if (op.refCount === 0 && op.isForSameSession(key.roomId, key.senderKey, key.sessionId)) {
if (!worst || !op.isBetter(worst)) {
return i;
}
}
return -1;
return worstIdx;
}, -1);
}
private findIndexOldestUnused(): number {
@ -183,10 +187,10 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
class KeyOperation {
session: OlmInboundGroupSession;
key: IRoomKey;
key: RoomKey;
refCount: number;
constructor(key: IRoomKey, session: OlmInboundGroupSession) {
constructor(key: RoomKey, session: OlmInboundGroupSession) {
this.key = key;
this.session = session;
this.refCount = 1;
@ -201,7 +205,7 @@ class KeyOperation {
return isBetterThan(this.session, other.session);
}
isForKey(key: IRoomKey) {
isForKey(key: RoomKey) {
return this.key.serializationKey === key.serializationKey &&
this.key.serializationType === key.serializationType;
}
@ -209,18 +213,27 @@ class KeyOperation {
dispose() {
this.session.free();
}
/** returns whether the key for this operation has been checked at some point against storage
* and was determined to be the better key, undefined if it hasn't been checked yet.
* Only keys that are the best keys can be returned by getCachedKey as returning a cache hit
* will usually not check for a better session in storage. Also see RoomKey.isBetter. */
get isBest(): boolean | undefined {
return this.key.isBetter;
}
}
export function tests() {
let instances = 0;
class MockRoomKey implements IRoomKey {
class MockRoomKey extends IncomingRoomKey {
private _roomId: string;
private _senderKey: string;
private _sessionId: string;
private _firstKnownIndex: number;
constructor(roomId: string, senderKey: string, sessionId: string, firstKnownIndex: number) {
super();
this._roomId = roomId;
this._senderKey = senderKey;
this._sessionId = sessionId;
@ -267,7 +280,6 @@ export function tests() {
const bobSenderKey = "def";
const sessionId1 = "s123";
const sessionId2 = "s456";
const sessionId3 = "s789";
return {
"load key gives correct session": async assert => {
@ -362,6 +374,7 @@ export function tests() {
let resolve1, resolve2, invocations = 0;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
await loader.useKey(key1, async session => { invocations += 1; });
key1.isBetter = true;
assert.equal(loader.size, 1);
const cachedKey = loader.getCachedKey(roomId, aliceSenderKey, sessionId1)!;
assert.equal(cachedKey, key1);
@ -379,6 +392,42 @@ export function tests() {
loader.dispose();
assert.strictEqual(instances, 0, "instances");
assert.strictEqual(loader.size, 0, "loader.size");
}
},
"checkBetterThanKeyInStorage false with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {});
// fake we've checked with storage that this is the best key,
// and as long is it remains the best key with newly added keys,
// it will be returned from getCachedKey (as called from checkBetterThanKeyInStorage)
key1.isBetter = true;
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3);
// this will hit cache of key 1 so we pass in null as txn
const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any);
assert.strictEqual(isBetter, false);
assert.strictEqual(key2.isBetter, false);
},
"checkBetterThanKeyInStorage true with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key
await loader.useKey(key1, async session => {});
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
// this will hit cache of key 1 so we pass in null as txn
const isBetter = await key2.checkBetterThanKeyInStorage(loader, null as any);
assert.strictEqual(isBetter, true);
assert.strictEqual(key2.isBetter, true);
},
"prefer to remove worst key for a session from cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {});
key1.isBetter = true; // set to true just so it gets returned from getCachedKey
const key2 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 4);
await loader.useKey(key2, async session => {});
const key3 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 3);
await loader.useKey(key3, async session => {});
assert.strictEqual(loader.getCachedKey(roomId, aliceSenderKey, sessionId1), key1);
},
}
}

View file

@ -19,30 +19,33 @@ import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader";
export interface IRoomKey {
get roomId(): string;
get senderKey(): string;
get sessionId(): string;
get claimedEd25519Key(): string;
get serializationKey(): string;
get serializationType(): string;
get eventIds(): string[] | undefined;
loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
export abstract class RoomKey {
private _isBetter: boolean | undefined;
abstract get roomId(): string;
abstract get senderKey(): string;
abstract get sessionId(): string;
abstract get claimedEd25519Key(): string;
abstract get serializationKey(): string;
abstract get serializationType(): string;
abstract get eventIds(): string[] | undefined;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
/* Whether the key has been checked against storage (or is from storage)
* to be the better key for a given session. Given that all keys are checked to be better
* as part of writing, we can trust that when this returns true, it really is the best key
* available between storage and cached keys in memory. This is why keys with this field set to
* true are used by the key loader to return cached keys. Also see KeyOperation.isBest there. */
get isBetter(): boolean | undefined { return this._isBetter; }
// should only be set in key.checkBetterThanKeyInStorage
set isBetter(value: boolean | undefined) { this._isBetter = value; }
}
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index();
}
export interface IIncomingRoomKey extends IRoomKey {
get isBetter(): boolean | undefined;
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean>;
write(loader: KeyLoader, txn: Transaction): Promise<boolean>;
}
abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
export abstract class IncomingRoomKey extends RoomKey {
private _eventIds?: string[];
private _isBetter?: boolean;
checkBetterThanKeyInStorage(loader: KeyLoader, txn: Transaction): Promise<boolean> {
return this._checkBetterThanKeyInStorage(loader, undefined, txn);
@ -51,7 +54,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
// we checked already and we had a better session in storage, so don't write
let pickledSession;
if (this._isBetter === undefined) {
if (this.isBetter === undefined) {
// if this key wasn't used to decrypt any messages in the same sync,
// we haven't checked if this is the best key yet,
// so do that now to not overwrite a better key.
@ -60,7 +63,7 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
pickledSession = session.pickle(pickleKey);
}, txn);
}
if (this._isBetter === false) {
if (this.isBetter === false) {
return false;
}
// before calling write in parallel, we need to check loader.running is false so we are sure our transaction will not be closed
@ -79,11 +82,10 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
}
get eventIds() { return this._eventIds; }
get isBetter() { return this._isBetter; }
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this._isBetter !== undefined) {
return this._isBetter;
if (this.isBetter !== undefined) {
return this.isBetter;
}
let existingKey = loader.getCachedKey(this.roomId, this.senderKey, this.sessionId);
if (!existingKey) {
@ -100,32 +102,26 @@ abstract class BaseIncomingRoomKey implements IIncomingRoomKey {
}
if (existingKey) {
const key = existingKey;
this._isBetter = await loader.useKey(this, newSession => {
return loader.useKey(key, (existingSession, pickleKey) => {
const isBetter = isBetterThan(newSession, existingSession);
if (isBetter && callback) {
await loader.useKey(this, async newSession => {
await loader.useKey(key, (existingSession, pickleKey) => {
// set isBetter as soon as possible, on both keys compared,
// as it is is used to determine whether a key can be used for the cache
this.isBetter = isBetterThan(newSession, existingSession);
key.isBetter = !this.isBetter;
if (this.isBetter && callback) {
callback(newSession, pickleKey);
}
return isBetter;
});
});
} else {
// no previous key, so we're the best \o/
this._isBetter = true;
this.isBetter = true;
}
return this.isBetter!;
}
return this._isBetter!;
}
abstract get roomId(): string;
abstract get senderKey(): string;
abstract get sessionId(): string;
abstract get claimedEd25519Key(): string;
abstract get serializationKey(): string;
abstract get serializationType(): string;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void;
}
class DeviceMessageRoomKey extends BaseIncomingRoomKey {
class DeviceMessageRoomKey extends IncomingRoomKey {
private _decryptionResult: DecryptionResult;
constructor(decryptionResult: DecryptionResult) {
@ -145,7 +141,7 @@ class DeviceMessageRoomKey extends BaseIncomingRoomKey {
}
}
class BackupRoomKey extends BaseIncomingRoomKey {
class BackupRoomKey extends IncomingRoomKey {
private _roomId: string;
private _sessionId: string;
private _backupInfo: string;
@ -169,10 +165,12 @@ class BackupRoomKey extends BaseIncomingRoomKey {
}
}
class StoredRoomKey implements IRoomKey {
class StoredRoomKey extends RoomKey {
private storageEntry: InboundGroupSessionEntry;
constructor(storageEntry: InboundGroupSessionEntry) {
super();
this.isBetter = true; // usually the key in storage is the best until checks prove otherwise
this.storageEntry = storageEntry;
}
@ -192,7 +190,7 @@ class StoredRoomKey implements IRoomKey {
// sessions are stored before they are received
// to keep track of events that need it to be decrypted.
// This is used to retry decryption of those events once the session is received.
return !!this.storageEntry.session;
return !!this.serializationKey;
}
}

View file

@ -17,7 +17,7 @@ limitations under the License.
import {DecryptionResult} from "../../DecryptionResult.js";
import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry";
import type {IRoomKey} from "./RoomKey.js";
import type {RoomKey} from "./RoomKey.js";
import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader";
import type {OlmWorker} from "../../OlmWorker";
import type {TimelineEvent} from "../../../storage/types";
@ -31,13 +31,13 @@ interface DecryptAllResult {
* Does the actual decryption of all events for a given megolm session in a batch
*/
export class SessionDecryption {
private key: IRoomKey;
private key: RoomKey;
private events: TimelineEvent[];
private keyLoader: KeyLoader;
private olmWorker?: OlmWorker;
private decryptionRequests?: any[];
constructor(key: IRoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) {
constructor(key: RoomKey, events: TimelineEvent[], olmWorker: OlmWorker | undefined, keyLoader: KeyLoader) {
this.key = key;
this.events = events;
this.olmWorker = olmWorker;