diff --git a/src/matrix/DeviceMessageHandler.js b/src/matrix/DeviceMessageHandler.js index 4da52e71..27f1d386 100644 --- a/src/matrix/DeviceMessageHandler.js +++ b/src/matrix/DeviceMessageHandler.js @@ -14,11 +14,8 @@ See the License for the specific language governing permissions and limitations under the License. */ -import {OLM_ALGORITHM, MEGOLM_ALGORITHM} from "./e2ee/common.js"; -import {countBy} from "../utils/groupBy.js"; - -// key to store in session store -const PENDING_ENCRYPTED_EVENTS = "pendingEncryptedDeviceEvents"; +import {OLM_ALGORITHM} from "./e2ee/common.js"; +import {countBy, groupBy} from "../utils/groupBy.js"; export class DeviceMessageHandler { constructor({storage}) { @@ -32,90 +29,50 @@ export class DeviceMessageHandler { this._megolmDecryption = megolmDecryption; } - /** - * @return {bool} whether messages are waiting to be decrypted and `decryptPending` should be called. - */ - async writeSync(toDeviceEvents, txn, log) { + obtainSyncLock(toDeviceEvents) { + return this._olmDecryption?.obtainDecryptionLock(toDeviceEvents); + } + + async prepareSync(toDeviceEvents, lock, txn, log) { + log.set("messageTypes", countBy(toDeviceEvents, e => e.type)); const encryptedEvents = toDeviceEvents.filter(e => e.type === "m.room.encrypted"); - log.set("eventsCount", countBy(toDeviceEvents, e => e.type)); - if (!encryptedEvents.length) { - return false; + if (!this._olmDecryption) { + log.log("can't decrypt, encryption not enabled", log.level.Warn); + return; + } + // only know olm for now + const olmEvents = encryptedEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM); + if (olmEvents.length) { + const olmDecryptChanges = await this._olmDecryption.decryptAll(olmEvents, lock, txn); + log.set("decryptedTypes", countBy(olmDecryptChanges.results, r => r.event?.type)); + for (const err of olmDecryptChanges.errors) { + log.child("decrypt_error").catch(err); + } + const newRoomKeys = this._megolmDecryption.roomKeysFromDeviceMessages(olmDecryptChanges.results, log); + return new SyncPreparation(olmDecryptChanges, newRoomKeys); } - // store encryptedEvents - let pendingEvents = await this._getPendingEvents(txn); - pendingEvents = pendingEvents.concat(encryptedEvents); - txn.session.set(PENDING_ENCRYPTED_EVENTS, pendingEvents); - // we don't handle anything other for now - return true; } - /** - * [_writeDecryptedEvents description] - * @param {Array} olmResults - * @param {[type]} txn [description] - * @return {[type]} [description] - */ - async _writeDecryptedEvents(olmResults, txn, log) { - const megOlmRoomKeysResults = olmResults.filter(r => { - return r.event?.type === "m.room_key" && r.event.content?.algorithm === MEGOLM_ALGORITHM; - }); - let roomKeys; - log.set("eventsCount", countBy(olmResults, r => r.event.type)); - log.set("roomKeys", megOlmRoomKeysResults.length); - if (megOlmRoomKeysResults.length) { - roomKeys = await this._megolmDecryption.addRoomKeys(megOlmRoomKeysResults, txn, log); - } - log.set("newRoomKeys", roomKeys.length); - return {roomKeys}; + /** check that prep is not undefined before calling this */ + async writeSync(prep, txn) { + // write olm changes + prep.olmDecryptChanges.write(txn); + await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn))); + } +} + +class SyncPreparation { + constructor(olmDecryptChanges, newRoomKeys) { + this.olmDecryptChanges = olmDecryptChanges; + this.newRoomKeys = newRoomKeys; + this.newKeysByRoom = groupBy(newRoomKeys, r => r.roomId); } - async _applyDecryptChanges(rooms, {roomKeys}) { - if (Array.isArray(roomKeys)) { - for (const roomKey of roomKeys) { - const room = rooms.get(roomKey.roomId); - // TODO: this is less parallized than it could be (like sync) - await room?.notifyRoomKey(roomKey); + dispose() { + if (this.newRoomKeys) { + for (const k of this.newRoomKeys) { + k.dispose(); } } } - - // not safe to call multiple times without awaiting first call - async decryptPending(rooms, log) { - if (!this._olmDecryption) { - return; - } - const readTxn = this._storage.readTxn([this._storage.storeNames.session]); - const pendingEvents = await this._getPendingEvents(readTxn); - log.set("eventCount", pendingEvents.length); - if (pendingEvents.length === 0) { - return; - } - // only know olm for now - const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM); - const decryptChanges = await this._olmDecryption.decryptAll(olmEvents); - for (const err of decryptChanges.errors) { - log.child("decrypt_error").catch(err); - } - const txn = this._storage.readWriteTxn([ - // both to remove the pending events and to modify the olm account - this._storage.storeNames.session, - this._storage.storeNames.olmSessions, - this._storage.storeNames.inboundGroupSessions, - ]); - let changes; - try { - changes = await this._writeDecryptedEvents(decryptChanges.results, txn, log); - decryptChanges.write(txn); - txn.session.remove(PENDING_ENCRYPTED_EVENTS); - } catch (err) { - txn.abort(); - throw err; - } - await txn.complete(); - await this._applyDecryptChanges(rooms, changes); - } - - async _getPendingEvents(txn) { - return (await txn.session.get(PENDING_ENCRYPTED_EVENTS)) || []; - } } diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 005f58d1..4a825507 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -374,12 +374,25 @@ export class Session { return room; } + async obtainSyncLock(syncResponse) { + const toDeviceEvents = syncResponse.to_device?.events; + if (Array.isArray(toDeviceEvents) && toDeviceEvents.length) { + return await this._deviceMessageHandler.obtainSyncLock(toDeviceEvents); + } + } + + async prepareSync(syncResponse, lock, txn, log) { + const toDeviceEvents = syncResponse.to_device?.events; + if (Array.isArray(toDeviceEvents) && toDeviceEvents.length) { + return await log.wrap("deviceMsgs", log => this._deviceMessageHandler.prepareSync(toDeviceEvents, lock, txn, log)); + } + } + /** @internal */ - async writeSync(syncResponse, syncFilterId, txn, log) { + async writeSync(syncResponse, syncFilterId, preparation, txn, log) { const changes = { syncInfo: null, e2eeAccountChanges: null, - deviceMessageDecryptionPending: false }; const syncToken = syncResponse.next_batch; if (syncToken !== this.syncToken) { @@ -399,10 +412,8 @@ export class Session { await log.wrap("deviceLists", log => this._deviceTracker.writeDeviceChanges(deviceLists.changed, txn, log)); } - const toDeviceEvents = syncResponse.to_device?.events; - if (Array.isArray(toDeviceEvents) && toDeviceEvents.length) { - changes.deviceMessageDecryptionPending = - await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(toDeviceEvents, txn, log)); + if (preparation) { + await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log)); } // store account data @@ -431,9 +442,6 @@ export class Session { /** @internal */ async afterSyncCompleted(changes, isCatchupSync, log) { const promises = []; - if (changes.deviceMessageDecryptionPending) { - promises.push(log.wrap("decryptPending", log => this._deviceMessageHandler.decryptPending(this.rooms, log))); - } // we don't start uploading one-time keys until we've caught up with // to-device messages, to help us avoid throwing away one-time-keys that we // are about to receive messages for diff --git a/src/matrix/Sync.js b/src/matrix/Sync.js index 89e807b8..9e8f7997 100644 --- a/src/matrix/Sync.js +++ b/src/matrix/Sync.js @@ -40,7 +40,7 @@ function timelineIsEmpty(roomResponse) { * Sync steps in js-pseudocode: * ```js * // can only read some stores - * const preparation = await room.prepareSync(roomResponse, membership, prepareTxn); + * const preparation = await room.prepareSync(roomResponse, membership, newRoomKeys, prepareTxn); * // can do async work that is not related to storage (such as decryption) * await room.afterPrepareSync(preparation); * // writes and calculates changes @@ -190,35 +190,44 @@ export class Sync { const response = await this._currentRequest.response(); const isInitialSync = !syncToken; + const sessionState = new SessionSyncProcessState(); const roomStates = this._parseRoomsResponse(response.rooms, isInitialSync); - await log.wrap("prepare", log => this._prepareRooms(roomStates, log)); - - let sessionChanges; - await log.wrap("write", async log => { - const syncTxn = this._openSyncTxn(); - try { - sessionChanges = await log.wrap("session", log => this._session.writeSync(response, syncFilterId, syncTxn, log)); - await Promise.all(roomStates.map(async rs => { - rs.changes = await log.wrap("room", log => rs.room.writeSync( - rs.roomResponse, isInitialSync, rs.preparation, syncTxn, log)); - })); - } catch(err) { - // avoid corrupting state by only - // storing the sync up till the point - // the exception occurred + try { + // take a lock on olm sessions used in this sync so sending a message doesn't change them while syncing + sessionState.lock = await log.wrap("obtainSyncLock", () => this._session.obtainSyncLock(response)); + await log.wrap("prepare", log => this._prepareSessionAndRooms(sessionState, roomStates, response, log)); + await log.wrap("afterPrepareSync", log => Promise.all(roomStates.map(rs => { + return rs.room.afterPrepareSync(rs.preparation, log); + }))); + await log.wrap("write", async log => { + const syncTxn = this._openSyncTxn(); try { - syncTxn.abort(); - } catch (abortErr) { - log.set("couldNotAbortTxn", true); + sessionState.changes = await log.wrap("session", log => this._session.writeSync( + response, syncFilterId, sessionState.preparation, syncTxn, log)); + await Promise.all(roomStates.map(async rs => { + rs.changes = await log.wrap("room", log => rs.room.writeSync( + rs.roomResponse, isInitialSync, rs.preparation, syncTxn, log)); + })); + } catch(err) { + // avoid corrupting state by only + // storing the sync up till the point + // the exception occurred + try { + syncTxn.abort(); + } catch (abortErr) { + log.set("couldNotAbortTxn", true); + } + throw err; } - throw err; - } - await syncTxn.complete(); - }); + await syncTxn.complete(); + }); + } finally { + sessionState.dispose(); + } log.wrap("after", log => { - log.wrap("session", log => this._session.afterSync(sessionChanges, log), log.level.Detail); + log.wrap("session", log => this._session.afterSync(sessionState.changes, log), log.level.Detail); // emit room related events after txn has been closed for(let rs of roomStates) { log.wrap("room", log => rs.room.afterSync(rs.changes, log), log.level.Detail); @@ -229,7 +238,7 @@ export class Sync { return { syncToken: response.next_batch, roomStates, - sessionChanges, + sessionChanges: sessionState.changes, hadToDeviceMessages: Array.isArray(toDeviceEvents) && toDeviceEvents.length > 0, }; } @@ -237,18 +246,26 @@ export class Sync { _openPrepareSyncTxn() { const storeNames = this._storage.storeNames; return this._storage.readTxn([ + storeNames.olmSessions, storeNames.inboundGroupSessions, ]); } - async _prepareRooms(roomStates, log) { + async _prepareSessionAndRooms(sessionState, roomStates, response, log) { const prepareTxn = this._openPrepareSyncTxn(); + sessionState.preparation = await log.wrap("session", log => this._session.prepareSync( + response, sessionState.lock, prepareTxn, log)); + + const newKeysByRoom = sessionState.preparation?.newKeysByRoom; + await Promise.all(roomStates.map(async rs => { - rs.preparation = await log.wrap("room", log => rs.room.prepareSync(rs.roomResponse, rs.membership, prepareTxn, log), log.level.Detail); + const newKeys = newKeysByRoom?.get(rs.room.id); + rs.preparation = await log.wrap("room", log => rs.room.prepareSync( + rs.roomResponse, rs.membership, newKeys, prepareTxn, log), log.level.Detail); })); + // This is needed for safari to not throw TransactionInactiveErrors on the syncTxn. See docs/INDEXEDDB.md await prepareTxn.complete(); - await Promise.all(roomStates.map(rs => rs.room.afterPrepareSync(rs.preparation, log))); } _openSyncTxn() { @@ -269,6 +286,9 @@ export class Sync { storeNames.outboundGroupSessions, storeNames.operations, storeNames.accountData, + // to decrypt and store new room keys + storeNames.olmSessions, + storeNames.inboundGroupSessions, ]); } @@ -311,6 +331,19 @@ export class Sync { } } +class SessionSyncProcessState { + constructor() { + this.lock = null; + this.preparation = null; + this.changes = null; + } + + dispose() { + this.lock?.release(); + this.preparation?.dispose(); + } +} + class RoomSyncProcessState { constructor(room, roomResponse, membership) { this.room = room; diff --git a/src/matrix/e2ee/RoomEncryption.js b/src/matrix/e2ee/RoomEncryption.js index c0dc17e7..70374743 100644 --- a/src/matrix/e2ee/RoomEncryption.js +++ b/src/matrix/e2ee/RoomEncryption.js @@ -93,7 +93,7 @@ export class RoomEncryption { // this happens before entries exists, as they are created by the syncwriter // but we want to be able to map it back to something in the timeline easily // when retrying decryption. - async prepareDecryptAll(events, source, isTimelineOpen, txn) { + async prepareDecryptAll(events, newKeys, source, isTimelineOpen, txn) { const errors = new Map(); const validEvents = []; for (const event of events) { @@ -107,6 +107,8 @@ export class RoomEncryption { } let customCache; let sessionCache; + // we have different caches so we can keep them small but still + // have backfill and sync not invalidate each other if (source === DecryptionSource.Sync) { sessionCache = this._megolmSyncCache; } else if (source === DecryptionSource.Timeline) { @@ -120,7 +122,7 @@ export class RoomEncryption { throw new Error("Unknown source: " + source); } const preparation = await this._megolmDecryption.prepareDecryptAll( - this._room.id, validEvents, sessionCache, txn); + this._room.id, validEvents, newKeys, sessionCache, txn); if (customCache) { customCache.dispose(); } @@ -188,20 +190,27 @@ export class RoomEncryption { console.warn("Got session key back from backup with different sender key, ignoring", {session, senderKey}); return; } - const txn = this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]); - let roomKey; - try { - roomKey = await this._megolmDecryption.addRoomKeyFromBackup( - this._room.id, sessionId, session, txn); - } catch (err) { - txn.abort(); - throw err; - } - await txn.complete(); - + let roomKey = this._megolmDecryption.roomKeyFromBackup(this._room.id, sessionId, session); if (roomKey) { - // this will reattempt decryption - await this._room.notifyRoomKey(roomKey); + let keyIsBestOne = false; + try { + const txn = this._storage.readWriteTxn([this._storage.storeNames.inboundGroupSessions]); + try { + keyIsBestOne = await this._megolmDecryption.writeRoomKey(roomKey, txn); + } catch (err) { + txn.abort(); + throw err; + } + await txn.complete(); + } finally { + // can still access properties on it afterwards + // this is just clearing the internal sessionInfo + roomKey.dispose(); + } + if (keyIsBestOne) { + // wrote the key, meaning we didn't have a better one, go ahead and reattempt decryption + await this._room.notifyRoomKey(roomKey); + } } } else if (session?.algorithm) { console.info(`Backed-up session of unknown algorithm: ${session.algorithm}`); @@ -212,12 +221,7 @@ export class RoomEncryption { } /** - * @type {RoomKeyDescription} - * @property {RoomKeyDescription} senderKey the curve25519 key of the sender - * @property {RoomKeyDescription} sessionId - * - * - * @param {Array} roomKeys + * @param {RoomKey} roomKeys * @return {Array} the event ids that should be retried to decrypt */ getEventIdsForRoomKey(roomKey) { diff --git a/src/matrix/e2ee/megolm/Decryption.js b/src/matrix/e2ee/megolm/Decryption.js index 88a99cd7..80a55961 100644 --- a/src/matrix/e2ee/megolm/Decryption.js +++ b/src/matrix/e2ee/megolm/Decryption.js @@ -16,11 +16,12 @@ limitations under the License. import {DecryptionError} from "../common.js"; import {groupBy} from "../../../utils/groupBy.js"; - +import * as RoomKey from "./decryption/RoomKey.js"; import {SessionInfo} from "./decryption/SessionInfo.js"; import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js"; import {SessionDecryption} from "./decryption/SessionDecryption.js"; import {SessionCache} from "./decryption/SessionCache.js"; +import {MEGOLM_ALGORITHM} from "../common.js"; function getSenderKey(event) { return event.content?.["sender_key"]; @@ -49,12 +50,13 @@ export class Decryption { * Reads all the state from storage to be able to decrypt the given events. * Decryption can then happen outside of a storage transaction. * @param {[type]} roomId [description] - * @param {[type]} events [description] + * @param {[type]} events [description] + * @param {RoomKey[]?} newKeys keys as returned from extractRoomKeys, but not yet committed to storage. May be undefined. * @param {[type]} sessionCache [description] * @param {[type]} txn [description] * @return {DecryptionPreparation} */ - async prepareDecryptAll(roomId, events, sessionCache, txn) { + async prepareDecryptAll(roomId, events, newKeys, sessionCache, txn) { const errors = new Map(); const validEvents = []; @@ -74,27 +76,38 @@ export class Decryption { }); const sessionDecryptions = []; - await Promise.all(Array.from(eventsBySession.values()).map(async eventsForSession => { - const first = eventsForSession[0]; - const senderKey = getSenderKey(first); - const sessionId = getSessionId(first); - const sessionInfo = await this._getSessionInfo(roomId, senderKey, sessionId, sessionCache, txn); - if (!sessionInfo) { + const firstEvent = eventsForSession[0]; + const sessionInfo = await this._getSessionInfoForEvent(roomId, firstEvent, newKeys, sessionCache, txn); + if (sessionInfo) { + sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._olmWorker)); + } else { for (const event of eventsForSession) { errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event)); } - } else { - sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._olmWorker)); } })); return new DecryptionPreparation(roomId, sessionDecryptions, errors); } - async _getSessionInfo(roomId, senderKey, sessionId, sessionCache, txn) { + async _getSessionInfoForEvent(roomId, event, newKeys, sessionCache, txn) { + const senderKey = getSenderKey(event); + const sessionId = getSessionId(event); let sessionInfo; - sessionInfo = sessionCache.get(roomId, senderKey, sessionId); + if (newKeys) { + const key = newKeys.find(k => k.roomId === roomId && k.senderKey === senderKey && k.sessionId === sessionId); + if (key) { + sessionInfo = await key.createSessionInfo(this._olm, this._pickleKey, txn); + if (sessionInfo) { + sessionCache.add(sessionInfo); + } + } + } + // look only in the cache after looking into newKeys as it may contains that are better + if (!sessionInfo) { + sessionInfo = sessionCache.get(roomId, senderKey, sessionId); + } if (!sessionInfo) { const sessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); if (sessionEntry) { @@ -113,111 +126,45 @@ export class Decryption { } /** - * @type {MegolmInboundSessionDescription} - * @property {string} senderKey the sender key of the session - * @property {string} sessionId the session identifier - * - * Adds room keys as inbound group sessions - * @param {Array} decryptionResults an array of m.room_key decryption results. - * @param {[type]} txn a storage transaction with read/write on inboundGroupSessions - * @return {Promise>} an array with the newly added sessions + * Writes the key as an inbound group session if there is not already a better key in the store + * @param {RoomKey} key + * @param {Transaction} txn a storage transaction with read/write on inboundGroupSessions + * @return {Promise} whether the key was the best for the sessio id and was written */ - async addRoomKeys(decryptionResults, txn, log) { - const newSessions = []; - for (const {senderCurve25519Key: senderKey, event, claimedEd25519Key} of decryptionResults) { - await log.wrap("room_key", async log => { - const roomId = event.content?.["room_id"]; - const sessionId = event.content?.["session_id"]; - const sessionKey = event.content?.["session_key"]; + writeRoomKey(key, txn) { + return key.write(this._olm, this._pickleKey, txn); + } - log.set("roomId", roomId); - log.set("sessionId", sessionId); - - if ( - typeof roomId !== "string" || - typeof sessionId !== "string" || - typeof senderKey !== "string" || - typeof sessionKey !== "string" - ) { + /** + * Extracts room keys from decrypted device messages. + * The key won't be persisted yet, you need to call RoomKey.write for that. + * + * @param {Array} decryptionResults, any non megolm m.room_key messages will be ignored. + * @return {Array} an array with validated RoomKey's. Note that it is possible we already have a better version of this key in storage though; writing the key will tell you so. + */ + roomKeysFromDeviceMessages(decryptionResults, log) { + let keys = []; + for (const dr of decryptionResults) { + if (dr.event?.type !== "m.room_key" || dr.event.content?.algorithm !== MEGOLM_ALGORITHM) { + continue; + } + log.wrap("room_key", log => { + const key = RoomKey.fromDeviceMessage(dr); + if (key) { + log.set("roomId", key.roomId); + log.set("id", key.sessionId); + keys.push(key); + } else { log.logLevel = log.level.Warn; log.set("invalid", true); - return; - } - - const session = new this._olm.InboundGroupSession(); - try { - session.create(sessionKey); - const sessionEntry = await this._writeInboundSession( - session, roomId, senderKey, claimedEd25519Key, sessionId, txn); - if (sessionEntry) { - newSessions.push(sessionEntry); - } - } finally { - session.free(); } }, log.level.Detail); } - // this will be passed to the Room in notifyRoomKeys - return newSessions; + return keys; } - /* - sessionInfo is a response from key backup and has the following keys: - algorithm - forwarding_curve25519_key_chain - sender_claimed_keys - sender_key - session_key - */ - async addRoomKeyFromBackup(roomId, sessionId, sessionInfo, txn) { - const sessionKey = sessionInfo["session_key"]; - const senderKey = sessionInfo["sender_key"]; - // TODO: can we just trust this? - const claimedEd25519Key = sessionInfo["sender_claimed_keys"]?.["ed25519"]; - - if ( - typeof roomId !== "string" || - typeof sessionId !== "string" || - typeof senderKey !== "string" || - typeof sessionKey !== "string" || - typeof claimedEd25519Key !== "string" - ) { - return; - } - const session = new this._olm.InboundGroupSession(); - try { - session.import_session(sessionKey); - return await this._writeInboundSession( - session, roomId, senderKey, claimedEd25519Key, sessionId, txn); - } finally { - session.free(); - } - } - - async _writeInboundSession(session, roomId, senderKey, claimedEd25519Key, sessionId, txn) { - let incomingSessionIsBetter = true; - const existingSessionEntry = await txn.inboundGroupSessions.get(roomId, senderKey, sessionId); - if (existingSessionEntry) { - const existingSession = new this._olm.InboundGroupSession(); - try { - existingSession.unpickle(this._pickleKey, existingSessionEntry.session); - incomingSessionIsBetter = session.first_known_index() < existingSession.first_known_index(); - } finally { - existingSession.free(); - } - } - - if (incomingSessionIsBetter) { - const sessionEntry = { - roomId, - senderKey, - sessionId, - session: session.pickle(this._pickleKey), - claimedKeys: {ed25519: claimedEd25519Key}, - }; - txn.inboundGroupSessions.set(sessionEntry); - return sessionEntry; - } + roomKeyFromBackup(roomId, sessionId, sessionInfo) { + return RoomKey.fromBackup(roomId, sessionId, sessionInfo); } } diff --git a/src/matrix/e2ee/megolm/decryption/RoomKey.js b/src/matrix/e2ee/megolm/decryption/RoomKey.js new file mode 100644 index 00000000..1cfd5f95 --- /dev/null +++ b/src/matrix/e2ee/megolm/decryption/RoomKey.js @@ -0,0 +1,154 @@ +import {SessionInfo} from "./SessionInfo.js"; + +export class BaseRoomKey { + constructor() { + this._sessionInfo = null; + this._isBetter = null; + } + + + + async createSessionInfo(olm, pickleKey, txn) { + const session = new olm.InboundGroupSession(); + try { + this._loadSessionKey(session); + this._isBetter = await this._isBetterThanKnown(session, olm, pickleKey, txn); + if (this._isBetter) { + const claimedKeys = {ed25519: this.claimedEd25519Key}; + this._sessionInfo = new SessionInfo(this.roomId, this.senderKey, session, claimedKeys); + // retain the session so we don't have to create a new session during write. + this._sessionInfo.retain(); + return this._sessionInfo; + } else { + session.free(); + return; + } + } catch (err) { + this._sessionInfo = null; + session.free(); + throw err; + } + } + + async _isBetterThanKnown(session, olm, pickleKey, txn) { + let isBetter = true; + const existingSessionEntry = await txn.inboundGroupSessions.get(this.roomId, this.senderKey, this.sessionId); + if (existingSessionEntry) { + const existingSession = new olm.InboundGroupSession(); + try { + existingSession.unpickle(pickleKey, existingSessionEntry.session); + isBetter = session.first_known_index() < existingSession.first_known_index(); + } finally { + existingSession.free(); + } + } + return isBetter; + } + + async write(olm, pickleKey, txn) { + // we checked already and we had a better session in storage, so don't write + if (this._isBetter === false) { + return false; + } + if (!this._sessionInfo) { + await this.createSessionInfo(olm, pickleKey, txn); + } + if (this._sessionInfo) { + const session = this._sessionInfo.session; + const sessionEntry = { + roomId: this.roomId, + senderKey: this.senderKey, + sessionId: this.sessionId, + session: session.pickle(pickleKey), + claimedKeys: this._sessionInfo.claimedKeys, + }; + txn.inboundGroupSessions.set(sessionEntry); + this.dispose(); + return true; + } + return false; + } + + dispose() { + if (this._sessionInfo) { + this._sessionInfo.release(); + this._sessionInfo = null; + } + } +} + +class DeviceMessageRoomKey extends BaseRoomKey { + constructor(decryptionResult) { + super(); + this._decryptionResult = decryptionResult; + } + + get roomId() { return this._decryptionResult.event.content?.["room_id"]; } + get senderKey() { return this._decryptionResult.senderKey; } + get sessionId() { return this._decryptionResult.event.content?.["session_id"]; } + get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } + + _loadSessionKey(session) { + const sessionKey = this._decryptionResult.event.content?.["session_key"]; + session.create(sessionKey); + } +} + +class BackupRoomKey extends BaseRoomKey { + constructor(roomId, sessionId, sessionInfo) { + super(); + this._roomId = roomId; + this._sessionId = sessionId; + this._sessionInfo = sessionInfo; + } + + get roomId() { return this._roomId; } + get senderKey() { return this._sessionInfo["sender_key"]; } + get sessionId() { return this._sessionId; } + get claimedEd25519Key() { return this._sessionInfo["sender_claimed_keys"]?.["ed25519"]; } + + _loadSessionKey(session) { + const sessionKey = this._sessionInfo["session_key"]; + session.import_session(sessionKey); + } +} + +export function fromDeviceMessage(dr) { + const roomId = dr.event.content?.["room_id"]; + const sessionId = dr.event.content?.["session_id"]; + const sessionKey = dr.event.content?.["session_key"]; + if ( + typeof roomId === "string" || + typeof sessionId === "string" || + typeof senderKey === "string" || + typeof sessionKey === "string" + ) { + return new DeviceMessageRoomKey(dr); + } +} + +/* +sessionInfo is a response from key backup and has the following keys: + algorithm + forwarding_curve25519_key_chain + sender_claimed_keys + sender_key + session_key + */ +export function fromBackup(roomId, sessionId, sessionInfo) { + const sessionKey = sessionInfo["session_key"]; + const senderKey = sessionInfo["sender_key"]; + // TODO: can we just trust this? + const claimedEd25519Key = sessionInfo["sender_claimed_keys"]?.["ed25519"]; + + if ( + typeof roomId === "string" && + typeof sessionId === "string" && + typeof senderKey === "string" && + typeof sessionKey === "string" && + typeof claimedEd25519Key === "string" + ) { + return new BackupRoomKey(roomId, sessionId, sessionInfo); + } +} + diff --git a/src/matrix/e2ee/megolm/decryption/SessionInfo.js b/src/matrix/e2ee/megolm/decryption/SessionInfo.js index dedc3222..e8bec3d0 100644 --- a/src/matrix/e2ee/megolm/decryption/SessionInfo.js +++ b/src/matrix/e2ee/megolm/decryption/SessionInfo.js @@ -40,5 +40,6 @@ export class SessionInfo { dispose() { this.session.free(); + this.session = null; } } diff --git a/src/matrix/e2ee/olm/Decryption.js b/src/matrix/e2ee/olm/Decryption.js index 7c4ef7e6..7556c367 100644 --- a/src/matrix/e2ee/olm/Decryption.js +++ b/src/matrix/e2ee/olm/Decryption.js @@ -16,6 +16,7 @@ limitations under the License. import {DecryptionError} from "../common.js"; import {groupBy} from "../../../utils/groupBy.js"; +import {MultiLock} from "../../../utils/Lock.js"; import {Session} from "./Session.js"; import {DecryptionResult} from "../DecryptionResult.js"; @@ -41,6 +42,29 @@ export class Decryption { this._olm = olm; this._senderKeyLock = senderKeyLock; } + + // we need to lock because both encryption and decryption can't be done in one txn, + // so for them not to step on each other toes, we need to lock. + // + // the lock is release from 1 of 3 places, whichever comes first: + // - decryptAll below fails (to release the lock as early as we can) + // - DecryptionChanges.write succeeds + // - Sync finishes the writeSync phase (or an error was thrown, in case we never get to DecryptionChanges.write) + async obtainDecryptionLock(events) { + const senderKeys = new Set(); + for (const event of events) { + const senderKey = event.content?.["sender_key"]; + if (senderKey) { + senderKeys.add(senderKey); + } + } + // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen) + // don't modify the sessions at the same time + const locks = await Promise.all(Array.from(senderKeys).map(senderKey => { + return this._senderKeyLock.takeLock(senderKey); + })); + return new MultiLock(locks); + } // we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once // but for the same sender key we need to do one by one @@ -54,34 +78,28 @@ export class Decryption { // /** + * It is importants the lock obtained from obtainDecryptionLock is for the same set of events as passed in here. * [decryptAll description] * @param {[type]} events * @return {Promise} [description] */ - async decryptAll(events) { - const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); - const timestamp = this._now(); - // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen) - // don't modify the sessions at the same time - const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => { - return this._senderKeyLock.takeLock(senderKey); - })); + async decryptAll(events, lock, txn) { try { - const readSessionsTxn = this._storage.readTxn([this._storage.storeNames.olmSessions]); + const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); + const timestamp = this._now(); // decrypt events for different sender keys in parallel const senderKeyOperations = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { - return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + return this._decryptAllForSenderKey(senderKey, events, timestamp, txn); })); const results = senderKeyOperations.reduce((all, r) => all.concat(r.results), []); const errors = senderKeyOperations.reduce((all, r) => all.concat(r.errors), []); const senderKeyDecryptions = senderKeyOperations.map(r => r.senderKeyDecryption); - return new DecryptionChanges(senderKeyDecryptions, results, errors, this._account, locks); + return new DecryptionChanges(senderKeyDecryptions, results, errors, this._account, lock); } catch (err) { // make sure the locks are release if something throws // otherwise they will be released in DecryptionChanges after having written - for (const lock of locks) { - lock.release(); - } + // or after the writeSync phase in Sync + lock.release(); throw err; } } @@ -268,12 +286,12 @@ class SenderKeyDecryption { * @property {Array} errors see DecryptionError.event to retrieve the event that failed to decrypt. */ class DecryptionChanges { - constructor(senderKeyDecryptions, results, errors, account, locks) { + constructor(senderKeyDecryptions, results, errors, account, lock) { this._senderKeyDecryptions = senderKeyDecryptions; this._account = account; this.results = results; this.errors = errors; - this._locks = locks; + this._lock = lock; } get hasNewSessions() { @@ -304,9 +322,7 @@ class DecryptionChanges { } } } finally { - for (const lock of this._locks) { - lock.release(); - } + this._lock.release(); } } } diff --git a/src/matrix/room/Room.js b/src/matrix/room/Room.js index dc730b2a..891e9dfe 100644 --- a/src/matrix/room/Room.js +++ b/src/matrix/room/Room.js @@ -148,7 +148,7 @@ export class Room extends EventEmitter { return entry.eventType === EVENT_ENCRYPTED_TYPE; }).map(entry => entry.event); const isTimelineOpen = this._isTimelineOpen; - r.preparation = await this._roomEncryption.prepareDecryptAll(events, source, isTimelineOpen, inboundSessionTxn); + r.preparation = await this._roomEncryption.prepareDecryptAll(events, null, source, isTimelineOpen, inboundSessionTxn); if (r.cancelled) return; const changes = await r.preparation.decrypt(); r.preparation = null; @@ -176,8 +176,11 @@ export class Room extends EventEmitter { return request; } - async prepareSync(roomResponse, membership, txn, log) { + async prepareSync(roomResponse, membership, newKeys, txn, log) { log.set("id", this.id); + if (newKeys) { + log.set("newKeys", newKeys.length); + } const summaryChanges = this._summary.data.applySyncResponse(roomResponse, membership) let roomEncryption = this._roomEncryption; // encryption is enabled in this sync @@ -194,7 +197,7 @@ export class Room extends EventEmitter { return event?.type === EVENT_ENCRYPTED_TYPE; }); decryptPreparation = await roomEncryption.prepareDecryptAll( - eventsToDecrypt, DecryptionSource.Sync, this._isTimelineOpen, txn); + eventsToDecrypt, newKeys, DecryptionSource.Sync, this._isTimelineOpen, txn); } } diff --git a/src/utils/Lock.js b/src/utils/Lock.js index e133f33c..8cfc733f 100644 --- a/src/utils/Lock.js +++ b/src/utils/Lock.js @@ -54,6 +54,18 @@ export class Lock { } } +export class MultiLock { + constructor(locks) { + this.locks = locks; + } + + release() { + for (const lock of this.locks) { + lock.release(); + } + } +} + export function tests() { return { "taking a lock twice returns false": assert => {