lock on senderKey while enc/decrypting olm sessions
This commit is contained in:
parent
4ecd853348
commit
4f4808b94c
5 changed files with 266 additions and 53 deletions
|
@ -23,6 +23,8 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
|
||||||
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
|
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
|
||||||
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
|
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
|
||||||
import {DeviceTracker} from "./e2ee/DeviceTracker.js";
|
import {DeviceTracker} from "./e2ee/DeviceTracker.js";
|
||||||
|
import {LockMap} from "../utils/LockMap.js";
|
||||||
|
|
||||||
const PICKLE_KEY = "DEFAULT_KEY";
|
const PICKLE_KEY = "DEFAULT_KEY";
|
||||||
|
|
||||||
export class Session {
|
export class Session {
|
||||||
|
@ -54,6 +56,7 @@ export class Session {
|
||||||
|
|
||||||
// called once this._e2eeAccount is assigned
|
// called once this._e2eeAccount is assigned
|
||||||
_setupEncryption() {
|
_setupEncryption() {
|
||||||
|
const senderKeyLock = new LockMap();
|
||||||
const olmDecryption = new OlmDecryption({
|
const olmDecryption = new OlmDecryption({
|
||||||
account: this._e2eeAccount,
|
account: this._e2eeAccount,
|
||||||
pickleKey: PICKLE_KEY,
|
pickleKey: PICKLE_KEY,
|
||||||
|
@ -61,6 +64,7 @@ export class Session {
|
||||||
ownUserId: this._user.id,
|
ownUserId: this._user.id,
|
||||||
storage: this._storage,
|
storage: this._storage,
|
||||||
olm: this._olm,
|
olm: this._olm,
|
||||||
|
senderKeyLock
|
||||||
});
|
});
|
||||||
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
|
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
|
||||||
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});
|
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});
|
||||||
|
|
|
@ -31,14 +31,14 @@ function sortSessions(sessions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Decryption {
|
export class Decryption {
|
||||||
constructor({account, pickleKey, now, ownUserId, storage, olm}) {
|
constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) {
|
||||||
this._account = account;
|
this._account = account;
|
||||||
this._pickleKey = pickleKey;
|
this._pickleKey = pickleKey;
|
||||||
this._now = now;
|
this._now = now;
|
||||||
this._ownUserId = ownUserId;
|
this._ownUserId = ownUserId;
|
||||||
this._storage = storage;
|
this._storage = storage;
|
||||||
this._olm = olm;
|
this._olm = olm;
|
||||||
this._createOutboundSessionPromise = null;
|
this._senderKeyLock = senderKeyLock;
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
|
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
|
||||||
|
@ -53,15 +53,28 @@ export class Decryption {
|
||||||
async decryptAll(events) {
|
async decryptAll(events) {
|
||||||
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]);
|
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]);
|
||||||
const timestamp = this._now();
|
const timestamp = this._now();
|
||||||
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
// take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen)
|
||||||
// decrypt events for different sender keys in parallel
|
// don't modify the sessions at the same time
|
||||||
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => {
|
||||||
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
return this._senderKeyLock.takeLock(senderKey);
|
||||||
}));
|
}));
|
||||||
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
try {
|
||||||
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||||
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
// decrypt events for different sender keys in parallel
|
||||||
return new DecryptionChanges(senderKeyDecryptions, payloads, errors);
|
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
||||||
|
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
||||||
|
}));
|
||||||
|
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
||||||
|
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
||||||
|
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
||||||
|
return new DecryptionChanges(senderKeyDecryptions, payloads, errors, locks);
|
||||||
|
} catch (err) {
|
||||||
|
// make sure the locks are release if something throws
|
||||||
|
for (const lock of locks) {
|
||||||
|
lock.release();
|
||||||
|
}
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
|
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
|
||||||
|
@ -235,11 +248,12 @@ class SenderKeyDecryption {
|
||||||
}
|
}
|
||||||
|
|
||||||
class DecryptionChanges {
|
class DecryptionChanges {
|
||||||
constructor(senderKeyDecryptions, payloads, errors, account) {
|
constructor(senderKeyDecryptions, payloads, errors, account, locks) {
|
||||||
this._senderKeyDecryptions = senderKeyDecryptions;
|
this._senderKeyDecryptions = senderKeyDecryptions;
|
||||||
this._account = account;
|
this._account = account;
|
||||||
this.payloads = payloads;
|
this.payloads = payloads;
|
||||||
this.errors = errors;
|
this.errors = errors;
|
||||||
|
this._locks = locks;
|
||||||
}
|
}
|
||||||
|
|
||||||
get hasNewSessions() {
|
get hasNewSessions() {
|
||||||
|
@ -247,25 +261,31 @@ class DecryptionChanges {
|
||||||
}
|
}
|
||||||
|
|
||||||
write(txn) {
|
write(txn) {
|
||||||
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
try {
|
||||||
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
||||||
txn.olmSessions.set(session.data);
|
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
||||||
if (session.isNew) {
|
txn.olmSessions.set(session.data);
|
||||||
const olmSession = session.load();
|
if (session.isNew) {
|
||||||
try {
|
const olmSession = session.load();
|
||||||
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
try {
|
||||||
} finally {
|
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
||||||
session.unload(olmSession);
|
} finally {
|
||||||
|
session.unload(olmSession);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
||||||
|
const {senderKey, sessions} = senderKeyDecryption;
|
||||||
|
// >= because index is zero-based
|
||||||
|
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
||||||
|
const session = sessions[i];
|
||||||
|
txn.olmSessions.remove(senderKey, session.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
} finally {
|
||||||
const {senderKey, sessions} = senderKeyDecryption;
|
for (const lock of this._locks) {
|
||||||
// >= because index is zero-based
|
lock.release();
|
||||||
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
|
||||||
const session = sessions[i];
|
|
||||||
txn.olmSessions.remove(senderKey, session.id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ function findFirstSessionId(sessionIds) {
|
||||||
const OTK_ALGORITHM = "signed_curve25519";
|
const OTK_ALGORITHM = "signed_curve25519";
|
||||||
|
|
||||||
export class Encryption {
|
export class Encryption {
|
||||||
constructor({account, olm, olmUtil, userId, storage, now, pickleKey}) {
|
constructor({account, olm, olmUtil, userId, storage, now, pickleKey, senderKeyLock}) {
|
||||||
this._account = account;
|
this._account = account;
|
||||||
this._olm = olm;
|
this._olm = olm;
|
||||||
this._olmUtil = olmUtil;
|
this._olmUtil = olmUtil;
|
||||||
|
@ -39,37 +39,47 @@ export class Encryption {
|
||||||
this._storage = storage;
|
this._storage = storage;
|
||||||
this._now = now;
|
this._now = now;
|
||||||
this._pickleKey = pickleKey;
|
this._pickleKey = pickleKey;
|
||||||
|
this._senderKeyLock = senderKeyLock;
|
||||||
}
|
}
|
||||||
|
|
||||||
async encrypt(type, content, devices, hsApi) {
|
async encrypt(type, content, devices, hsApi) {
|
||||||
const {
|
// TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed)
|
||||||
devicesWithoutSession,
|
// take a lock on all senderKeys so decryption and other calls to encrypt (should not happen)
|
||||||
existingEncryptionTargets
|
// don't modify the sessions at the same time
|
||||||
} = await this._findExistingSessions(devices);
|
const locks = await Promise.all(devices.map(device => {
|
||||||
|
return this._senderKeyLock.takeLock(device.curve25519Key);
|
||||||
const timestamp = this._now();
|
}));
|
||||||
|
|
||||||
let encryptionTargets = [];
|
|
||||||
try {
|
try {
|
||||||
if (devicesWithoutSession.length) {
|
const {
|
||||||
const newEncryptionTargets = await this._createNewSessions(
|
devicesWithoutSession,
|
||||||
devicesWithoutSession, hsApi, timestamp);
|
existingEncryptionTargets,
|
||||||
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
|
} = await this._findExistingSessions(devices);
|
||||||
|
|
||||||
|
const timestamp = this._now();
|
||||||
|
|
||||||
|
let encryptionTargets = [];
|
||||||
|
try {
|
||||||
|
if (devicesWithoutSession.length) {
|
||||||
|
const newEncryptionTargets = await this._createNewSessions(
|
||||||
|
devicesWithoutSession, hsApi, timestamp);
|
||||||
|
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
|
||||||
|
}
|
||||||
|
await this._loadSessions(existingEncryptionTargets);
|
||||||
|
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
|
||||||
|
const messages = encryptionTargets.map(target => {
|
||||||
|
const content = this._encryptForDevice(type, content, target);
|
||||||
|
return new EncryptedMessage(content, target.device);
|
||||||
|
});
|
||||||
|
await this._storeSessions(encryptionTargets, timestamp);
|
||||||
|
return messages;
|
||||||
|
} finally {
|
||||||
|
for (const target of encryptionTargets) {
|
||||||
|
target.dispose();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// TODO: if we read and write in two different txns,
|
|
||||||
// is there a chance we overwrite a session modified by the decryption during sync?
|
|
||||||
// I think so. We'll have to have a lock while sending ...
|
|
||||||
await this._loadSessions(existingEncryptionTargets);
|
|
||||||
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
|
|
||||||
const messages = encryptionTargets.map(target => {
|
|
||||||
const content = this._encryptForDevice(type, content, target);
|
|
||||||
return new EncryptedMessage(content, target.device);
|
|
||||||
});
|
|
||||||
await this._storeSessions(encryptionTargets, timestamp);
|
|
||||||
return messages;
|
|
||||||
} finally {
|
} finally {
|
||||||
for (const target of encryptionTargets) {
|
for (const lock of locks) {
|
||||||
target.dispose();
|
lock.release();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
86
src/utils/Lock.js
Normal file
86
src/utils/Lock.js
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export class Lock {
|
||||||
|
constructor() {
|
||||||
|
this._promise = null;
|
||||||
|
this._resolve = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
take() {
|
||||||
|
if (!this._promise) {
|
||||||
|
this._promise = new Promise(resolve => {
|
||||||
|
this._resolve = resolve;
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
get isTaken() {
|
||||||
|
return !!this._promise;
|
||||||
|
}
|
||||||
|
|
||||||
|
release() {
|
||||||
|
if (this._resolve) {
|
||||||
|
this._promise = null;
|
||||||
|
const resolve = this._resolve;
|
||||||
|
this._resolve = null;
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
released() {
|
||||||
|
return this._promise;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function tests() {
|
||||||
|
return {
|
||||||
|
"taking a lock twice returns false": assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
assert.equal(lock.take(), true);
|
||||||
|
assert.equal(lock.isTaken, true);
|
||||||
|
assert.equal(lock.take(), false);
|
||||||
|
},
|
||||||
|
"can take a released lock again": assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
lock.release();
|
||||||
|
assert.equal(lock.isTaken, false);
|
||||||
|
assert.equal(lock.take(), true);
|
||||||
|
},
|
||||||
|
"2 waiting for lock, only first one gets it": async assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
|
||||||
|
let first = false;
|
||||||
|
lock.released().then(() => first = lock.take());
|
||||||
|
let second = false;
|
||||||
|
lock.released().then(() => second = lock.take());
|
||||||
|
const promise = lock.released();
|
||||||
|
lock.release();
|
||||||
|
await promise;
|
||||||
|
assert.equal(first, true);
|
||||||
|
assert.equal(second, false);
|
||||||
|
},
|
||||||
|
"await non-taken lock": async assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
await lock.released();
|
||||||
|
assert(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
93
src/utils/LockMap.js
Normal file
93
src/utils/LockMap.js
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Lock} from "./Lock.js";
|
||||||
|
|
||||||
|
export class LockMap {
|
||||||
|
constructor() {
|
||||||
|
this._map = new Map();
|
||||||
|
}
|
||||||
|
|
||||||
|
async takeLock(key) {
|
||||||
|
let lock = this._map.get(key);
|
||||||
|
if (lock) {
|
||||||
|
while (!lock.take()) {
|
||||||
|
await lock.released();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
this._map.set(key, lock);
|
||||||
|
}
|
||||||
|
// don't leave old locks lying around
|
||||||
|
lock.released().then(() => {
|
||||||
|
// give others a chance to take the lock first
|
||||||
|
Promise.resolve().then(() => {
|
||||||
|
if (!lock.isTaken) {
|
||||||
|
this._map.delete(key);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return lock;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function tests() {
|
||||||
|
return {
|
||||||
|
"taking a lock on the same key blocks": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
let second = false;
|
||||||
|
const prom = lockMap.takeLock("foo").then(() => {
|
||||||
|
second = true;
|
||||||
|
});
|
||||||
|
assert.equal(second, false);
|
||||||
|
// do a delay to make sure prom does not resolve on its own
|
||||||
|
await Promise.resolve();
|
||||||
|
lock.release();
|
||||||
|
await prom;
|
||||||
|
assert.equal(second, true);
|
||||||
|
},
|
||||||
|
"lock is not cleaned up with second request": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
let ranSecond = false;
|
||||||
|
const prom = lockMap.takeLock("foo").then(returnedLock => {
|
||||||
|
ranSecond = true;
|
||||||
|
assert.equal(returnedLock.isTaken, true);
|
||||||
|
// peek into internals, naughty
|
||||||
|
assert.equal(lockMap._map.get("foo"), returnedLock);
|
||||||
|
});
|
||||||
|
lock.release();
|
||||||
|
await prom;
|
||||||
|
// double delay to make sure cleanup logic ran
|
||||||
|
await Promise.resolve();
|
||||||
|
await Promise.resolve();
|
||||||
|
assert.equal(ranSecond, true);
|
||||||
|
},
|
||||||
|
"lock is cleaned up without other request": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
await Promise.resolve();
|
||||||
|
lock.release();
|
||||||
|
// double delay to make sure cleanup logic ran
|
||||||
|
await Promise.resolve();
|
||||||
|
await Promise.resolve();
|
||||||
|
assert.equal(lockMap._map.has("foo"), false);
|
||||||
|
},
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
Reference in a new issue