lock on senderKey while enc/decrypting olm sessions

This commit is contained in:
Bruno Windels 2020-09-03 12:12:33 +02:00
parent 4ecd853348
commit 4f4808b94c
5 changed files with 266 additions and 53 deletions

View file

@ -23,6 +23,8 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
import {DeviceTracker} from "./e2ee/DeviceTracker.js"; import {DeviceTracker} from "./e2ee/DeviceTracker.js";
import {LockMap} from "../utils/LockMap.js";
const PICKLE_KEY = "DEFAULT_KEY"; const PICKLE_KEY = "DEFAULT_KEY";
export class Session { export class Session {
@ -54,6 +56,7 @@ export class Session {
// called once this._e2eeAccount is assigned // called once this._e2eeAccount is assigned
_setupEncryption() { _setupEncryption() {
const senderKeyLock = new LockMap();
const olmDecryption = new OlmDecryption({ const olmDecryption = new OlmDecryption({
account: this._e2eeAccount, account: this._e2eeAccount,
pickleKey: PICKLE_KEY, pickleKey: PICKLE_KEY,
@ -61,6 +64,7 @@ export class Session {
ownUserId: this._user.id, ownUserId: this._user.id,
storage: this._storage, storage: this._storage,
olm: this._olm, olm: this._olm,
senderKeyLock
}); });
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm}); const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});

View file

@ -31,14 +31,14 @@ function sortSessions(sessions) {
} }
export class Decryption { export class Decryption {
constructor({account, pickleKey, now, ownUserId, storage, olm}) { constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) {
this._account = account; this._account = account;
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._now = now; this._now = now;
this._ownUserId = ownUserId; this._ownUserId = ownUserId;
this._storage = storage; this._storage = storage;
this._olm = olm; this._olm = olm;
this._createOutboundSessionPromise = null; this._senderKeyLock = senderKeyLock;
} }
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once // we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
@ -53,15 +53,28 @@ export class Decryption {
async decryptAll(events) { async decryptAll(events) {
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]);
const timestamp = this._now(); const timestamp = this._now();
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen)
// decrypt events for different sender keys in parallel // don't modify the sessions at the same time
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => {
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn); return this._senderKeyLock.takeLock(senderKey);
})); }));
const payloads = results.reduce((all, r) => all.concat(r.payloads), []); try {
const errors = results.reduce((all, r) => all.concat(r.errors), []); const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption); // decrypt events for different sender keys in parallel
return new DecryptionChanges(senderKeyDecryptions, payloads, errors); const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
}));
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
const errors = results.reduce((all, r) => all.concat(r.errors), []);
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
return new DecryptionChanges(senderKeyDecryptions, payloads, errors, locks);
} catch (err) {
// make sure the locks are release if something throws
for (const lock of locks) {
lock.release();
}
throw err;
}
} }
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) { async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
@ -235,11 +248,12 @@ class SenderKeyDecryption {
} }
class DecryptionChanges { class DecryptionChanges {
constructor(senderKeyDecryptions, payloads, errors, account) { constructor(senderKeyDecryptions, payloads, errors, account, locks) {
this._senderKeyDecryptions = senderKeyDecryptions; this._senderKeyDecryptions = senderKeyDecryptions;
this._account = account; this._account = account;
this.payloads = payloads; this.payloads = payloads;
this.errors = errors; this.errors = errors;
this._locks = locks;
} }
get hasNewSessions() { get hasNewSessions() {
@ -247,25 +261,31 @@ class DecryptionChanges {
} }
write(txn) { write(txn) {
for (const senderKeyDecryption of this._senderKeyDecryptions) { try {
for (const session of senderKeyDecryption.getModifiedSessions()) { for (const senderKeyDecryption of this._senderKeyDecryptions) {
txn.olmSessions.set(session.data); for (const session of senderKeyDecryption.getModifiedSessions()) {
if (session.isNew) { txn.olmSessions.set(session.data);
const olmSession = session.load(); if (session.isNew) {
try { const olmSession = session.load();
this._account.writeRemoveOneTimeKey(olmSession, txn); try {
} finally { this._account.writeRemoveOneTimeKey(olmSession, txn);
session.unload(olmSession); } finally {
session.unload(olmSession);
}
}
}
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
const {senderKey, sessions} = senderKeyDecryption;
// >= because index is zero-based
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
const session = sessions[i];
txn.olmSessions.remove(senderKey, session.id);
} }
} }
} }
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) { } finally {
const {senderKey, sessions} = senderKeyDecryption; for (const lock of this._locks) {
// >= because index is zero-based lock.release();
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
const session = sessions[i];
txn.olmSessions.remove(senderKey, session.id);
}
} }
} }
} }

View file

@ -31,7 +31,7 @@ function findFirstSessionId(sessionIds) {
const OTK_ALGORITHM = "signed_curve25519"; const OTK_ALGORITHM = "signed_curve25519";
export class Encryption { export class Encryption {
constructor({account, olm, olmUtil, userId, storage, now, pickleKey}) { constructor({account, olm, olmUtil, userId, storage, now, pickleKey, senderKeyLock}) {
this._account = account; this._account = account;
this._olm = olm; this._olm = olm;
this._olmUtil = olmUtil; this._olmUtil = olmUtil;
@ -39,37 +39,47 @@ export class Encryption {
this._storage = storage; this._storage = storage;
this._now = now; this._now = now;
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._senderKeyLock = senderKeyLock;
} }
async encrypt(type, content, devices, hsApi) { async encrypt(type, content, devices, hsApi) {
const { // TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed)
devicesWithoutSession, // take a lock on all senderKeys so decryption and other calls to encrypt (should not happen)
existingEncryptionTargets // don't modify the sessions at the same time
} = await this._findExistingSessions(devices); const locks = await Promise.all(devices.map(device => {
return this._senderKeyLock.takeLock(device.curve25519Key);
const timestamp = this._now(); }));
let encryptionTargets = [];
try { try {
if (devicesWithoutSession.length) { const {
const newEncryptionTargets = await this._createNewSessions( devicesWithoutSession,
devicesWithoutSession, hsApi, timestamp); existingEncryptionTargets,
encryptionTargets = encryptionTargets.concat(newEncryptionTargets); } = await this._findExistingSessions(devices);
const timestamp = this._now();
let encryptionTargets = [];
try {
if (devicesWithoutSession.length) {
const newEncryptionTargets = await this._createNewSessions(
devicesWithoutSession, hsApi, timestamp);
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
}
await this._loadSessions(existingEncryptionTargets);
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
const messages = encryptionTargets.map(target => {
const content = this._encryptForDevice(type, content, target);
return new EncryptedMessage(content, target.device);
});
await this._storeSessions(encryptionTargets, timestamp);
return messages;
} finally {
for (const target of encryptionTargets) {
target.dispose();
}
} }
// TODO: if we read and write in two different txns,
// is there a chance we overwrite a session modified by the decryption during sync?
// I think so. We'll have to have a lock while sending ...
await this._loadSessions(existingEncryptionTargets);
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
const messages = encryptionTargets.map(target => {
const content = this._encryptForDevice(type, content, target);
return new EncryptedMessage(content, target.device);
});
await this._storeSessions(encryptionTargets, timestamp);
return messages;
} finally { } finally {
for (const target of encryptionTargets) { for (const lock of locks) {
target.dispose(); lock.release();
} }
} }
} }

86
src/utils/Lock.js Normal file
View file

@ -0,0 +1,86 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
export class Lock {
constructor() {
this._promise = null;
this._resolve = null;
}
take() {
if (!this._promise) {
this._promise = new Promise(resolve => {
this._resolve = resolve;
});
return true;
}
return false;
}
get isTaken() {
return !!this._promise;
}
release() {
if (this._resolve) {
this._promise = null;
const resolve = this._resolve;
this._resolve = null;
resolve();
}
}
released() {
return this._promise;
}
}
export function tests() {
return {
"taking a lock twice returns false": assert => {
const lock = new Lock();
assert.equal(lock.take(), true);
assert.equal(lock.isTaken, true);
assert.equal(lock.take(), false);
},
"can take a released lock again": assert => {
const lock = new Lock();
lock.take();
lock.release();
assert.equal(lock.isTaken, false);
assert.equal(lock.take(), true);
},
"2 waiting for lock, only first one gets it": async assert => {
const lock = new Lock();
lock.take();
let first = false;
lock.released().then(() => first = lock.take());
let second = false;
lock.released().then(() => second = lock.take());
const promise = lock.released();
lock.release();
await promise;
assert.equal(first, true);
assert.equal(second, false);
},
"await non-taken lock": async assert => {
const lock = new Lock();
await lock.released();
assert(true);
}
}
}

93
src/utils/LockMap.js Normal file
View file

@ -0,0 +1,93 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {Lock} from "./Lock.js";
export class LockMap {
constructor() {
this._map = new Map();
}
async takeLock(key) {
let lock = this._map.get(key);
if (lock) {
while (!lock.take()) {
await lock.released();
}
} else {
lock = new Lock();
lock.take();
this._map.set(key, lock);
}
// don't leave old locks lying around
lock.released().then(() => {
// give others a chance to take the lock first
Promise.resolve().then(() => {
if (!lock.isTaken) {
this._map.delete(key);
}
});
});
return lock;
}
}
export function tests() {
return {
"taking a lock on the same key blocks": async assert => {
const lockMap = new LockMap();
const lock = await lockMap.takeLock("foo");
let second = false;
const prom = lockMap.takeLock("foo").then(() => {
second = true;
});
assert.equal(second, false);
// do a delay to make sure prom does not resolve on its own
await Promise.resolve();
lock.release();
await prom;
assert.equal(second, true);
},
"lock is not cleaned up with second request": async assert => {
const lockMap = new LockMap();
const lock = await lockMap.takeLock("foo");
let ranSecond = false;
const prom = lockMap.takeLock("foo").then(returnedLock => {
ranSecond = true;
assert.equal(returnedLock.isTaken, true);
// peek into internals, naughty
assert.equal(lockMap._map.get("foo"), returnedLock);
});
lock.release();
await prom;
// double delay to make sure cleanup logic ran
await Promise.resolve();
await Promise.resolve();
assert.equal(ranSecond, true);
},
"lock is cleaned up without other request": async assert => {
const lockMap = new LockMap();
const lock = await lockMap.takeLock("foo");
await Promise.resolve();
lock.release();
// double delay to make sure cleanup logic ran
await Promise.resolve();
await Promise.resolve();
assert.equal(lockMap._map.has("foo"), false);
},
};
}