diff --git a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts index 1618a4a6..375fb2ac 100644 --- a/src/matrix/e2ee/megolm/decryption/KeyLoader.ts +++ b/src/matrix/e2ee/megolm/decryption/KeyLoader.ts @@ -156,7 +156,7 @@ export class KeyLoader extends BaseLRUCache { private findIndexSameKey(key: IRoomKey): number { return this._entries.findIndex(op => { - return op.isForKey(key); + return op.isForSameSession(key.roomId, key.senderKey, key.sessionId) && op.isForKey(key); }); } @@ -213,7 +213,6 @@ class KeyOperation { export function tests() { let instances = 0; - let idCounter = 0; class MockRoomKey implements IRoomKey { private _roomId: string; @@ -232,7 +231,7 @@ export function tests() { get senderKey(): string { return this._senderKey; } get sessionId(): string { return this._sessionId; } get claimedEd25519Key(): string { return "claimedEd25519Key"; } - get serializationKey(): string { return "key"; } + get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; } get serializationType(): string { return "type"; } get eventIds(): string[] | undefined { return undefined; } loadInto(session: OlmInboundGroupSession) { @@ -246,6 +245,10 @@ export function tests() { public sessionId: string = ""; public firstKnownIndex: number = 0; + constructor() { + instances += 1; + } + free(): void { instances -= 1; } pickle(key: string | Uint8Array): string { return `${this.sessionId}-pickled-session`; } unpickle(key: string | Uint8Array, pickle: string) {} @@ -262,18 +265,109 @@ export function tests() { const roomId = "!abc:hs.tld"; const aliceSenderKey = "abc"; const bobSenderKey = "def"; - const sessionId = "s123"; + const sessionId1 = "s123"; + const sessionId2 = "s456"; + const sessionId3 = "s789"; return { "load key gives correct session": async assert => { - const loader = new KeyLoader(olm, PICKLE_KEY, 5); + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let callback1Called = false; + let callback2Called = false; + const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + callback1Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 1); + await Promise.resolve(); // make sure they are busy in parallel + }); + const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 2), async session => { + callback2Called = true; + assert.equal(session.session_id(), sessionId2); + assert.equal(session.first_known_index(), 2); + await Promise.resolve(); // make sure they are busy in parallel + }); + assert.equal(loader.size, 2); + await Promise.all([p1, p2]); + assert(callback1Called); + assert(callback2Called); + }, + "keys with different first index are kept separate": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let callback1Called = false; + let callback2Called = false; + const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + callback1Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 1); + await Promise.resolve(); // make sure they are busy in parallel + }); + const p2 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2), async session => { + callback2Called = true; + assert.equal(session.session_id(), sessionId1); + assert.equal(session.first_known_index(), 2); + await Promise.resolve(); // make sure they are busy in parallel + }); + assert.equal(loader.size, 2); + await Promise.all([p1, p2]); + assert(callback1Called); + assert(callback2Called); + }, + "useKey blocks as long as no free sessions are available": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 1); + let resolve; let callbackCalled = false; - await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId, 1), session => { + loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { + await new Promise(r => resolve = r); + }); + await Promise.resolve(); + assert.equal(loader.size, 1); + const promise = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), session => { callbackCalled = true; - assert.equal(session.session_id(), sessionId); + }); + assert.equal(callbackCalled, false); + resolve(); + await promise; + assert.equal(callbackCalled, true); + }, + "cache hit while key in use, then replace (check refCount works properly)": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 1); + let resolve1, resolve2; + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + const p1 = loader.useKey(key1, async session => { + await new Promise(r => resolve1 = r); + }); + const p2 = loader.useKey(key1, async session => { + await new Promise(r => resolve2 = r); + }); + await Promise.resolve(); + assert.equal(loader.size, 1); + assert.equal(loader.running, true); + resolve1(); + await p1; + assert.equal(loader.running, true); + resolve2(); + await p2; + assert.equal(loader.running, false); + let callbackCalled = false; + await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => { + callbackCalled = true; + assert.equal(session.session_id(), sessionId2); assert.equal(session.first_known_index(), 1); }); - assert(callbackCalled); + assert.equal(loader.size, 1); + assert.equal(callbackCalled, true); }, + "cache hit while key not in use": async assert => { + const loader = new KeyLoader(olm, PICKLE_KEY, 2); + let resolve1, resolve2, invocations = 0; + const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); + await loader.useKey(key1, async session => { invocations += 1; }); + assert.equal(loader.size, 1); + const cachedKey = loader.getCachedKey(roomId, aliceSenderKey, sessionId1)!; + assert.equal(cachedKey, key1); + await loader.useKey(cachedKey, async session => { invocations += 1; }); + assert.equal(loader.size, 1); + assert.equal(invocations, 2); + } } }