diff --git a/src/matrix/e2ee/olm/Encryption.js b/src/matrix/e2ee/olm/Encryption.js index 928fc6a9..fee485a4 100644 --- a/src/matrix/e2ee/olm/Encryption.js +++ b/src/matrix/e2ee/olm/Encryption.js @@ -42,6 +42,39 @@ export class Encryption { } async encrypt(type, content, devices, hsApi) { + const { + devicesWithoutSession, + existingEncryptionTargets + } = await this._findExistingSessions(devices); + + const timestamp = this._now(); + + let encryptionTargets = []; + try { + if (devicesWithoutSession.length) { + const newEncryptionTargets = await this._createNewSessions( + devicesWithoutSession, hsApi, timestamp); + encryptionTargets = encryptionTargets.concat(newEncryptionTargets); + } + // TODO: if we read and write in two different txns, + // is there a chance we overwrite a session modified by the decryption during sync? + // I think so. We'll have to have a lock while sending ... + await this._loadSessions(existingEncryptionTargets); + encryptionTargets = encryptionTargets.concat(existingEncryptionTargets); + const messages = encryptionTargets.map(target => { + const content = this._encryptForDevice(type, content, target); + return new EncryptedMessage(content, target.device); + }); + await this._storeSessions(encryptionTargets, timestamp); + return messages; + } finally { + for (const target of encryptionTargets) { + target.dispose(); + } + } + } + + async _findExistingSessions(devices) { const txn = this._storage.readTxn([this._storage.storeNames.olmSessions]); const sessionIdsForDevice = await Promise.all(devices.map(async device => { return await txn.olmSessions.getSessionIds(device.curve25519Key); @@ -51,49 +84,28 @@ export class Encryption { return !(sessionIds?.length); }); - const timestamp = this._now(); - - let encryptionTuples = []; - - if (devicesWithoutSession.length) { - const newEncryptionTuples = await this._claimOneTimeKeys(hsApi, devicesWithoutSession); - try { - for (const tuple of newEncryptionTuples) { - const {device, oneTimeKey} = tuple; - tuple.session = this._account.createOutboundOlmSession(device.curve25519Key, oneTimeKey); - } - this._storeSessions(newEncryptionTuples, timestamp); - } catch (err) { - for (const tuple of newEncryptionTuples) { - tuple.dispose(); - } - } - encryptionTuples = encryptionTuples.concat(newEncryptionTuples); - } - - const existingEncryptionTuples = devices.map((device, i) => { + const existingEncryptionTargets = devices.map((device, i) => { const sessionIds = sessionIdsForDevice[i]; if (sessionIds?.length > 0) { const sessionId = findFirstSessionId(sessionIds); - return EncryptionTuple.fromSessionId(device, sessionId); + return EncryptionTarget.fromSessionId(device, sessionId); } - }).filter(tuple => !!tuple); - - // TODO: if we read and write in two different txns, - // is there a chance we overwrite a session modified by the decryption during sync? - // I think so. We'll have to have a lock while sending ... - await this._loadSessions(existingEncryptionTuples); - encryptionTuples = encryptionTuples.concat(existingEncryptionTuples); - const ciphertext = this._buildCipherText(type, content, encryptionTuples); - await this._storeSessions(encryptionTuples, timestamp); - return { - type: "m.room.encrypted", - content: { - algorithm: OLM_ALGORITHM, - sender_key: this._account.identityKeys.curve25519, - ciphertext + }).filter(target => !!target); + + return {devicesWithoutSession, existingEncryptionTargets}; + } + + _encryptForDevice(type, content, target) { + const {session, device} = target; + const message = session.encrypt(this._buildPlainTextMessageForDevice(type, content, device)); + const encryptedContent = { + algorithm: OLM_ALGORITHM, + sender_key: this._account.identityKeys.curve25519, + ciphertext: { + [device.curve25519Key]: message } }; + return encryptedContent; } _buildPlainTextMessageForDevice(type, content, device) { @@ -111,15 +123,20 @@ export class Encryption { } } - _buildCipherText(type, content, encryptionTuples) { - const ciphertext = {}; - for (const {device, session} of encryptionTuples) { - if (session) { - const message = session.encrypt(this._buildPlainTextMessageForDevice(type, content, device)); - ciphertext[device.curve25519Key] = message; + async _createNewSessions(devicesWithoutSession, hsApi, timestamp) { + const newEncryptionTargets = await this._claimOneTimeKeys(hsApi, devicesWithoutSession); + try { + for (const target of newEncryptionTargets) { + const {device, oneTimeKey} = target; + target.session = this._account.createOutboundOlmSession(device.curve25519Key, oneTimeKey); + } + this._storeSessions(newEncryptionTargets, timestamp); + } catch (err) { + for (const target of newEncryptionTargets) { + target.dispose(); } } - return ciphertext; + return newEncryptionTargets; } async _claimOneTimeKeys(hsApi, deviceIdentities) { @@ -142,11 +159,11 @@ export class Encryption { }).response(); // TODO: log claimResponse.failures const userKeyMap = claimResponse?.["one_time_keys"]; - return this._verifyAndCreateOTKTuples(userKeyMap, devicesByUser); + return this._verifyAndCreateOTKTargets(userKeyMap, devicesByUser); } - _verifyAndCreateOTKTuples(userKeyMap, devicesByUser) { - const verifiedEncryptionTuples = []; + _verifyAndCreateOTKTargets(userKeyMap, devicesByUser) { + const verifiedEncryptionTargets = []; for (const [userId, userSection] of Object.entries(userKeyMap)) { for (const [deviceId, deviceSection] of Object.entries(userSection)) { const [firstPropName, keySection] = Object.entries(deviceSection)[0]; @@ -157,34 +174,48 @@ export class Encryption { const isValidSignature = verifyEd25519Signature( this._olmUtil, userId, deviceId, device.ed25519Key, keySection); if (isValidSignature) { - verifiedEncryptionTuples.push(EncryptionTuple.fromOTK(device, keySection.key)); + const target = EncryptionTarget.fromOTK(device, keySection.key); + verifiedEncryptionTargets.push(target); } } } } } - return verifiedEncryptionTuples; + return verifiedEncryptionTargets; } - async _loadSessions(encryptionTuples) { + async _loadSessions(encryptionTargets) { const txn = this._storage.readTxn([this._storage.storeNames.olmSessions]); - await Promise.all(encryptionTuples.map(async encryptionTuple => { - const sessionEntry = await txn.olmSessions.get( - encryptionTuple.device.curve25519Key, encryptionTuple.sessionId); - if (sessionEntry) { - const olmSession = new this._olm.Session(); - encryptionTuple.session = + // given we run loading in parallel, there might still be some + // storage requests that will finish later once one has failed. + // those should not allocate a session anymore. + let failed = false; + try { + await Promise.all(encryptionTargets.map(async encryptionTarget => { + const sessionEntry = await txn.olmSessions.get( + encryptionTarget.device.curve25519Key, encryptionTarget.sessionId); + if (sessionEntry && !failed) { + const olmSession = new this._olm.Session(); + olmSession.unpickle(this._pickleKey, sessionEntry.session); + encryptionTarget.session = olmSession; + } + })); + } catch (err) { + failed = true; + // clean up the sessions that did load + for (const target of encryptionTargets) { + target.dispose(); } - - })); + throw err; + } } - async _storeSessions(encryptionTuples, timestamp) { + async _storeSessions(encryptionTargets, timestamp) { const txn = this._storage.readWriteTxn([this._storage.storeNames.olmSessions]); try { - for (const tuple of encryptionTuples) { + for (const target of encryptionTargets) { const sessionEntry = createSessionEntry( - tuple.session, tuple.device.curve25519Key, timestamp, this._pickleKey); + target.session, target.device.curve25519Key, timestamp, this._pickleKey); txn.olmSessions.set(sessionEntry); } } catch (err) { @@ -199,7 +230,7 @@ export class Encryption { // it is constructed with either a oneTimeKey // (and later converted to a session) in case of a new session // or an existing session -class EncryptionTuple { +class EncryptionTarget { constructor(device, oneTimeKey, sessionId) { this.device = device; this.oneTimeKey = oneTimeKey; @@ -209,11 +240,11 @@ class EncryptionTuple { } static fromOTK(device, oneTimeKey) { - return new EncryptionTuple(device, oneTimeKey, null); + return new EncryptionTarget(device, oneTimeKey, null); } static fromSessionId(device, sessionId) { - return new EncryptionTuple(device, null, sessionId); + return new EncryptionTarget(device, null, sessionId); } dispose() { @@ -222,3 +253,10 @@ class EncryptionTuple { } } } + +class EncryptedMessage { + constructor(content, device) { + this.content = content; + this.device = device; + } +}