make KeyLoader use proper olm types

This commit is contained in:
Bruno Windels 2022-01-20 11:15:48 +01:00
parent 30438846e9
commit a4d924acd1
3 changed files with 25 additions and 34 deletions

View file

@ -17,25 +17,14 @@ limitations under the License.
import {isBetterThan, IncomingRoomKey} from "./RoomKey"; import {isBetterThan, IncomingRoomKey} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache"; import {BaseLRUCache} from "../../../../utils/LRUCache";
import type {RoomKey} from "./RoomKey"; import type {RoomKey} from "./RoomKey";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export declare class OlmDecryptionResult { export declare class OlmDecryptionResult {
readonly plaintext: string; readonly plaintext: string;
readonly message_index: number; readonly message_index: number;
} }
export declare class OlmInboundGroupSession {
constructor();
free(): void;
pickle(key: string | Uint8Array): string;
unpickle(key: string | Uint8Array, pickle: string);
create(session_key: string): string;
import_session(session_key: string): string;
decrypt(message: string): OlmDecryptionResult;
session_id(): string;
first_known_index(): number;
export_session(message_index: number): string;
}
/* /*
Because Olm only has very limited memory available when compiled to wasm, Because Olm only has very limited memory available when compiled to wasm,
we limit the amount of sessions held in memory. we limit the amount of sessions held in memory.
@ -43,11 +32,11 @@ we limit the amount of sessions held in memory.
export class KeyLoader extends BaseLRUCache<KeyOperation> { export class KeyLoader extends BaseLRUCache<KeyOperation> {
private pickleKey: string; private pickleKey: string;
private olm: any; private olm: Olm;
private resolveUnusedOperation?: () => void; private resolveUnusedOperation?: () => void;
private operationBecomesUnusedPromise?: Promise<void>; private operationBecomesUnusedPromise?: Promise<void>;
constructor(olm: any, pickleKey: string, limit: number) { constructor(olm: Olm, pickleKey: string, limit: number) {
super(limit); super(limit);
this.pickleKey = pickleKey; this.pickleKey = pickleKey;
this.olm = olm; this.olm = olm;
@ -60,7 +49,7 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
} }
} }
async useKey<T>(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> { async useKey<T>(key: RoomKey, callback: (session: Olm.InboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const keyOp = await this.allocateOperation(key); const keyOp = await this.allocateOperation(key);
try { try {
return await callback(keyOp.session, this.pickleKey); return await callback(keyOp.session, this.pickleKey);
@ -186,11 +175,11 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
} }
class KeyOperation { class KeyOperation {
session: OlmInboundGroupSession; session: Olm.InboundGroupSession;
key: RoomKey; key: RoomKey;
refCount: number; refCount: number;
constructor(key: RoomKey, session: OlmInboundGroupSession) { constructor(key: RoomKey, session: Olm.InboundGroupSession) {
this.key = key; this.key = key;
this.session = session; this.session = session;
this.refCount = 1; this.refCount = 1;
@ -248,7 +237,7 @@ export function tests() {
get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; } get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; }
get serializationType(): string { return "type"; } get serializationType(): string { return "type"; }
get eventIds(): string[] | undefined { return undefined; } get eventIds(): string[] | undefined { return undefined; }
loadInto(session: OlmInboundGroupSession) { loadInto(session: Olm.InboundGroupSession) {
const mockSession = session as MockInboundSession; const mockSession = session as MockInboundSession;
mockSession.sessionId = this.sessionId; mockSession.sessionId = this.sessionId;
mockSession.firstKnownIndex = this._firstKnownIndex; mockSession.firstKnownIndex = this._firstKnownIndex;
@ -284,7 +273,7 @@ export function tests() {
return { return {
"load key gives correct session": async assert => { "load key gives correct session": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let callback1Called = false; let callback1Called = false;
let callback2Called = false; let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -305,7 +294,7 @@ export function tests() {
assert(callback2Called); assert(callback2Called);
}, },
"keys with different first index are kept separate": async assert => { "keys with different first index are kept separate": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let callback1Called = false; let callback1Called = false;
let callback2Called = false; let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -326,7 +315,7 @@ export function tests() {
assert(callback2Called); assert(callback2Called);
}, },
"useKey blocks as long as no free sessions are available": async assert => { "useKey blocks as long as no free sessions are available": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1);
let resolve; let resolve;
let callbackCalled = false; let callbackCalled = false;
loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -343,7 +332,7 @@ export function tests() {
assert.equal(callbackCalled, true); assert.equal(callbackCalled, true);
}, },
"cache hit while key in use, then replace (check refCount works properly)": async assert => { "cache hit while key in use, then replace (check refCount works properly)": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1);
let resolve1, resolve2; let resolve1, resolve2;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
const p1 = loader.useKey(key1, async session => { const p1 = loader.useKey(key1, async session => {
@ -371,7 +360,7 @@ export function tests() {
assert.equal(callbackCalled, true); assert.equal(callbackCalled, true);
}, },
"cache hit while key not in use": async assert => { "cache hit while key not in use": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let resolve1, resolve2, invocations = 0; let resolve1, resolve2, invocations = 0;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
await loader.useKey(key1, async session => { invocations += 1; }); await loader.useKey(key1, async session => { invocations += 1; });
@ -385,7 +374,7 @@ export function tests() {
}, },
"dispose calls free on all sessions": async assert => { "dispose calls free on all sessions": async assert => {
instances = 0; instances = 0;
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {}); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {});
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {}); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {});
assert.equal(instances, 2); assert.equal(instances, 2);
@ -395,7 +384,7 @@ export function tests() {
assert.strictEqual(loader.size, 0, "loader.size"); assert.strictEqual(loader.size, 0, "loader.size");
}, },
"checkBetterThanKeyInStorage false with cache": async assert => { "checkBetterThanKeyInStorage false with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
// fake we've checked with storage that this is the best key, // fake we've checked with storage that this is the best key,
@ -409,7 +398,7 @@ export function tests() {
assert.strictEqual(key2.isBetter, false); assert.strictEqual(key2.isBetter, false);
}, },
"checkBetterThanKeyInStorage true with cache": async assert => { "checkBetterThanKeyInStorage true with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
@ -420,7 +409,7 @@ export function tests() {
assert.strictEqual(key2.isBetter, true); assert.strictEqual(key2.isBetter, true);
}, },
"prefer to remove worst key for a session from cache": async assert => { "prefer to remove worst key for a session from cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
key1.isBetter = true; // set to true just so it gets returned from getCachedKey key1.isBetter = true; // set to true just so it gets returned from getCachedKey

View file

@ -17,7 +17,9 @@ limitations under the License.
import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult"; import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; import type {KeyLoader} from "./KeyLoader";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export abstract class RoomKey { export abstract class RoomKey {
private _isBetter: boolean | undefined; private _isBetter: boolean | undefined;
@ -33,7 +35,7 @@ export abstract class RoomKey {
abstract get serializationKey(): string; abstract get serializationKey(): string;
abstract get serializationType(): string; abstract get serializationType(): string;
abstract get eventIds(): string[] | undefined; abstract get eventIds(): string[] | undefined;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; abstract loadInto(session: Olm.InboundGroupSession, pickleKey: string): void;
/* Whether the key has been checked against storage (or is from storage) /* Whether the key has been checked against storage (or is from storage)
* to be the better key for a given session. Given that all keys are checked to be better * to be the better key for a given session. Given that all keys are checked to be better
* as part of writing, we can trust that when this returns true, it really is the best key * as part of writing, we can trust that when this returns true, it really is the best key
@ -44,7 +46,7 @@ export abstract class RoomKey {
set isBetter(value: boolean | undefined) { this._isBetter = value; } set isBetter(value: boolean | undefined) { this._isBetter = value; }
} }
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { export function isBetterThan(newSession: Olm.InboundGroupSession, existingSession: Olm.InboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index(); return newSession.first_known_index() < existingSession.first_known_index();
} }
@ -87,7 +89,7 @@ export abstract class IncomingRoomKey extends RoomKey {
get eventIds() { return this._eventIds; } get eventIds() { return this._eventIds; }
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> { private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: Olm.InboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this.isBetter !== undefined) { if (this.isBetter !== undefined) {
return this.isBetter; return this.isBetter;
} }

View file

@ -17,7 +17,7 @@ limitations under the License.
import {DecryptionResult} from "../../DecryptionResult.js"; import {DecryptionResult} from "../../DecryptionResult.js";
import {DecryptionError} from "../../common.js"; import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; import {ReplayDetectionEntry} from "./ReplayDetectionEntry";
import type {RoomKey} from "./RoomKey.js"; import type {RoomKey} from "./RoomKey";
import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader"; import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader";
import type {OlmWorker} from "../../OlmWorker"; import type {OlmWorker} from "../../OlmWorker";
import type {TimelineEvent} from "../../../storage/types"; import type {TimelineEvent} from "../../../storage/types";
@ -61,7 +61,7 @@ export class SessionDecryption {
this.decryptionRequests!.push(request); this.decryptionRequests!.push(request);
decryptionResult = await request.response(); decryptionResult = await request.response();
} else { } else {
decryptionResult = session.decrypt(ciphertext); decryptionResult = session.decrypt(ciphertext) as OlmDecryptionResult;
} }
const {plaintext} = decryptionResult!; const {plaintext} = decryptionResult!;
let payload; let payload;