diff --git a/src/matrix/e2ee/DeviceTracker.js b/src/matrix/e2ee/DeviceTracker.js index 0068a1f9..c640b04a 100644 --- a/src/matrix/e2ee/DeviceTracker.js +++ b/src/matrix/e2ee/DeviceTracker.js @@ -363,3 +363,125 @@ export class DeviceTracker { return await txn.deviceIdentities.getByCurve25519Key(curve25519Key); } } + +import {createMockStorage} from "../../mocks/Storage"; +import {Instance as NullLoggerInstance} from "../../logging/NullLogger"; + +export function tests() { + + function createUntrackedRoomMock(roomId, joinedUserIds, invitedUserIds = []) { + return { + isTrackingMembers: false, + isEncrypted: true, + loadMemberList: () => { + const joinedMembers = joinedUserIds.map(userId => {return {membership: "join", roomId, userId};}); + const invitedMembers = invitedUserIds.map(userId => {return {membership: "invite", roomId, userId};}); + const members = joinedMembers.concat(invitedMembers); + const memberMap = members.reduce((map, member) => { + map.set(member.userId, member); + return map; + }, new Map()); + return {members: memberMap, release() {}} + }, + writeIsTrackingMembers(isTrackingMembers) { + if (this.isTrackingMembers !== isTrackingMembers) { + return isTrackingMembers; + } + return undefined; + }, + applyIsTrackingMembersChanges(isTrackingMembers) { + if (isTrackingMembers !== undefined) { + this.isTrackingMembers = isTrackingMembers; + } + }, + } + } + + function createQueryKeysHSApiMock(createKey = (algorithm, userId, deviceId) => `${algorithm}:${userId}:${deviceId}:key`) { + return { + queryKeys(payload) { + const {device_keys: deviceKeys} = payload; + const userKeys = Object.entries(deviceKeys).reduce((userKeys, [userId, deviceIds]) => { + if (deviceIds.length === 0) { + deviceIds = ["device1"]; + } + userKeys[userId] = deviceIds.filter(d => d === "device1").reduce((deviceKeys, deviceId) => { + deviceKeys[deviceId] = { + "algorithms": [ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2" + ], + "device_id": deviceId, + "keys": { + [`curve25519:${deviceId}`]: createKey("curve25519", userId, deviceId), + [`ed25519:${deviceId}`]: createKey("ed25519", userId, deviceId), + }, + "signatures": { + [userId]: { + [`ed25519:${deviceId}`]: `ed25519:${userId}:${deviceId}:signature` + } + }, + "unsigned": { + "device_display_name": `${userId} Phone` + }, + "user_id": userId + }; + return deviceKeys; + }, {}); + return userKeys; + }, {}); + const response = {device_keys: userKeys}; + return { + async response() { + return response; + } + }; + } + }; + } + const roomId = "!abc:hs.tld"; + + return { + "trackRoom only writes joined members": async assert => { + const storage = await createMockStorage(); + const tracker = new DeviceTracker({ + storage, + getSyncToken: () => "token", + olmUtil: {ed25519_verify: () => {}}, // valid if it does not throw + ownUserId: "@alice:hs.tld", + ownDeviceId: "ABCD", + }); + const room = createUntrackedRoomMock(roomId, ["@alice:hs.tld", "@bob:hs.tld"], ["@charly:hs.tld"]); + await tracker.trackRoom(room, NullLoggerInstance.item); + const txn = await storage.readTxn([storage.storeNames.userIdentities]); + assert.deepEqual(await txn.userIdentities.get("@alice:hs.tld"), { + userId: "@alice:hs.tld", + roomIds: [roomId], + deviceTrackingStatus: TRACKING_STATUS_OUTDATED + }); + assert.deepEqual(await txn.userIdentities.get("@bob:hs.tld"), { + userId: "@bob:hs.tld", + roomIds: [roomId], + deviceTrackingStatus: TRACKING_STATUS_OUTDATED + }); + assert.equal(await txn.userIdentities.get("@charly:hs.tld"), undefined); + }, + "getting devices for tracked room yields correct keys": async assert => { + const storage = await createMockStorage(); + const tracker = new DeviceTracker({ + storage, + getSyncToken: () => "token", + olmUtil: {ed25519_verify: () => {}}, // valid if it does not throw + ownUserId: "@alice:hs.tld", + ownDeviceId: "ABCD", + }); + const room = createUntrackedRoomMock(roomId, ["@alice:hs.tld", "@bob:hs.tld"]); + await tracker.trackRoom(room, NullLoggerInstance.item); + const hsApi = createQueryKeysHSApiMock(); + const devices = await tracker.devicesForRoomMembers(roomId, ["@alice:hs.tld", "@bob:hs.tld"], hsApi, NullLoggerInstance.item); + assert.equal(devices.length, 2); + assert.equal(devices.find(d => d.userId === "@alice:hs.tld").ed25519Key, "ed25519:@alice:hs.tld:device1:key"); + assert.equal(devices.find(d => d.userId === "@bob:hs.tld").ed25519Key, "ed25519:@bob:hs.tld:device1:key"); + }, + } +}