forked from mystiq/hydrogen-web
lock on senderKey while enc/decrypting olm sessions
This commit is contained in:
parent
4ecd853348
commit
4f4808b94c
5 changed files with 266 additions and 53 deletions
|
@ -23,6 +23,8 @@ import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
|
|||
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
|
||||
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
|
||||
import {DeviceTracker} from "./e2ee/DeviceTracker.js";
|
||||
import {LockMap} from "../utils/LockMap.js";
|
||||
|
||||
const PICKLE_KEY = "DEFAULT_KEY";
|
||||
|
||||
export class Session {
|
||||
|
@ -54,6 +56,7 @@ export class Session {
|
|||
|
||||
// called once this._e2eeAccount is assigned
|
||||
_setupEncryption() {
|
||||
const senderKeyLock = new LockMap();
|
||||
const olmDecryption = new OlmDecryption({
|
||||
account: this._e2eeAccount,
|
||||
pickleKey: PICKLE_KEY,
|
||||
|
@ -61,6 +64,7 @@ export class Session {
|
|||
ownUserId: this._user.id,
|
||||
storage: this._storage,
|
||||
olm: this._olm,
|
||||
senderKeyLock
|
||||
});
|
||||
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
|
||||
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});
|
||||
|
|
|
@ -31,14 +31,14 @@ function sortSessions(sessions) {
|
|||
}
|
||||
|
||||
export class Decryption {
|
||||
constructor({account, pickleKey, now, ownUserId, storage, olm}) {
|
||||
constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) {
|
||||
this._account = account;
|
||||
this._pickleKey = pickleKey;
|
||||
this._now = now;
|
||||
this._ownUserId = ownUserId;
|
||||
this._storage = storage;
|
||||
this._olm = olm;
|
||||
this._createOutboundSessionPromise = null;
|
||||
this._senderKeyLock = senderKeyLock;
|
||||
}
|
||||
|
||||
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
|
||||
|
@ -53,15 +53,28 @@ export class Decryption {
|
|||
async decryptAll(events) {
|
||||
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]);
|
||||
const timestamp = this._now();
|
||||
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||
// decrypt events for different sender keys in parallel
|
||||
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
||||
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
||||
// take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen)
|
||||
// don't modify the sessions at the same time
|
||||
const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => {
|
||||
return this._senderKeyLock.takeLock(senderKey);
|
||||
}));
|
||||
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
||||
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
||||
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
||||
return new DecryptionChanges(senderKeyDecryptions, payloads, errors);
|
||||
try {
|
||||
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||
// decrypt events for different sender keys in parallel
|
||||
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
||||
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
||||
}));
|
||||
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
||||
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
||||
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
||||
return new DecryptionChanges(senderKeyDecryptions, payloads, errors, locks);
|
||||
} catch (err) {
|
||||
// make sure the locks are release if something throws
|
||||
for (const lock of locks) {
|
||||
lock.release();
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
|
||||
|
@ -235,11 +248,12 @@ class SenderKeyDecryption {
|
|||
}
|
||||
|
||||
class DecryptionChanges {
|
||||
constructor(senderKeyDecryptions, payloads, errors, account) {
|
||||
constructor(senderKeyDecryptions, payloads, errors, account, locks) {
|
||||
this._senderKeyDecryptions = senderKeyDecryptions;
|
||||
this._account = account;
|
||||
this.payloads = payloads;
|
||||
this.errors = errors;
|
||||
this._locks = locks;
|
||||
}
|
||||
|
||||
get hasNewSessions() {
|
||||
|
@ -247,25 +261,31 @@ class DecryptionChanges {
|
|||
}
|
||||
|
||||
write(txn) {
|
||||
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
||||
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
||||
txn.olmSessions.set(session.data);
|
||||
if (session.isNew) {
|
||||
const olmSession = session.load();
|
||||
try {
|
||||
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
||||
} finally {
|
||||
session.unload(olmSession);
|
||||
try {
|
||||
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
||||
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
||||
txn.olmSessions.set(session.data);
|
||||
if (session.isNew) {
|
||||
const olmSession = session.load();
|
||||
try {
|
||||
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
||||
} finally {
|
||||
session.unload(olmSession);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
||||
const {senderKey, sessions} = senderKeyDecryption;
|
||||
// >= because index is zero-based
|
||||
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
||||
const session = sessions[i];
|
||||
txn.olmSessions.remove(senderKey, session.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
||||
const {senderKey, sessions} = senderKeyDecryption;
|
||||
// >= because index is zero-based
|
||||
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
||||
const session = sessions[i];
|
||||
txn.olmSessions.remove(senderKey, session.id);
|
||||
}
|
||||
} finally {
|
||||
for (const lock of this._locks) {
|
||||
lock.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ function findFirstSessionId(sessionIds) {
|
|||
const OTK_ALGORITHM = "signed_curve25519";
|
||||
|
||||
export class Encryption {
|
||||
constructor({account, olm, olmUtil, userId, storage, now, pickleKey}) {
|
||||
constructor({account, olm, olmUtil, userId, storage, now, pickleKey, senderKeyLock}) {
|
||||
this._account = account;
|
||||
this._olm = olm;
|
||||
this._olmUtil = olmUtil;
|
||||
|
@ -39,37 +39,47 @@ export class Encryption {
|
|||
this._storage = storage;
|
||||
this._now = now;
|
||||
this._pickleKey = pickleKey;
|
||||
this._senderKeyLock = senderKeyLock;
|
||||
}
|
||||
|
||||
async encrypt(type, content, devices, hsApi) {
|
||||
const {
|
||||
devicesWithoutSession,
|
||||
existingEncryptionTargets
|
||||
} = await this._findExistingSessions(devices);
|
||||
|
||||
const timestamp = this._now();
|
||||
|
||||
let encryptionTargets = [];
|
||||
// TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed)
|
||||
// take a lock on all senderKeys so decryption and other calls to encrypt (should not happen)
|
||||
// don't modify the sessions at the same time
|
||||
const locks = await Promise.all(devices.map(device => {
|
||||
return this._senderKeyLock.takeLock(device.curve25519Key);
|
||||
}));
|
||||
try {
|
||||
if (devicesWithoutSession.length) {
|
||||
const newEncryptionTargets = await this._createNewSessions(
|
||||
devicesWithoutSession, hsApi, timestamp);
|
||||
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
|
||||
const {
|
||||
devicesWithoutSession,
|
||||
existingEncryptionTargets,
|
||||
} = await this._findExistingSessions(devices);
|
||||
|
||||
const timestamp = this._now();
|
||||
|
||||
let encryptionTargets = [];
|
||||
try {
|
||||
if (devicesWithoutSession.length) {
|
||||
const newEncryptionTargets = await this._createNewSessions(
|
||||
devicesWithoutSession, hsApi, timestamp);
|
||||
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
|
||||
}
|
||||
await this._loadSessions(existingEncryptionTargets);
|
||||
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
|
||||
const messages = encryptionTargets.map(target => {
|
||||
const content = this._encryptForDevice(type, content, target);
|
||||
return new EncryptedMessage(content, target.device);
|
||||
});
|
||||
await this._storeSessions(encryptionTargets, timestamp);
|
||||
return messages;
|
||||
} finally {
|
||||
for (const target of encryptionTargets) {
|
||||
target.dispose();
|
||||
}
|
||||
}
|
||||
// TODO: if we read and write in two different txns,
|
||||
// is there a chance we overwrite a session modified by the decryption during sync?
|
||||
// I think so. We'll have to have a lock while sending ...
|
||||
await this._loadSessions(existingEncryptionTargets);
|
||||
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
|
||||
const messages = encryptionTargets.map(target => {
|
||||
const content = this._encryptForDevice(type, content, target);
|
||||
return new EncryptedMessage(content, target.device);
|
||||
});
|
||||
await this._storeSessions(encryptionTargets, timestamp);
|
||||
return messages;
|
||||
} finally {
|
||||
for (const target of encryptionTargets) {
|
||||
target.dispose();
|
||||
for (const lock of locks) {
|
||||
lock.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
86
src/utils/Lock.js
Normal file
86
src/utils/Lock.js
Normal file
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
export class Lock {
|
||||
constructor() {
|
||||
this._promise = null;
|
||||
this._resolve = null;
|
||||
}
|
||||
|
||||
take() {
|
||||
if (!this._promise) {
|
||||
this._promise = new Promise(resolve => {
|
||||
this._resolve = resolve;
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
get isTaken() {
|
||||
return !!this._promise;
|
||||
}
|
||||
|
||||
release() {
|
||||
if (this._resolve) {
|
||||
this._promise = null;
|
||||
const resolve = this._resolve;
|
||||
this._resolve = null;
|
||||
resolve();
|
||||
}
|
||||
}
|
||||
|
||||
released() {
|
||||
return this._promise;
|
||||
}
|
||||
}
|
||||
|
||||
export function tests() {
|
||||
return {
|
||||
"taking a lock twice returns false": assert => {
|
||||
const lock = new Lock();
|
||||
assert.equal(lock.take(), true);
|
||||
assert.equal(lock.isTaken, true);
|
||||
assert.equal(lock.take(), false);
|
||||
},
|
||||
"can take a released lock again": assert => {
|
||||
const lock = new Lock();
|
||||
lock.take();
|
||||
lock.release();
|
||||
assert.equal(lock.isTaken, false);
|
||||
assert.equal(lock.take(), true);
|
||||
},
|
||||
"2 waiting for lock, only first one gets it": async assert => {
|
||||
const lock = new Lock();
|
||||
lock.take();
|
||||
|
||||
let first = false;
|
||||
lock.released().then(() => first = lock.take());
|
||||
let second = false;
|
||||
lock.released().then(() => second = lock.take());
|
||||
const promise = lock.released();
|
||||
lock.release();
|
||||
await promise;
|
||||
assert.equal(first, true);
|
||||
assert.equal(second, false);
|
||||
},
|
||||
"await non-taken lock": async assert => {
|
||||
const lock = new Lock();
|
||||
await lock.released();
|
||||
assert(true);
|
||||
}
|
||||
}
|
||||
}
|
93
src/utils/LockMap.js
Normal file
93
src/utils/LockMap.js
Normal file
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
import {Lock} from "./Lock.js";
|
||||
|
||||
export class LockMap {
|
||||
constructor() {
|
||||
this._map = new Map();
|
||||
}
|
||||
|
||||
async takeLock(key) {
|
||||
let lock = this._map.get(key);
|
||||
if (lock) {
|
||||
while (!lock.take()) {
|
||||
await lock.released();
|
||||
}
|
||||
} else {
|
||||
lock = new Lock();
|
||||
lock.take();
|
||||
this._map.set(key, lock);
|
||||
}
|
||||
// don't leave old locks lying around
|
||||
lock.released().then(() => {
|
||||
// give others a chance to take the lock first
|
||||
Promise.resolve().then(() => {
|
||||
if (!lock.isTaken) {
|
||||
this._map.delete(key);
|
||||
}
|
||||
});
|
||||
});
|
||||
return lock;
|
||||
}
|
||||
}
|
||||
|
||||
export function tests() {
|
||||
return {
|
||||
"taking a lock on the same key blocks": async assert => {
|
||||
const lockMap = new LockMap();
|
||||
const lock = await lockMap.takeLock("foo");
|
||||
let second = false;
|
||||
const prom = lockMap.takeLock("foo").then(() => {
|
||||
second = true;
|
||||
});
|
||||
assert.equal(second, false);
|
||||
// do a delay to make sure prom does not resolve on its own
|
||||
await Promise.resolve();
|
||||
lock.release();
|
||||
await prom;
|
||||
assert.equal(second, true);
|
||||
},
|
||||
"lock is not cleaned up with second request": async assert => {
|
||||
const lockMap = new LockMap();
|
||||
const lock = await lockMap.takeLock("foo");
|
||||
let ranSecond = false;
|
||||
const prom = lockMap.takeLock("foo").then(returnedLock => {
|
||||
ranSecond = true;
|
||||
assert.equal(returnedLock.isTaken, true);
|
||||
// peek into internals, naughty
|
||||
assert.equal(lockMap._map.get("foo"), returnedLock);
|
||||
});
|
||||
lock.release();
|
||||
await prom;
|
||||
// double delay to make sure cleanup logic ran
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
assert.equal(ranSecond, true);
|
||||
},
|
||||
"lock is cleaned up without other request": async assert => {
|
||||
const lockMap = new LockMap();
|
||||
const lock = await lockMap.takeLock("foo");
|
||||
await Promise.resolve();
|
||||
lock.release();
|
||||
// double delay to make sure cleanup logic ran
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
assert.equal(lockMap._map.has("foo"), false);
|
||||
},
|
||||
|
||||
};
|
||||
}
|
Loading…
Reference in a new issue