forked from mystiq/hydrogen-web
Merge pull request #80 from vector-im/bwindels/olm-encrypt
Implement olm encryption
This commit is contained in:
commit
90867d9558
13 changed files with 706 additions and 102 deletions
|
@ -64,6 +64,9 @@ export class DeviceMessageHandler {
|
||||||
}
|
}
|
||||||
const readTxn = await this._storage.readTxn([this._storage.storeNames.session]);
|
const readTxn = await this._storage.readTxn([this._storage.storeNames.session]);
|
||||||
const pendingEvents = await this._getPendingEvents(readTxn);
|
const pendingEvents = await this._getPendingEvents(readTxn);
|
||||||
|
if (pendingEvents.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
// only know olm for now
|
// only know olm for now
|
||||||
const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM);
|
const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM);
|
||||||
const decryptChanges = await this._olmDecryption.decryptAll(olmEvents);
|
const decryptChanges = await this._olmDecryption.decryptAll(olmEvents);
|
||||||
|
|
|
@ -21,8 +21,11 @@ import {User} from "./User.js";
|
||||||
import {Account as E2EEAccount} from "./e2ee/Account.js";
|
import {Account as E2EEAccount} from "./e2ee/Account.js";
|
||||||
import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
|
import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
|
||||||
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
|
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
|
||||||
|
import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js";
|
||||||
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
|
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption.js";
|
||||||
import {DeviceTracker} from "./e2ee/DeviceTracker.js";
|
import {DeviceTracker} from "./e2ee/DeviceTracker.js";
|
||||||
|
import {LockMap} from "../utils/LockMap.js";
|
||||||
|
|
||||||
const PICKLE_KEY = "DEFAULT_KEY";
|
const PICKLE_KEY = "DEFAULT_KEY";
|
||||||
|
|
||||||
export class Session {
|
export class Session {
|
||||||
|
@ -42,18 +45,23 @@ export class Session {
|
||||||
this._olmUtil = null;
|
this._olmUtil = null;
|
||||||
this._e2eeAccount = null;
|
this._e2eeAccount = null;
|
||||||
this._deviceTracker = null;
|
this._deviceTracker = null;
|
||||||
|
this._olmEncryption = null;
|
||||||
if (olm) {
|
if (olm) {
|
||||||
this._olmUtil = new olm.Utility();
|
this._olmUtil = new olm.Utility();
|
||||||
this._deviceTracker = new DeviceTracker({
|
this._deviceTracker = new DeviceTracker({
|
||||||
storage,
|
storage,
|
||||||
getSyncToken: () => this.syncToken,
|
getSyncToken: () => this.syncToken,
|
||||||
olmUtil: this._olmUtil,
|
olmUtil: this._olmUtil,
|
||||||
|
ownUserId: sessionInfo.userId,
|
||||||
|
ownDeviceId: sessionInfo.deviceId,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// called once this._e2eeAccount is assigned
|
// called once this._e2eeAccount is assigned
|
||||||
_setupEncryption() {
|
_setupEncryption() {
|
||||||
|
console.log("loaded e2ee account with keys", this._e2eeAccount.identityKeys);
|
||||||
|
const senderKeyLock = new LockMap();
|
||||||
const olmDecryption = new OlmDecryption({
|
const olmDecryption = new OlmDecryption({
|
||||||
account: this._e2eeAccount,
|
account: this._e2eeAccount,
|
||||||
pickleKey: PICKLE_KEY,
|
pickleKey: PICKLE_KEY,
|
||||||
|
@ -61,6 +69,17 @@ export class Session {
|
||||||
ownUserId: this._user.id,
|
ownUserId: this._user.id,
|
||||||
storage: this._storage,
|
storage: this._storage,
|
||||||
olm: this._olm,
|
olm: this._olm,
|
||||||
|
senderKeyLock
|
||||||
|
});
|
||||||
|
this._olmEncryption = new OlmEncryption({
|
||||||
|
account: this._e2eeAccount,
|
||||||
|
pickleKey: PICKLE_KEY,
|
||||||
|
now: this._clock.now,
|
||||||
|
ownUserId: this._user.id,
|
||||||
|
storage: this._storage,
|
||||||
|
olm: this._olm,
|
||||||
|
olmUtil: this._olmUtil,
|
||||||
|
senderKeyLock
|
||||||
});
|
});
|
||||||
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
|
const megolmDecryption = new MegOlmDecryption({pickleKey: PICKLE_KEY, olm: this._olm});
|
||||||
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});
|
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption});
|
||||||
|
|
|
@ -126,8 +126,24 @@ export class Account {
|
||||||
|
|
||||||
createInboundOlmSession(senderKey, body) {
|
createInboundOlmSession(senderKey, body) {
|
||||||
const newSession = new this._olm.Session();
|
const newSession = new this._olm.Session();
|
||||||
newSession.create_inbound_from(this._account, senderKey, body);
|
try {
|
||||||
return newSession;
|
newSession.create_inbound_from(this._account, senderKey, body);
|
||||||
|
return newSession;
|
||||||
|
} catch (err) {
|
||||||
|
newSession.free();
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
createOutboundOlmSession(theirIdentityKey, theirOneTimeKey) {
|
||||||
|
const newSession = new this._olm.Session();
|
||||||
|
try {
|
||||||
|
newSession.create_outbound(this._account, theirIdentityKey, theirOneTimeKey);
|
||||||
|
return newSession;
|
||||||
|
} catch (err) {
|
||||||
|
newSession.free();
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
writeRemoveOneTimeKey(session, txn) {
|
writeRemoveOneTimeKey(session, txn) {
|
||||||
|
|
|
@ -14,13 +14,11 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import anotherjson from "../../../lib/another-json/index.js";
|
import {verifyEd25519Signature, SIGNATURE_ALGORITHM} from "./common.js";
|
||||||
|
|
||||||
const TRACKING_STATUS_OUTDATED = 0;
|
const TRACKING_STATUS_OUTDATED = 0;
|
||||||
const TRACKING_STATUS_UPTODATE = 1;
|
const TRACKING_STATUS_UPTODATE = 1;
|
||||||
|
|
||||||
const DEVICE_KEYS_SIGNATURE_ALGORITHM = "ed25519";
|
|
||||||
|
|
||||||
// map 1 device from /keys/query response to DeviceIdentity
|
// map 1 device from /keys/query response to DeviceIdentity
|
||||||
function deviceKeysAsDeviceIdentity(deviceSection) {
|
function deviceKeysAsDeviceIdentity(deviceSection) {
|
||||||
const deviceId = deviceSection["device_id"];
|
const deviceId = deviceSection["device_id"];
|
||||||
|
@ -36,11 +34,13 @@ function deviceKeysAsDeviceIdentity(deviceSection) {
|
||||||
}
|
}
|
||||||
|
|
||||||
export class DeviceTracker {
|
export class DeviceTracker {
|
||||||
constructor({storage, getSyncToken, olmUtil}) {
|
constructor({storage, getSyncToken, olmUtil, ownUserId, ownDeviceId}) {
|
||||||
this._storage = storage;
|
this._storage = storage;
|
||||||
this._getSyncToken = getSyncToken;
|
this._getSyncToken = getSyncToken;
|
||||||
this._identityChangedForRoom = null;
|
this._identityChangedForRoom = null;
|
||||||
this._olmUtil = olmUtil;
|
this._olmUtil = olmUtil;
|
||||||
|
this._ownUserId = ownUserId;
|
||||||
|
this._ownDeviceId = ownDeviceId;
|
||||||
}
|
}
|
||||||
|
|
||||||
async writeDeviceChanges(deviceLists, txn) {
|
async writeDeviceChanges(deviceLists, txn) {
|
||||||
|
@ -200,7 +200,11 @@ export class DeviceTracker {
|
||||||
if (deviceIdOnKeys !== deviceId) {
|
if (deviceIdOnKeys !== deviceId) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return this._verifyUserDeviceKeys(deviceKeys);
|
// don't store our own device
|
||||||
|
if (userId === this._ownUserId && deviceId === this._ownDeviceId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return this._hasValidSignature(deviceKeys);
|
||||||
});
|
});
|
||||||
const verifiedKeys = verifiedEntries.map(([, deviceKeys]) => deviceKeys);
|
const verifiedKeys = verifiedEntries.map(([, deviceKeys]) => deviceKeys);
|
||||||
return {userId, verifiedKeys};
|
return {userId, verifiedKeys};
|
||||||
|
@ -208,26 +212,11 @@ export class DeviceTracker {
|
||||||
return verifiedKeys;
|
return verifiedKeys;
|
||||||
}
|
}
|
||||||
|
|
||||||
_verifyUserDeviceKeys(deviceSection) {
|
_hasValidSignature(deviceSection) {
|
||||||
const deviceId = deviceSection["device_id"];
|
const deviceId = deviceSection["device_id"];
|
||||||
const userId = deviceSection["user_id"];
|
const userId = deviceSection["user_id"];
|
||||||
const clone = Object.assign({}, deviceSection);
|
const ed25519Key = deviceSection?.keys?.[`${SIGNATURE_ALGORITHM}:${deviceId}`];
|
||||||
delete clone.unsigned;
|
return verifyEd25519Signature(this._olmUtil, userId, deviceId, ed25519Key, deviceSection);
|
||||||
delete clone.signatures;
|
|
||||||
const canonicalJson = anotherjson.stringify(clone);
|
|
||||||
const key = deviceSection?.keys?.[`${DEVICE_KEYS_SIGNATURE_ALGORITHM}:${deviceId}`];
|
|
||||||
const signature = deviceSection?.signatures?.[userId]?.[`${DEVICE_KEYS_SIGNATURE_ALGORITHM}:${deviceId}`];
|
|
||||||
try {
|
|
||||||
if (!signature) {
|
|
||||||
throw new Error("no signature");
|
|
||||||
}
|
|
||||||
// throws when signature is invalid
|
|
||||||
this._olmUtil.ed25519_verify(key, canonicalJson, signature);
|
|
||||||
return true;
|
|
||||||
} catch (err) {
|
|
||||||
console.warn("Invalid device signature, ignoring device.", key, canonicalJson, signature, err);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -275,6 +264,10 @@ export class DeviceTracker {
|
||||||
if (queriedDevices && queriedDevices.length) {
|
if (queriedDevices && queriedDevices.length) {
|
||||||
flattenedDevices = flattenedDevices.concat(queriedDevices);
|
flattenedDevices = flattenedDevices.concat(queriedDevices);
|
||||||
}
|
}
|
||||||
return flattenedDevices;
|
// filter out our own devices if it got in somehow (even though we should not store it)
|
||||||
|
const devices = flattenedDevices.filter(device => {
|
||||||
|
return !(device.userId === this._ownUserId && device.deviceId === this._ownDeviceId);
|
||||||
|
});
|
||||||
|
return devices;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import anotherjson from "../../../lib/another-json/index.js";
|
||||||
|
|
||||||
// use common prefix so it's easy to clear properties that are not e2ee related during session clear
|
// use common prefix so it's easy to clear properties that are not e2ee related during session clear
|
||||||
export const SESSION_KEY_PREFIX = "e2ee:";
|
export const SESSION_KEY_PREFIX = "e2ee:";
|
||||||
export const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2";
|
export const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2";
|
||||||
|
@ -27,3 +29,24 @@ export class DecryptionError extends Error {
|
||||||
this.details = detailsObj;
|
this.details = detailsObj;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const SIGNATURE_ALGORITHM = "ed25519";
|
||||||
|
|
||||||
|
export function verifyEd25519Signature(olmUtil, userId, deviceOrKeyId, ed25519Key, value) {
|
||||||
|
const clone = Object.assign({}, value);
|
||||||
|
delete clone.unsigned;
|
||||||
|
delete clone.signatures;
|
||||||
|
const canonicalJson = anotherjson.stringify(clone);
|
||||||
|
const signature = value?.signatures?.[userId]?.[`${SIGNATURE_ALGORITHM}:${deviceOrKeyId}`];
|
||||||
|
try {
|
||||||
|
if (!signature) {
|
||||||
|
throw new Error("no signature");
|
||||||
|
}
|
||||||
|
// throws when signature is invalid
|
||||||
|
olmUtil.ed25519_verify(ed25519Key, canonicalJson, signature);
|
||||||
|
return true;
|
||||||
|
} catch (err) {
|
||||||
|
console.warn("Invalid signature, ignoring.", ed25519Key, canonicalJson, signature, err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {DecryptionError} from "../common.js";
|
import {DecryptionError} from "../common.js";
|
||||||
|
import {groupBy} from "../../../utils/groupBy.js";
|
||||||
|
import {Session} from "./Session.js";
|
||||||
|
|
||||||
const SESSION_LIMIT_PER_SENDER_KEY = 4;
|
const SESSION_LIMIT_PER_SENDER_KEY = 4;
|
||||||
|
|
||||||
|
@ -29,14 +31,14 @@ function sortSessions(sessions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
export class Decryption {
|
export class Decryption {
|
||||||
constructor({account, pickleKey, now, ownUserId, storage, olm}) {
|
constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) {
|
||||||
this._account = account;
|
this._account = account;
|
||||||
this._pickleKey = pickleKey;
|
this._pickleKey = pickleKey;
|
||||||
this._now = now;
|
this._now = now;
|
||||||
this._ownUserId = ownUserId;
|
this._ownUserId = ownUserId;
|
||||||
this._storage = storage;
|
this._storage = storage;
|
||||||
this._olm = olm;
|
this._olm = olm;
|
||||||
this._createOutboundSessionPromise = null;
|
this._senderKeyLock = senderKeyLock;
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
|
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
|
||||||
|
@ -49,26 +51,30 @@ export class Decryption {
|
||||||
//
|
//
|
||||||
// doing it one by one would be possible, but we would lose the opportunity for parallelization
|
// doing it one by one would be possible, but we would lose the opportunity for parallelization
|
||||||
async decryptAll(events) {
|
async decryptAll(events) {
|
||||||
const eventsPerSenderKey = events.reduce((map, event) => {
|
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]);
|
||||||
const senderKey = event.content?.["sender_key"];
|
|
||||||
let list = map.get(senderKey);
|
|
||||||
if (!list) {
|
|
||||||
list = [];
|
|
||||||
map.set(senderKey, list);
|
|
||||||
}
|
|
||||||
list.push(event);
|
|
||||||
return map;
|
|
||||||
}, new Map());
|
|
||||||
const timestamp = this._now();
|
const timestamp = this._now();
|
||||||
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
// take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen)
|
||||||
// decrypt events for different sender keys in parallel
|
// don't modify the sessions at the same time
|
||||||
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
const locks = await Promise.all(Array.from(eventsPerSenderKey.keys()).map(senderKey => {
|
||||||
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
return this._senderKeyLock.takeLock(senderKey);
|
||||||
}));
|
}));
|
||||||
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
try {
|
||||||
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||||
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
// decrypt events for different sender keys in parallel
|
||||||
return new DecryptionChanges(senderKeyDecryptions, payloads, errors);
|
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
|
||||||
|
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
|
||||||
|
}));
|
||||||
|
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
|
||||||
|
const errors = results.reduce((all, r) => all.concat(r.errors), []);
|
||||||
|
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
|
||||||
|
return new DecryptionChanges(senderKeyDecryptions, payloads, errors, this._account, locks);
|
||||||
|
} catch (err) {
|
||||||
|
// make sure the locks are release if something throws
|
||||||
|
for (const lock of locks) {
|
||||||
|
lock.release();
|
||||||
|
}
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
|
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
|
||||||
|
@ -105,7 +111,12 @@ export class Decryption {
|
||||||
plaintext = createResult.plaintext;
|
plaintext = createResult.plaintext;
|
||||||
}
|
}
|
||||||
if (typeof plaintext === "string") {
|
if (typeof plaintext === "string") {
|
||||||
const payload = JSON.parse(plaintext);
|
let payload;
|
||||||
|
try {
|
||||||
|
payload = JSON.parse(plaintext);
|
||||||
|
} catch (err) {
|
||||||
|
throw new DecryptionError("Could not JSON decode plaintext", event, {plaintext, err});
|
||||||
|
}
|
||||||
this._validatePayload(payload, event);
|
this._validatePayload(payload, event);
|
||||||
return {event: payload, senderKey};
|
return {event: payload, senderKey};
|
||||||
} else {
|
} else {
|
||||||
|
@ -177,44 +188,6 @@ export class Decryption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class Session {
|
|
||||||
constructor(data, pickleKey, olm, isNew = false) {
|
|
||||||
this.data = data;
|
|
||||||
this._olm = olm;
|
|
||||||
this._pickleKey = pickleKey;
|
|
||||||
this.isNew = isNew;
|
|
||||||
this.isModified = isNew;
|
|
||||||
}
|
|
||||||
|
|
||||||
static create(senderKey, olmSession, olm, pickleKey, timestamp) {
|
|
||||||
return new Session({
|
|
||||||
session: olmSession.pickle(pickleKey),
|
|
||||||
sessionId: olmSession.session_id(),
|
|
||||||
senderKey,
|
|
||||||
lastUsed: timestamp,
|
|
||||||
}, pickleKey, olm, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
get id() {
|
|
||||||
return this.data.sessionId;
|
|
||||||
}
|
|
||||||
|
|
||||||
load() {
|
|
||||||
const session = new this._olm.Session();
|
|
||||||
session.unpickle(this._pickleKey, this.data.session);
|
|
||||||
return session;
|
|
||||||
}
|
|
||||||
|
|
||||||
unload(olmSession) {
|
|
||||||
olmSession.free();
|
|
||||||
}
|
|
||||||
|
|
||||||
save(olmSession) {
|
|
||||||
this.data.session = olmSession.pickle(this._pickleKey);
|
|
||||||
this.isModified = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// decryption helper for a single senderKey
|
// decryption helper for a single senderKey
|
||||||
class SenderKeyDecryption {
|
class SenderKeyDecryption {
|
||||||
constructor(senderKey, sessions, olm, timestamp) {
|
constructor(senderKey, sessions, olm, timestamp) {
|
||||||
|
@ -280,11 +253,12 @@ class SenderKeyDecryption {
|
||||||
}
|
}
|
||||||
|
|
||||||
class DecryptionChanges {
|
class DecryptionChanges {
|
||||||
constructor(senderKeyDecryptions, payloads, errors, account) {
|
constructor(senderKeyDecryptions, payloads, errors, account, locks) {
|
||||||
this._senderKeyDecryptions = senderKeyDecryptions;
|
this._senderKeyDecryptions = senderKeyDecryptions;
|
||||||
this._account = account;
|
this._account = account;
|
||||||
this.payloads = payloads;
|
this.payloads = payloads;
|
||||||
this.errors = errors;
|
this.errors = errors;
|
||||||
|
this._locks = locks;
|
||||||
}
|
}
|
||||||
|
|
||||||
get hasNewSessions() {
|
get hasNewSessions() {
|
||||||
|
@ -292,25 +266,31 @@ class DecryptionChanges {
|
||||||
}
|
}
|
||||||
|
|
||||||
write(txn) {
|
write(txn) {
|
||||||
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
try {
|
||||||
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
for (const senderKeyDecryption of this._senderKeyDecryptions) {
|
||||||
txn.olmSessions.set(session.data);
|
for (const session of senderKeyDecryption.getModifiedSessions()) {
|
||||||
if (session.isNew) {
|
txn.olmSessions.set(session.data);
|
||||||
const olmSession = session.load();
|
if (session.isNew) {
|
||||||
try {
|
const olmSession = session.load();
|
||||||
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
try {
|
||||||
} finally {
|
this._account.writeRemoveOneTimeKey(olmSession, txn);
|
||||||
session.unload(olmSession);
|
} finally {
|
||||||
|
session.unload(olmSession);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
||||||
|
const {senderKey, sessions} = senderKeyDecryption;
|
||||||
|
// >= because index is zero-based
|
||||||
|
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
||||||
|
const session = sessions[i];
|
||||||
|
txn.olmSessions.remove(senderKey, session.id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
|
} finally {
|
||||||
const {senderKey, sessions} = senderKeyDecryption;
|
for (const lock of this._locks) {
|
||||||
// >= because index is zero-based
|
lock.release();
|
||||||
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
|
|
||||||
const session = sessions[i];
|
|
||||||
txn.olmSessions.remove(senderKey, session.id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
274
src/matrix/e2ee/olm/Encryption.js
Normal file
274
src/matrix/e2ee/olm/Encryption.js
Normal file
|
@ -0,0 +1,274 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {groupByWithCreator} from "../../../utils/groupBy.js";
|
||||||
|
import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js";
|
||||||
|
import {createSessionEntry} from "./Session.js";
|
||||||
|
|
||||||
|
function findFirstSessionId(sessionIds) {
|
||||||
|
return sessionIds.reduce((first, sessionId) => {
|
||||||
|
if (!first || sessionId < first) {
|
||||||
|
return sessionId;
|
||||||
|
} else {
|
||||||
|
return first;
|
||||||
|
}
|
||||||
|
}, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
const OTK_ALGORITHM = "signed_curve25519";
|
||||||
|
|
||||||
|
export class Encryption {
|
||||||
|
constructor({account, olm, olmUtil, ownUserId, storage, now, pickleKey, senderKeyLock}) {
|
||||||
|
this._account = account;
|
||||||
|
this._olm = olm;
|
||||||
|
this._olmUtil = olmUtil;
|
||||||
|
this._ownUserId = ownUserId;
|
||||||
|
this._storage = storage;
|
||||||
|
this._now = now;
|
||||||
|
this._pickleKey = pickleKey;
|
||||||
|
this._senderKeyLock = senderKeyLock;
|
||||||
|
}
|
||||||
|
|
||||||
|
async encrypt(type, content, devices, hsApi) {
|
||||||
|
// TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed)
|
||||||
|
// take a lock on all senderKeys so decryption and other calls to encrypt (should not happen)
|
||||||
|
// don't modify the sessions at the same time
|
||||||
|
const locks = await Promise.all(devices.map(device => {
|
||||||
|
return this._senderKeyLock.takeLock(device.curve25519Key);
|
||||||
|
}));
|
||||||
|
try {
|
||||||
|
const {
|
||||||
|
devicesWithoutSession,
|
||||||
|
existingEncryptionTargets,
|
||||||
|
} = await this._findExistingSessions(devices);
|
||||||
|
|
||||||
|
const timestamp = this._now();
|
||||||
|
|
||||||
|
let encryptionTargets = [];
|
||||||
|
try {
|
||||||
|
if (devicesWithoutSession.length) {
|
||||||
|
const newEncryptionTargets = await this._createNewSessions(
|
||||||
|
devicesWithoutSession, hsApi, timestamp);
|
||||||
|
encryptionTargets = encryptionTargets.concat(newEncryptionTargets);
|
||||||
|
}
|
||||||
|
await this._loadSessions(existingEncryptionTargets);
|
||||||
|
encryptionTargets = encryptionTargets.concat(existingEncryptionTargets);
|
||||||
|
const messages = encryptionTargets.map(target => {
|
||||||
|
const encryptedContent = this._encryptForDevice(type, content, target);
|
||||||
|
return new EncryptedMessage(encryptedContent, target.device);
|
||||||
|
});
|
||||||
|
await this._storeSessions(encryptionTargets, timestamp);
|
||||||
|
return messages;
|
||||||
|
} finally {
|
||||||
|
for (const target of encryptionTargets) {
|
||||||
|
target.dispose();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
for (const lock of locks) {
|
||||||
|
lock.release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async _findExistingSessions(devices) {
|
||||||
|
const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||||
|
const sessionIdsForDevice = await Promise.all(devices.map(async device => {
|
||||||
|
return await txn.olmSessions.getSessionIds(device.curve25519Key);
|
||||||
|
}));
|
||||||
|
const devicesWithoutSession = devices.filter((_, i) => {
|
||||||
|
const sessionIds = sessionIdsForDevice[i];
|
||||||
|
return !(sessionIds?.length);
|
||||||
|
});
|
||||||
|
|
||||||
|
const existingEncryptionTargets = devices.map((device, i) => {
|
||||||
|
const sessionIds = sessionIdsForDevice[i];
|
||||||
|
if (sessionIds?.length > 0) {
|
||||||
|
const sessionId = findFirstSessionId(sessionIds);
|
||||||
|
return EncryptionTarget.fromSessionId(device, sessionId);
|
||||||
|
}
|
||||||
|
}).filter(target => !!target);
|
||||||
|
|
||||||
|
return {devicesWithoutSession, existingEncryptionTargets};
|
||||||
|
}
|
||||||
|
|
||||||
|
_encryptForDevice(type, content, target) {
|
||||||
|
const {session, device} = target;
|
||||||
|
const plaintext = JSON.stringify(this._buildPlainTextMessageForDevice(type, content, device));
|
||||||
|
const message = session.encrypt(plaintext);
|
||||||
|
const encryptedContent = {
|
||||||
|
algorithm: OLM_ALGORITHM,
|
||||||
|
sender_key: this._account.identityKeys.curve25519,
|
||||||
|
ciphertext: {
|
||||||
|
[device.curve25519Key]: message
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return encryptedContent;
|
||||||
|
}
|
||||||
|
|
||||||
|
_buildPlainTextMessageForDevice(type, content, device) {
|
||||||
|
return {
|
||||||
|
keys: {
|
||||||
|
"ed25519": this._account.identityKeys.ed25519
|
||||||
|
},
|
||||||
|
recipient_keys: {
|
||||||
|
"ed25519": device.ed25519Key
|
||||||
|
},
|
||||||
|
recipient: device.userId,
|
||||||
|
sender: this._ownUserId,
|
||||||
|
content,
|
||||||
|
type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async _createNewSessions(devicesWithoutSession, hsApi, timestamp) {
|
||||||
|
const newEncryptionTargets = await this._claimOneTimeKeys(hsApi, devicesWithoutSession);
|
||||||
|
try {
|
||||||
|
for (const target of newEncryptionTargets) {
|
||||||
|
const {device, oneTimeKey} = target;
|
||||||
|
target.session = this._account.createOutboundOlmSession(device.curve25519Key, oneTimeKey);
|
||||||
|
}
|
||||||
|
this._storeSessions(newEncryptionTargets, timestamp);
|
||||||
|
} catch (err) {
|
||||||
|
for (const target of newEncryptionTargets) {
|
||||||
|
target.dispose();
|
||||||
|
}
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
return newEncryptionTargets;
|
||||||
|
}
|
||||||
|
|
||||||
|
async _claimOneTimeKeys(hsApi, deviceIdentities) {
|
||||||
|
// create a Map<userId, Map<deviceId, deviceIdentity>>
|
||||||
|
const devicesByUser = groupByWithCreator(deviceIdentities,
|
||||||
|
device => device.userId,
|
||||||
|
() => new Map(),
|
||||||
|
(deviceMap, device) => deviceMap.set(device.deviceId, device)
|
||||||
|
);
|
||||||
|
const oneTimeKeys = Array.from(devicesByUser.entries()).reduce((usersObj, [userId, deviceMap]) => {
|
||||||
|
usersObj[userId] = Array.from(deviceMap.values()).reduce((devicesObj, device) => {
|
||||||
|
devicesObj[device.deviceId] = OTK_ALGORITHM;
|
||||||
|
return devicesObj;
|
||||||
|
}, {});
|
||||||
|
return usersObj;
|
||||||
|
}, {});
|
||||||
|
const claimResponse = await hsApi.claimKeys({
|
||||||
|
timeout: 10000,
|
||||||
|
one_time_keys: oneTimeKeys
|
||||||
|
}).response();
|
||||||
|
// TODO: log claimResponse.failures
|
||||||
|
const userKeyMap = claimResponse?.["one_time_keys"];
|
||||||
|
return this._verifyAndCreateOTKTargets(userKeyMap, devicesByUser);
|
||||||
|
}
|
||||||
|
|
||||||
|
_verifyAndCreateOTKTargets(userKeyMap, devicesByUser) {
|
||||||
|
const verifiedEncryptionTargets = [];
|
||||||
|
for (const [userId, userSection] of Object.entries(userKeyMap)) {
|
||||||
|
for (const [deviceId, deviceSection] of Object.entries(userSection)) {
|
||||||
|
const [firstPropName, keySection] = Object.entries(deviceSection)[0];
|
||||||
|
const [keyAlgorithm] = firstPropName.split(":");
|
||||||
|
if (keyAlgorithm === OTK_ALGORITHM) {
|
||||||
|
const device = devicesByUser.get(userId)?.get(deviceId);
|
||||||
|
if (device) {
|
||||||
|
const isValidSignature = verifyEd25519Signature(
|
||||||
|
this._olmUtil, userId, deviceId, device.ed25519Key, keySection);
|
||||||
|
if (isValidSignature) {
|
||||||
|
const target = EncryptionTarget.fromOTK(device, keySection.key);
|
||||||
|
verifiedEncryptionTargets.push(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return verifiedEncryptionTargets;
|
||||||
|
}
|
||||||
|
|
||||||
|
async _loadSessions(encryptionTargets) {
|
||||||
|
const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
|
||||||
|
// given we run loading in parallel, there might still be some
|
||||||
|
// storage requests that will finish later once one has failed.
|
||||||
|
// those should not allocate a session anymore.
|
||||||
|
let failed = false;
|
||||||
|
try {
|
||||||
|
await Promise.all(encryptionTargets.map(async encryptionTarget => {
|
||||||
|
const sessionEntry = await txn.olmSessions.get(
|
||||||
|
encryptionTarget.device.curve25519Key, encryptionTarget.sessionId);
|
||||||
|
if (sessionEntry && !failed) {
|
||||||
|
const olmSession = new this._olm.Session();
|
||||||
|
olmSession.unpickle(this._pickleKey, sessionEntry.session);
|
||||||
|
encryptionTarget.session = olmSession;
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
} catch (err) {
|
||||||
|
failed = true;
|
||||||
|
// clean up the sessions that did load
|
||||||
|
for (const target of encryptionTargets) {
|
||||||
|
target.dispose();
|
||||||
|
}
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async _storeSessions(encryptionTargets, timestamp) {
|
||||||
|
const txn = await this._storage.readWriteTxn([this._storage.storeNames.olmSessions]);
|
||||||
|
try {
|
||||||
|
for (const target of encryptionTargets) {
|
||||||
|
const sessionEntry = createSessionEntry(
|
||||||
|
target.session, target.device.curve25519Key, timestamp, this._pickleKey);
|
||||||
|
txn.olmSessions.set(sessionEntry);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
txn.abort();
|
||||||
|
throw err;
|
||||||
|
}
|
||||||
|
await txn.complete();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// just a container needed to encrypt a message for a recipient device
|
||||||
|
// it is constructed with either a oneTimeKey
|
||||||
|
// (and later converted to a session) in case of a new session
|
||||||
|
// or an existing session
|
||||||
|
class EncryptionTarget {
|
||||||
|
constructor(device, oneTimeKey, sessionId) {
|
||||||
|
this.device = device;
|
||||||
|
this.oneTimeKey = oneTimeKey;
|
||||||
|
this.sessionId = sessionId;
|
||||||
|
// an olmSession, should probably be called olmSession
|
||||||
|
this.session = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
static fromOTK(device, oneTimeKey) {
|
||||||
|
return new EncryptionTarget(device, oneTimeKey, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
static fromSessionId(device, sessionId) {
|
||||||
|
return new EncryptionTarget(device, null, sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispose() {
|
||||||
|
if (this.session) {
|
||||||
|
this.session.free();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class EncryptedMessage {
|
||||||
|
constructor(content, device) {
|
||||||
|
this.content = content;
|
||||||
|
this.device = device;
|
||||||
|
}
|
||||||
|
}
|
58
src/matrix/e2ee/olm/Session.js
Normal file
58
src/matrix/e2ee/olm/Session.js
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export function createSessionEntry(olmSession, senderKey, timestamp, pickleKey) {
|
||||||
|
return {
|
||||||
|
session: olmSession.pickle(pickleKey),
|
||||||
|
sessionId: olmSession.session_id(),
|
||||||
|
senderKey,
|
||||||
|
lastUsed: timestamp,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Session {
|
||||||
|
constructor(data, pickleKey, olm, isNew = false) {
|
||||||
|
this.data = data;
|
||||||
|
this._olm = olm;
|
||||||
|
this._pickleKey = pickleKey;
|
||||||
|
this.isNew = isNew;
|
||||||
|
this.isModified = isNew;
|
||||||
|
}
|
||||||
|
|
||||||
|
static create(senderKey, olmSession, olm, pickleKey, timestamp) {
|
||||||
|
const data = createSessionEntry(olmSession, senderKey, timestamp, pickleKey);
|
||||||
|
return new Session(data, pickleKey, olm, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
get id() {
|
||||||
|
return this.data.sessionId;
|
||||||
|
}
|
||||||
|
|
||||||
|
load() {
|
||||||
|
const session = new this._olm.Session();
|
||||||
|
session.unpickle(this._pickleKey, this.data.session);
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
|
unload(olmSession) {
|
||||||
|
olmSession.free();
|
||||||
|
}
|
||||||
|
|
||||||
|
save(olmSession) {
|
||||||
|
this.data.session = olmSession.pickle(this._pickleKey);
|
||||||
|
this.isModified = true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -168,6 +168,10 @@ export class HomeServerApi {
|
||||||
return this._post("/keys/query", null, queryRequest, options);
|
return this._post("/keys/query", null, queryRequest, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claimKeys(payload, options = null) {
|
||||||
|
return this._post("/keys/claim", null, payload, options);
|
||||||
|
}
|
||||||
|
|
||||||
get mediaRepository() {
|
get mediaRepository() {
|
||||||
return this._mediaRepository;
|
return this._mediaRepository;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,11 +18,31 @@ function encodeKey(senderKey, sessionId) {
|
||||||
return `${senderKey}|${sessionId}`;
|
return `${senderKey}|${sessionId}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function decodeKey(key) {
|
||||||
|
const [senderKey, sessionId] = key.split("|");
|
||||||
|
return {senderKey, sessionId};
|
||||||
|
}
|
||||||
|
|
||||||
export class OlmSessionStore {
|
export class OlmSessionStore {
|
||||||
constructor(store) {
|
constructor(store) {
|
||||||
this._store = store;
|
this._store = store;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getSessionIds(senderKey) {
|
||||||
|
const sessionIds = [];
|
||||||
|
const range = IDBKeyRange.lowerBound(encodeKey(senderKey, ""));
|
||||||
|
await this._store.iterateKeys(range, key => {
|
||||||
|
const decodedKey = decodeKey(key);
|
||||||
|
// prevent running into the next room
|
||||||
|
if (decodedKey.senderKey === senderKey) {
|
||||||
|
sessionIds.push(decodedKey.sessionId);
|
||||||
|
return false; // fetch more
|
||||||
|
}
|
||||||
|
return true; // done
|
||||||
|
});
|
||||||
|
return sessionIds;
|
||||||
|
}
|
||||||
|
|
||||||
getAll(senderKey) {
|
getAll(senderKey) {
|
||||||
const range = IDBKeyRange.lowerBound(encodeKey(senderKey, ""));
|
const range = IDBKeyRange.lowerBound(encodeKey(senderKey, ""));
|
||||||
return this._store.selectWhile(range, session => {
|
return this._store.selectWhile(range, session => {
|
||||||
|
|
86
src/utils/Lock.js
Normal file
86
src/utils/Lock.js
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export class Lock {
|
||||||
|
constructor() {
|
||||||
|
this._promise = null;
|
||||||
|
this._resolve = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
take() {
|
||||||
|
if (!this._promise) {
|
||||||
|
this._promise = new Promise(resolve => {
|
||||||
|
this._resolve = resolve;
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
get isTaken() {
|
||||||
|
return !!this._promise;
|
||||||
|
}
|
||||||
|
|
||||||
|
release() {
|
||||||
|
if (this._resolve) {
|
||||||
|
this._promise = null;
|
||||||
|
const resolve = this._resolve;
|
||||||
|
this._resolve = null;
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
released() {
|
||||||
|
return this._promise;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function tests() {
|
||||||
|
return {
|
||||||
|
"taking a lock twice returns false": assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
assert.equal(lock.take(), true);
|
||||||
|
assert.equal(lock.isTaken, true);
|
||||||
|
assert.equal(lock.take(), false);
|
||||||
|
},
|
||||||
|
"can take a released lock again": assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
lock.release();
|
||||||
|
assert.equal(lock.isTaken, false);
|
||||||
|
assert.equal(lock.take(), true);
|
||||||
|
},
|
||||||
|
"2 waiting for lock, only first one gets it": async assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
|
||||||
|
let first;
|
||||||
|
lock.released().then(() => first = lock.take());
|
||||||
|
let second;
|
||||||
|
lock.released().then(() => second = lock.take());
|
||||||
|
const promise = lock.released();
|
||||||
|
lock.release();
|
||||||
|
await promise;
|
||||||
|
assert.strictEqual(first, true);
|
||||||
|
assert.strictEqual(second, false);
|
||||||
|
},
|
||||||
|
"await non-taken lock": async assert => {
|
||||||
|
const lock = new Lock();
|
||||||
|
await lock.released();
|
||||||
|
assert(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
93
src/utils/LockMap.js
Normal file
93
src/utils/LockMap.js
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Lock} from "./Lock.js";
|
||||||
|
|
||||||
|
export class LockMap {
|
||||||
|
constructor() {
|
||||||
|
this._map = new Map();
|
||||||
|
}
|
||||||
|
|
||||||
|
async takeLock(key) {
|
||||||
|
let lock = this._map.get(key);
|
||||||
|
if (lock) {
|
||||||
|
while (!lock.take()) {
|
||||||
|
await lock.released();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
lock = new Lock();
|
||||||
|
lock.take();
|
||||||
|
this._map.set(key, lock);
|
||||||
|
}
|
||||||
|
// don't leave old locks lying around
|
||||||
|
lock.released().then(() => {
|
||||||
|
// give others a chance to take the lock first
|
||||||
|
Promise.resolve().then(() => {
|
||||||
|
if (!lock.isTaken) {
|
||||||
|
this._map.delete(key);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return lock;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function tests() {
|
||||||
|
return {
|
||||||
|
"taking a lock on the same key blocks": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
let second = false;
|
||||||
|
const prom = lockMap.takeLock("foo").then(() => {
|
||||||
|
second = true;
|
||||||
|
});
|
||||||
|
assert.equal(second, false);
|
||||||
|
// do a delay to make sure prom does not resolve on its own
|
||||||
|
await Promise.resolve();
|
||||||
|
lock.release();
|
||||||
|
await prom;
|
||||||
|
assert.equal(second, true);
|
||||||
|
},
|
||||||
|
"lock is not cleaned up with second request": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
let ranSecond = false;
|
||||||
|
const prom = lockMap.takeLock("foo").then(returnedLock => {
|
||||||
|
ranSecond = true;
|
||||||
|
assert.equal(returnedLock.isTaken, true);
|
||||||
|
// peek into internals, naughty
|
||||||
|
assert.equal(lockMap._map.get("foo"), returnedLock);
|
||||||
|
});
|
||||||
|
lock.release();
|
||||||
|
await prom;
|
||||||
|
// double delay to make sure cleanup logic ran
|
||||||
|
await Promise.resolve();
|
||||||
|
await Promise.resolve();
|
||||||
|
assert.equal(ranSecond, true);
|
||||||
|
},
|
||||||
|
"lock is cleaned up without other request": async assert => {
|
||||||
|
const lockMap = new LockMap();
|
||||||
|
const lock = await lockMap.takeLock("foo");
|
||||||
|
await Promise.resolve();
|
||||||
|
lock.release();
|
||||||
|
// double delay to make sure cleanup logic ran
|
||||||
|
await Promise.resolve();
|
||||||
|
await Promise.resolve();
|
||||||
|
assert.equal(lockMap._map.has("foo"), false);
|
||||||
|
},
|
||||||
|
|
||||||
|
};
|
||||||
|
}
|
35
src/utils/groupBy.js
Normal file
35
src/utils/groupBy.js
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
/*
|
||||||
|
Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export function groupBy(array, groupFn) {
|
||||||
|
return groupByWithCreator(array, groupFn,
|
||||||
|
() => {return [];},
|
||||||
|
(array, value) => array.push(value)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function groupByWithCreator(array, groupFn, createCollectionFn, addCollectionFn) {
|
||||||
|
return array.reduce((map, value) => {
|
||||||
|
const key = groupFn(value);
|
||||||
|
let collection = map.get(key);
|
||||||
|
if (!collection) {
|
||||||
|
collection = createCollectionFn();
|
||||||
|
map.set(key, collection);
|
||||||
|
}
|
||||||
|
addCollectionFn(collection, value);
|
||||||
|
return map;
|
||||||
|
}, new Map());
|
||||||
|
}
|
Loading…
Reference in a new issue