diff --git a/src/matrix/Session.js b/src/matrix/Session.js index 2752fca7..803aadc8 100644 --- a/src/matrix/Session.js +++ b/src/matrix/Session.js @@ -23,6 +23,8 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js"; import {DeviceTracker} from "./e2ee/DeviceTracker.js"; +import {LockMap} from "../utils/LockMap.js"; + const PICKLE_KEY = "DEFAULT_KEY"; export class Session { @@ -54,6 +56,7 @@ export class Session { // called once this._e2eeAccount is assigned _setupEncryption() { + const senderKeyLock = new LockMap(); const olmDecryption = new OlmDecryption({ account: this._e2eeAccount, pickleKey: PICKLE_KEY, @@ -61,6 +64,7 @@ export class Session { ownUserId: this._user.id, storage: this._storage, olm: this._olm, + senderKeyLock }); const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption}); diff --git a/src/matrix/e2ee/olm/Decryption.js b/src/matrix/e2ee/olm/Decryption.js index 01362266..c21c4b3d 100644 --- a/src/matrix/e2ee/olm/Decryption.js +++ b/src/matrix/e2ee/olm/Decryption.js @@ -31,14 +31,14 @@ function sortSessions(sessions) { } export class Decryption { - constructor({account, pickleKey, now, ownUserId, storage, olm}) { + constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) { this._account = account; this._pickleKey = pickleKey; this._now = now; this._ownUserId = ownUserId; this._storage = storage; this._olm = olm; - this._createOutboundSessionPromise = null; + this._senderKeyLock = senderKeyLock; } // we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once @@ -53,15 +53,28 @@ export class Decryption { async decryptAll(events) { const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); const timestamp = this._now(); - const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); - // decrypt events for different sender keys in parallel - const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { - return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen) + // don't modify the sessions at the same time + const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => { + return this._senderKeyLock.takeLock(senderKey); })); - const payloads = results.reduce((all, r) => all.concat(r.payloads), []); - const errors = results.reduce((all, r) => all.concat(r.errors), []); - const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); - return new DecryptionChanges(senderKeyDecryptions, payloads, errors); + try { + const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); + // decrypt events for different sender keys in parallel + const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { + return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); + })); + const payloads = results.reduce((all, r) => all.concat(r.payloads), []); + const errors = results.reduce((all, r) => all.concat(r.errors), []); + const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); + return new DecryptionChanges(senderKeyDecryptions, payloads, errors, locks); + } catch (err) { + // make sure the locks are release if something throws + for (const lock of locks) { + lock.release(); + } + throw err; + } } async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) { @@ -235,11 +248,12 @@ class SenderKeyDecryption { } class DecryptionChanges { - constructor(senderKeyDecryptions, payloads, errors, account) { + constructor(senderKeyDecryptions, payloads, errors, account, locks) { this._senderKeyDecryptions = senderKeyDecryptions; this._account = account; this.payloads = payloads; this.errors = errors; + this._locks = locks; } get hasNewSessions() { @@ -247,25 +261,31 @@ class DecryptionChanges { } write(txn) { - for (const senderKeyDecryption of this._senderKeyDecryptions) { - for (const session of senderKeyDecryption.getModifiedSessions()) { - txn.olmSessions.set(session.data); - if (session.isNew) { - const olmSession = session.load(); - try { - this._account.writeRemoveOneTimeKey(olmSession, txn); - } finally { - session.unload(olmSession); + try { + for (const senderKeyDecryption of this._senderKeyDecryptions) { + for (const session of senderKeyDecryption.getModifiedSessions()) { + txn.olmSessions.set(session.data); + if (session.isNew) { + const olmSession = session.load(); + try { + this._account.writeRemoveOneTimeKey(olmSession, txn); + } finally { + session.unload(olmSession); + } + } + } + if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { + const {senderKey, sessions} = senderKeyDecryption; + // >= because index is zero-based + for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) { + const session = sessions[i]; + txn.olmSessions.remove(senderKey, session.id); } } } - if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { - const {senderKey, sessions} = senderKeyDecryption; - // >= because index is zero-based - for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) { - const session = sessions[i]; - txn.olmSessions.remove(senderKey, session.id); - } + } finally { + for (const lock of this._locks) { + lock.release(); } } } diff --git a/src/matrix/e2ee/olm/Encryption.js b/src/matrix/e2ee/olm/Encryption.js index fee485a4..1461b94c 100644 --- a/src/matrix/e2ee/olm/Encryption.js +++ b/src/matrix/e2ee/olm/Encryption.js @@ -31,7 +31,7 @@ function findFirstSessionId(sessionIds) { const OTK_ALGORITHM = "signed_curve25519"; export class Encryption { - constructor({account, olm, olmUtil, userId, storage, now, pickleKey}) { + constructor({account, olm, olmUtil, userId, storage, now, pickleKey, senderKeyLock}) { this._account = account; this._olm = olm; this._olmUtil = olmUtil; @@ -39,37 +39,47 @@ export class Encryption { this._storage = storage; this._now = now; this._pickleKey = pickleKey; + this._senderKeyLock = senderKeyLock; } async encrypt(type, content, devices, hsApi) { - const { - devicesWithoutSession, - existingEncryptionTargets - } = await this._findExistingSessions(devices); - - const timestamp = this._now(); - - let encryptionTargets = []; + // TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed) + // take a lock on all senderKeys so decryption and other calls to encrypt (should not happen) + // don't modify the sessions at the same time + const locks = await Promise.all(devices.map(device => { + return this._senderKeyLock.takeLock(device.curve25519Key); + })); try { - if (devicesWithoutSession.length) { - const newEncryptionTargets = await this._createNewSessions( - devicesWithoutSession, hsApi, timestamp); - encryptionTargets = encryptionTargets.concat(newEncryptionTargets); + const { + devicesWithoutSession, + existingEncryptionTargets, + } = await this._findExistingSessions(devices); + + const timestamp = this._now(); + + let encryptionTargets = []; + try { + if (devicesWithoutSession.length) { + const newEncryptionTargets = await this._createNewSessions( + devicesWithoutSession, hsApi, timestamp); + encryptionTargets = encryptionTargets.concat(newEncryptionTargets); + } + await this._loadSessions(existingEncryptionTargets); + encryptionTargets = encryptionTargets.concat(existingEncryptionTargets); + const messages = encryptionTargets.map(target => { + const content = this._encryptForDevice(type, content, target); + return new EncryptedMessage(content, target.device); + }); + await this._storeSessions(encryptionTargets, timestamp); + return messages; + } finally { + for (const target of encryptionTargets) { + target.dispose(); + } } - // TODO: if we read and write in two different txns, - // is there a chance we overwrite a session modified by the decryption during sync? - // I think so. We'll have to have a lock while sending ... - await this._loadSessions(existingEncryptionTargets); - encryptionTargets = encryptionTargets.concat(existingEncryptionTargets); - const messages = encryptionTargets.map(target => { - const content = this._encryptForDevice(type, content, target); - return new EncryptedMessage(content, target.device); - }); - await this._storeSessions(encryptionTargets, timestamp); - return messages; } finally { - for (const target of encryptionTargets) { - target.dispose(); + for (const lock of locks) { + lock.release(); } } } diff --git a/src/utils/Lock.js b/src/utils/Lock.js new file mode 100644 index 00000000..21d5d7a2 --- /dev/null +++ b/src/utils/Lock.js @@ -0,0 +1,86 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +export class Lock { + constructor() { + this._promise = null; + this._resolve = null; + } + + take() { + if (!this._promise) { + this._promise = new Promise(resolve => { + this._resolve = resolve; + }); + return true; + } + return false; + } + + get isTaken() { + return !!this._promise; + } + + release() { + if (this._resolve) { + this._promise = null; + const resolve = this._resolve; + this._resolve = null; + resolve(); + } + } + + released() { + return this._promise; + } +} + +export function tests() { + return { + "taking a lock twice returns false": assert => { + const lock = new Lock(); + assert.equal(lock.take(), true); + assert.equal(lock.isTaken, true); + assert.equal(lock.take(), false); + }, + "can take a released lock again": assert => { + const lock = new Lock(); + lock.take(); + lock.release(); + assert.equal(lock.isTaken, false); + assert.equal(lock.take(), true); + }, + "2 waiting for lock, only first one gets it": async assert => { + const lock = new Lock(); + lock.take(); + + let first = false; + lock.released().then(() => first = lock.take()); + let second = false; + lock.released().then(() => second = lock.take()); + const promise = lock.released(); + lock.release(); + await promise; + assert.equal(first, true); + assert.equal(second, false); + }, + "await non-taken lock": async assert => { + const lock = new Lock(); + await lock.released(); + assert(true); + } + } +} diff --git a/src/utils/LockMap.js b/src/utils/LockMap.js new file mode 100644 index 00000000..f99776cc --- /dev/null +++ b/src/utils/LockMap.js @@ -0,0 +1,93 @@ +/* +Copyright 2020 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import {Lock} from "./Lock.js"; + +export class LockMap { + constructor() { + this._map = new Map(); + } + + async takeLock(key) { + let lock = this._map.get(key); + if (lock) { + while (!lock.take()) { + await lock.released(); + } + } else { + lock = new Lock(); + lock.take(); + this._map.set(key, lock); + } + // don't leave old locks lying around + lock.released().then(() => { + // give others a chance to take the lock first + Promise.resolve().then(() => { + if (!lock.isTaken) { + this._map.delete(key); + } + }); + }); + return lock; + } +} + +export function tests() { + return { + "taking a lock on the same key blocks": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + let second = false; + const prom = lockMap.takeLock("foo").then(() => { + second = true; + }); + assert.equal(second, false); + // do a delay to make sure prom does not resolve on its own + await Promise.resolve(); + lock.release(); + await prom; + assert.equal(second, true); + }, + "lock is not cleaned up with second request": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + let ranSecond = false; + const prom = lockMap.takeLock("foo").then(returnedLock => { + ranSecond = true; + assert.equal(returnedLock.isTaken, true); + // peek into internals, naughty + assert.equal(lockMap._map.get("foo"), returnedLock); + }); + lock.release(); + await prom; + // double delay to make sure cleanup logic ran + await Promise.resolve(); + await Promise.resolve(); + assert.equal(ranSecond, true); + }, + "lock is cleaned up without other request": async assert => { + const lockMap = new LockMap(); + const lock = await lockMap.takeLock("foo"); + await Promise.resolve(); + lock.release(); + // double delay to make sure cleanup logic ran + await Promise.resolve(); + await Promise.resolve(); + assert.equal(lockMap._map.has("foo"), false); + }, + + }; +}