From 44e9f91d4c3ae6ca591b9f688de9bef23c5a593c Mon Sep 17 00:00:00 2001 From: Bruno Windels Date: Wed, 2 Sep 2020 13:33:27 +0200 Subject: [PATCH] to_device handler for encrypted messages changes the api of the olm decryption to decrypt in batch so we can isolate side-effects until we have a write-txn open and we can parallelize the decryption of different sender keys. --- src/matrix/DeviceMessageHandler.js | 87 ++++ src/matrix/e2ee/common.js | 3 +- src/matrix/e2ee/olm/Decryption.js | 374 ++++++++++++------ src/matrix/storage/idb/stores/SessionStore.js | 4 + 4 files changed, 345 insertions(+), 123 deletions(-) create mode 100644 src/matrix/DeviceMessageHandler.js diff --git a/src/matrix/DeviceMessageHandler.js b/src/matrix/DeviceMessageHandler.js new file mode 100644 index 00000000..a26bfe33 --- /dev/null +++ b/src/matrix/DeviceMessageHandler.js @@ -0,0 +1,87 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {OLM_ALGORITHM, MEGOLM_ALGORITHM} from "./e2ee/common.js"; + +// key to store in session store +const PENDING_ENCRYPTED_EVENTS = "pendingEncryptedDeviceEvents"; + +export class DeviceMessageHandler { + constructor({storage, olmDecryption, megolmEncryption}) { + this._storage = storage; + this._olmDecryption = olmDecryption; + this._megolmEncryption = megolmEncryption; + } + + async writeSync(toDeviceEvents, txn) { + const encryptedEvents = toDeviceEvents.filter(e => e.type === "m.room.encrypted"); + // store encryptedEvents + let pendingEvents = this._getPendingEvents(txn); + pendingEvents = pendingEvents.concat(encryptedEvents); + txn.session.set(PENDING_ENCRYPTED_EVENTS, pendingEvents); + // we don't handle anything other for now + } + + async _handleDecryptedEvents(payloads, txn) { + const megOlmRoomKeysPayloads = payloads.filter(p => { + return p.event.type === "m.room_key" && p.event.content?.algorithm === MEGOLM_ALGORITHM; + }); + let megolmChanges; + if (megOlmRoomKeysPayloads.length) { + megolmChanges = await this._megolmEncryption.addRoomKeys(megOlmRoomKeysPayloads, txn); + } + return {megolmChanges}; + } + + applyChanges({megolmChanges}) { + if (megolmChanges) { + this._megolmEncryption.applyRoomKeyChanges(megolmChanges); + } + } + + // not safe to call multiple times without awaiting first call + async decryptPending() { + const readTxn = await this._storage.readTxn([this._storage.storeNames.session]); + const pendingEvents = this._getPendingEvents(readTxn); + // only know olm for now + const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM); + const decryptChanges = await this._olmDecryption.decryptAll(olmEvents); + for (const err of decryptChanges.errors) { + console.warn("decryption failed for event", err, err.event); + } + const txn = await this._storage.readWriteTxn([ + // both to remove the pending events and to modify the olm account + this._storage.storeNames.session, + this._storage.storeNames.olmSessions, + // this._storage.storeNames.megolmInboundSessions, + ]); + let changes; + try { + changes = await this._handleDecryptedEvent(decryptChanges.payloads, txn); + decryptChanges.write(txn); + txn.session.remove(PENDING_ENCRYPTED_EVENTS); + } catch (err) { + txn.abort(); + throw err; + } + await txn.complete(); + this._applyChanges(changes); + } + + async _getPendingEvents(txn) { + return (await txn.session.get(PENDING_ENCRYPTED_EVENTS)) || []; + } +} diff --git a/src/matrix/e2ee/common.js b/src/matrix/e2ee/common.js index ef758feb..c5e7399f 100644 --- a/src/matrix/e2ee/common.js +++ b/src/matrix/e2ee/common.js @@ -20,9 +20,10 @@ export const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2"; export const MEGOLM_ALGORITHM = "m.megolm.v1.aes-sha2"; export class DecryptionError extends Error { - constructor(code, detailsObj = null) { + constructor(code, event, detailsObj = null) { super(`Decryption error ${code}${detailsObj ? ": "+JSON.stringify(detailsObj) : ""}`); this.code = code; + this.event = event; this.details = detailsObj; } } diff --git a/src/matrix/e2ee/olm/Decryption.js b/src/matrix/e2ee/olm/Decryption.js index 582f96d2..f701f4df 100644 --- a/src/matrix/e2ee/olm/Decryption.js +++ b/src/matrix/e2ee/olm/Decryption.js @@ -22,6 +22,12 @@ function isPreKeyMessage(message) { return message.type === 0; } +function sortSessions(sessions) { + sessions.sort((a, b) => { + return b.data.lastUsed - a.data.lastUsed; + }); +} + export class Decryption { constructor({account, pickleKey, now, ownUserId, storage, olm}) { this._account = account; @@ -33,155 +39,279 @@ export class Decryption { this._createOutboundSessionPromise = null; } - // we can't run this in the sync txn because decryption will be async ... - // should we store the encrypted events in the sync loop and then pop them from there? - // it would be good in any case to run the (next) sync request in parallel with decryption - async decrypt(event) { - const senderKey = event.content?.["sender_key"]; + // we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once + // but for the same sender key we need to do one by one + // + // also we want to store the room key, etc ... in the same txn as we remove the pending encrypted event + // + // so we need to decrypt events in a batch (so we can decide which ones can run in parallel and which one one by one) + // and also can avoid side-effects before all can be stored this way + // + // doing it one by one would be possible, but we would lose the opportunity for parallelization + async decryptAll(events) { + const eventsPerSenderKey = events.reduce((map, event) => { + const senderKey = event.content?.["sender_key"]; + let list = map.get(senderKey); + if (!list) { + list = []; + map.set(senderKey, list); + } + list.push(event); + return map; + }, new Map()); + const timestamp = this._now(); + const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); + // decrypt events for different sender keys in parallel + const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { + return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + })); + const payloads = results.reduce((all, r) => all.concat(r.payloads), []); + const errors = results.reduce((all, r) => all.concat(r.errors), []); + const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); + return new DecryptionChanges(senderKeyDecryptions, payloads, errors); + } + + async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) { + const sessions = await this._getSessions(senderKey, readSessionsTxn); + const senderKeyDecryption = new SenderKeyDecryption(senderKey, sessions, this._olm, timestamp); + const payloads = []; + const errors = []; + // events for a single senderKey need to be decrypted one by one + for (const event of events) { + try { + const payload = this._decryptForSenderKey(senderKeyDecryption, event, timestamp); + payloads.push(payload); + } catch (err) { + errors.push(err); + } + } + return {payloads, errors, senderKeyDecryption}; + } + + _decryptForSenderKey(senderKeyDecryption, event, timestamp) { + const senderKey = senderKeyDecryption.senderKey; + const message = this._getMessageAndValidateEvent(event); + let plaintext; + try { + plaintext = senderKeyDecryption.decrypt(message); + } catch (err) { + // TODO: is it ok that an error on one session prevents other sessions from being attempted? + throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", event, {senderKey, error: err.message}); + } + // could not decrypt with any existing session + if (typeof plaintext !== "string" && isPreKeyMessage(message)) { + const createResult = this._createSessionAndDecrypt(senderKey, message, timestamp); + senderKeyDecryption.addNewSession(createResult.session); + plaintext = createResult.plaintext; + } + if (typeof plaintext === "string") { + const payload = JSON.parse(plaintext); + this._validatePayload(payload, event); + return {event: payload, senderKey}; + } else { + throw new DecryptionError("Didn't find any session to decrypt with", event, + {sessionIds: senderKeyDecryption.sessions.map(s => s.id)}); + } + } + + // only for pre-key messages after having attempted decryption with existing sessions + _createSessionAndDecrypt(senderKey, message, timestamp) { + let plaintext; + // if we have multiple messages encrypted with the same new session, + // this could create multiple sessions as the OTK isn't removed yet + // (this only happens in DecryptionChanges.write) + // This should be ok though as we'll first try to decrypt with the new session + const olmSession = this._account.createInboundOlmSession(senderKey, message.body); + try { + plaintext = olmSession.decrypt(message.type, message.body); + const session = Session.create(senderKey, olmSession, this._olm, this._pickleKey, timestamp); + session.unload(olmSession); + return {session, plaintext}; + } catch (err) { + olmSession.free(); + throw err; + } + } + + _getMessageAndValidateEvent(event) { const ciphertext = event.content?.ciphertext; if (!ciphertext) { - throw new DecryptionError("OLM_MISSING_CIPHERTEXT"); + throw new DecryptionError("OLM_MISSING_CIPHERTEXT", event); } const message = ciphertext?.[this._account.identityKeys.curve25519]; if (!message) { - // TODO: use same error messages as element-web - throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS"); - } - const sortedSessionIds = await this._getSortedSessionIds(senderKey); - let plaintext; - for (const sessionId of sortedSessionIds) { - try { - plaintext = await this._attemptDecryption(senderKey, sessionId, message); - } catch (err) { - throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", {senderKey, error: err.message}); - } - if (typeof plaintext === "string") { - break; - } - } - if (typeof plaintext !== "string" && isPreKeyMessage(message)) { - plaintext = await this._createOutboundSessionAndDecrypt(senderKey, message, sortedSessionIds); - } - if (typeof plaintext === "string") { - return this._parseAndValidatePayload(plaintext, event); + throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS", event); } + + return message; } - async _getSortedSessionIds(senderKey) { - const readTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); - const sortedSessions = await readTxn.olmSessions.getAll(senderKey); + async _getSessions(senderKey, txn) { + const sessionEntries = await txn.olmSessions.getAll(senderKey); // sort most recent used sessions first - sortedSessions.sort((a, b) => { - return b.lastUsed - a.lastUsed; - }); - return sortedSessions.map(s => s.sessionId); + const sessions = sessionEntries.map(s => new Session(s, this._pickleKey, this._olm)); + sortSessions(sessions); + return sessions; } - async _createOutboundSessionAndDecrypt(senderKey, message, sortedSessionIds) { - // serialize calls so the account isn't written from multiple - // sessions at once - while (this._createOutboundSessionPromise) { - await this._createOutboundSessionPromise; + _validatePayload(payload, event) { + if (payload.sender !== event.sender) { + throw new DecryptionError("OLM_FORWARDED_MESSAGE", event, {sentBy: event.sender, encryptedBy: payload.sender}); } - this._createOutboundSessionPromise = (async () => { - try { - return await this._createOutboundSessionAndDecryptImpl(senderKey, message, sortedSessionIds); - } finally { - this._createOutboundSessionPromise = null; - } - })(); - return await this._createOutboundSessionPromise; - } - - // this could internally dispatch to a web-worker - async _createOutboundSessionAndDecryptImpl(senderKey, message, sortedSessionIds) { - let plaintext; - const session = this._account.createInboundOlmSession(senderKey, message.body); - try { - const txn = await this._storage.readWriteTxn([ - this._storage.storeNames.session, - this._storage.storeNames.olmSessions, - ]); - try { - // do this before removing the OTK removal, so we know decryption succeeded beforehand, - // as we don't have a way of undoing the OTK removal atm. - plaintext = session.decrypt(message.type, message.body); - this._account.writeRemoveOneTimeKey(session, txn); - // remove oldest session if we reach the limit including the new session - if (sortedSessionIds.length >= SESSION_LIMIT_PER_SENDER_KEY) { - // given they are sorted, the oldest one is the last one - const oldestSessionId = sortedSessionIds[sortedSessionIds.length - 1]; - txn.olmSessions.remove(senderKey, oldestSessionId); - } - txn.olmSessions.set({ - session: session.pickle(this._pickleKey), - sessionId: session.session_id(), - senderKey, - lastUsed: this._now(), - }); - } catch (err) { - txn.abort(); - throw err; - } - await txn.complete(); - } finally { - session.free(); + if (payload.recipient !== this._ownUserId) { + throw new DecryptionError("OLM_BAD_RECIPIENT", event, {recipient: payload.recipient}); } - return plaintext; + if (payload.recipient_keys?.ed25519 !== this._account.identityKeys.ed25519) { + throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", event, {key: payload.recipient_keys?.ed25519}); + } + // TODO: check room_id + if (!payload.type) { + throw new DecryptionError("missing type on payload", event, {payload}); + } + if (!payload.content) { + throw new DecryptionError("missing content on payload", event, {payload}); + } + // TODO: how important is it to verify the message? + // we should look at payload.keys.ed25519 for that... and compare it to the key we have fetched + // from /keys/query, which we might not have done yet at this point. + } +} + +class Session { + constructor(data, pickleKey, olm, isNew = false) { + this.data = data; + this._olm = olm; + this._pickleKey = pickleKey; + this.isNew = isNew; + this.isModified = isNew; } - // this could internally dispatch to a web-worker - async _attemptDecryption(senderKey, sessionId, message) { - const txn = await this._storage.readWriteTxn([this._storage.storeNames.olmSessions]); + static create(senderKey, olmSession, olm, pickleKey, timestamp) { + return new Session({ + session: olmSession.pickle(pickleKey), + sessionId: olmSession.session_id(), + senderKey, + lastUsed: timestamp, + }, pickleKey, olm, true); + } + + get id() { + return this.data.sessionId; + } + + load() { const session = new this._olm.Session(); - let plaintext; + session.unpickle(this._pickleKey, this.data.session); + return session; + } + + unload(olmSession) { + olmSession.free(); + } + + save(olmSession) { + this.data.session = olmSession.pickle(this._pickleKey); + this.isModified = true; + } +} + +// decryption helper for a single senderKey +class SenderKeyDecryption { + constructor(senderKey, sessions, olm, timestamp) { + this.senderKey = senderKey; + this.sessions = sessions; + this._olm = olm; + this._timestamp = timestamp; + } + + addNewSession(session) { + // add at top as it is most recent + this.sessions.unshift(session); + } + + decrypt(message) { + for (const session of this.sessions) { + const plaintext = this._decryptWithSession(session, message); + if (typeof plaintext === "string") { + // keep them sorted so will try the same session first for other messages + // and so we can assume the excess ones are at the end + // if they grow too large + sortSessions(this.sessions); + return plaintext; + } + } + } + + getModifiedSessions() { + return this.sessions.filter(session => session.isModified); + } + + get hasNewSessions() { + return this.sessions.some(session => session.isNew); + } + + // this could internally dispatch to a web-worker + // and is why we unpickle/pickle on each iteration + // if this turns out to be a real cost for IE11, + // we could look into adding a less expensive serialization mechanism + // for olm sessions to libolm + _decryptWithSession(session, message) { + const olmSession = session.load(); try { - const sessionEntry = await txn.olmSessions.get(senderKey, sessionId); - session.unpickle(this._pickleKey, sessionEntry.session); - if (isPreKeyMessage(message) && !session.matches_inbound(message.body)) { + if (isPreKeyMessage(message) && !olmSession.matches_inbound(message.body)) { return; } try { - plaintext = session.decrypt(message.type, message.body); + const plaintext = olmSession.decrypt(message.type, message.body); + session.save(olmSession); + session.lastUsed = this._timestamp; + return plaintext; } catch (err) { if (isPreKeyMessage(message)) { - throw new Error(`Error decrypting prekey message with existing session id ${sessionId}: ${err.message}`); + throw new Error(`Error decrypting prekey message with existing session id ${session.id}: ${err.message}`); } // decryption failed, bail out return; } - sessionEntry.session = session.pickle(this._pickleKey); - sessionEntry.lastUsed = this._now(); - txn.olmSessions.set(sessionEntry); - } catch(err) { - txn.abort(); - throw err; } finally { - session.free(); + session.unload(olmSession); + } + } +} + +class DecryptionChanges { + constructor(senderKeyDecryptions, payloads, errors, account) { + this._senderKeyDecryptions = senderKeyDecryptions; + this._account = account; + this.payloads = payloads; + this.errors = errors; + } + + get hasNewSessions() { + return this._senderKeyDecryptions.some(skd => skd.hasNewSessions); + } + + write(txn) { + for (const senderKeyDecryption of this._senderKeyDecryptions) { + for (const session of senderKeyDecryption.getModifiedSessions()) { + txn.olmSessions.set(session.data); + if (session.isNew) { + const olmSession = session.load(); + try { + this._account.writeRemoveOneTimeKey(olmSession, txn); + } finally { + session.unload(olmSession); + } + } + } + if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { + const {senderKey, sessions} = senderKeyDecryption; + // >= because index is zero-based + for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) { + const session = sessions[i]; + txn.olmSessions.remove(senderKey, session.id); + } + } } - await txn.complete(); - return plaintext; - } - - _parseAndValidatePayload(plaintext, event) { - const payload = JSON.parse(plaintext); - - if (payload.sender !== event.sender) { - throw new DecryptionError("OLM_FORWARDED_MESSAGE", {sentBy: event.sender, encryptedBy: payload.sender}); - } - if (payload.recipient !== this._ownUserId) { - throw new DecryptionError("OLM_BAD_RECIPIENT", {recipient: payload.recipient}); - } - if (payload.recipient_keys?.ed25519 !== this._account.identityKeys.ed25519) { - throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", {key: payload.recipient_keys?.ed25519}); - } - // TODO: check room_id - if (!payload.type) { - throw new Error("missing type on payload"); - } - if (!payload.content) { - throw new Error("missing content on payload"); - } - return payload; } } diff --git a/src/matrix/storage/idb/stores/SessionStore.js b/src/matrix/storage/idb/stores/SessionStore.js index f64a8299..25ea2351 100644 --- a/src/matrix/storage/idb/stores/SessionStore.js +++ b/src/matrix/storage/idb/stores/SessionStore.js @@ -49,4 +49,8 @@ export class SessionStore { add(key, value) { return this._sessionStore.put({key, value}); } + + remove(key) { + this._sessionStore.delete(key); + } }