to_device handler for encrypted messages

changes the api of the olm decryption to decrypt in batch
so we can isolate side-effects until we have a write-txn open
and we can parallelize the decryption of different sender keys.
This commit is contained in:
Bruno Windels 2020-09-02 13:33:27 +02:00
parent 3698dd9b92
commit 44e9f91d4c
4 changed files with 345 additions and 123 deletions

View file

@ -0,0 +1,87 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {OLM_ALGORITHM, MEGOLM_ALGORITHM} from "./e2ee/common.js";
// key to store in session store
const PENDING_ENCRYPTED_EVENTS = "pendingEncryptedDeviceEvents";
export class DeviceMessageHandler {
constructor({storage, olmDecryption, megolmEncryption}) {
this._storage = storage;
this._olmDecryption = olmDecryption;
this._megolmEncryption = megolmEncryption;
}
async writeSync(toDeviceEvents, txn) {
const encryptedEvents = toDeviceEvents.filter(e => e.type === "m.room.encrypted");
// store encryptedEvents
let pendingEvents = this._getPendingEvents(txn);
pendingEvents = pendingEvents.concat(encryptedEvents);
txn.session.set(PENDING_ENCRYPTED_EVENTS, pendingEvents);
// we don't handle anything other for now
}
async _handleDecryptedEvents(payloads, txn) {
const megOlmRoomKeysPayloads = payloads.filter(p => {
return p.event.type === "m.room_key" && p.event.content?.algorithm === MEGOLM_ALGORITHM;
});
let megolmChanges;
if (megOlmRoomKeysPayloads.length) {
megolmChanges = await this._megolmEncryption.addRoomKeys(megOlmRoomKeysPayloads, txn);
}
return {megolmChanges};
}
applyChanges({megolmChanges}) {
if (megolmChanges) {
this._megolmEncryption.applyRoomKeyChanges(megolmChanges);
}
}
// not safe to call multiple times without awaiting first call
async decryptPending() {
const readTxn = await this._storage.readTxn([this._storage.storeNames.session]);
const pendingEvents = this._getPendingEvents(readTxn);
// only know olm for now
const olmEvents = pendingEvents.filter(e => e.content?.algorithm === OLM_ALGORITHM);
const decryptChanges = await this._olmDecryption.decryptAll(olmEvents);
for (const err of decryptChanges.errors) {
console.warn("decryption failed for event", err, err.event);
}
const txn = await this._storage.readWriteTxn([
// both to remove the pending events and to modify the olm account
this._storage.storeNames.session,
this._storage.storeNames.olmSessions,
// this._storage.storeNames.megolmInboundSessions,
]);
let changes;
try {
changes = await this._handleDecryptedEvent(decryptChanges.payloads, txn);
decryptChanges.write(txn);
txn.session.remove(PENDING_ENCRYPTED_EVENTS);
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
this._applyChanges(changes);
}
async _getPendingEvents(txn) {
return (await txn.session.get(PENDING_ENCRYPTED_EVENTS)) || [];
}
}

View file

@ -20,9 +20,10 @@ export const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2";
export const MEGOLM_ALGORITHM = "m.megolm.v1.aes-sha2";
export class DecryptionError extends Error {
constructor(code, detailsObj = null) {
constructor(code, event, detailsObj = null) {
super(`Decryption error ${code}${detailsObj ? ": "+JSON.stringify(detailsObj) : ""}`);
this.code = code;
this.event = event;
this.details = detailsObj;
}
}

View file

@ -22,6 +22,12 @@ function isPreKeyMessage(message) {
return message.type === 0;
}
function sortSessions(sessions) {
sessions.sort((a, b) => {
return b.data.lastUsed - a.data.lastUsed;
});
}
export class Decryption {
constructor({account, pickleKey, now, ownUserId, storage, olm}) {
this._account = account;
@ -33,155 +39,279 @@ export class Decryption {
this._createOutboundSessionPromise = null;
}
// we can't run this in the sync txn because decryption will be async ...
// should we store the encrypted events in the sync loop and then pop them from there?
// it would be good in any case to run the (next) sync request in parallel with decryption
async decrypt(event) {
const senderKey = event.content?.["sender_key"];
// we need decryptAll because there is some parallelization we can do for decrypting different sender keys at once
// but for the same sender key we need to do one by one
//
// also we want to store the room key, etc ... in the same txn as we remove the pending encrypted event
//
// so we need to decrypt events in a batch (so we can decide which ones can run in parallel and which one one by one)
// and also can avoid side-effects before all can be stored this way
//
// doing it one by one would be possible, but we would lose the opportunity for parallelization
async decryptAll(events) {
const eventsPerSenderKey = events.reduce((map, event) => {
const senderKey = event.content?.["sender_key"];
let list = map.get(senderKey);
if (!list) {
list = [];
map.set(senderKey, list);
}
list.push(event);
return map;
}, new Map());
const timestamp = this._now();
const readSessionsTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
// decrypt events for different sender keys in parallel
const results = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
return this._decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn);
}));
const payloads = results.reduce((all, r) => all.concat(r.payloads), []);
const errors = results.reduce((all, r) => all.concat(r.errors), []);
const senderKeyDecryptions = results.map(r => r.senderKeyDecryption);
return new DecryptionChanges(senderKeyDecryptions, payloads, errors);
}
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) {
const sessions = await this._getSessions(senderKey, readSessionsTxn);
const senderKeyDecryption = new SenderKeyDecryption(senderKey, sessions, this._olm, timestamp);
const payloads = [];
const errors = [];
// events for a single senderKey need to be decrypted one by one
for (const event of events) {
try {
const payload = this._decryptForSenderKey(senderKeyDecryption, event, timestamp);
payloads.push(payload);
} catch (err) {
errors.push(err);
}
}
return {payloads, errors, senderKeyDecryption};
}
_decryptForSenderKey(senderKeyDecryption, event, timestamp) {
const senderKey = senderKeyDecryption.senderKey;
const message = this._getMessageAndValidateEvent(event);
let plaintext;
try {
plaintext = senderKeyDecryption.decrypt(message);
} catch (err) {
// TODO: is it ok that an error on one session prevents other sessions from being attempted?
throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", event, {senderKey, error: err.message});
}
// could not decrypt with any existing session
if (typeof plaintext !== "string" && isPreKeyMessage(message)) {
const createResult = this._createSessionAndDecrypt(senderKey, message, timestamp);
senderKeyDecryption.addNewSession(createResult.session);
plaintext = createResult.plaintext;
}
if (typeof plaintext === "string") {
const payload = JSON.parse(plaintext);
this._validatePayload(payload, event);
return {event: payload, senderKey};
} else {
throw new DecryptionError("Didn't find any session to decrypt with", event,
{sessionIds: senderKeyDecryption.sessions.map(s => s.id)});
}
}
// only for pre-key messages after having attempted decryption with existing sessions
_createSessionAndDecrypt(senderKey, message, timestamp) {
let plaintext;
// if we have multiple messages encrypted with the same new session,
// this could create multiple sessions as the OTK isn't removed yet
// (this only happens in DecryptionChanges.write)
// This should be ok though as we'll first try to decrypt with the new session
const olmSession = this._account.createInboundOlmSession(senderKey, message.body);
try {
plaintext = olmSession.decrypt(message.type, message.body);
const session = Session.create(senderKey, olmSession, this._olm, this._pickleKey, timestamp);
session.unload(olmSession);
return {session, plaintext};
} catch (err) {
olmSession.free();
throw err;
}
}
_getMessageAndValidateEvent(event) {
const ciphertext = event.content?.ciphertext;
if (!ciphertext) {
throw new DecryptionError("OLM_MISSING_CIPHERTEXT");
throw new DecryptionError("OLM_MISSING_CIPHERTEXT", event);
}
const message = ciphertext?.[this._account.identityKeys.curve25519];
if (!message) {
// TODO: use same error messages as element-web
throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS");
}
const sortedSessionIds = await this._getSortedSessionIds(senderKey);
let plaintext;
for (const sessionId of sortedSessionIds) {
try {
plaintext = await this._attemptDecryption(senderKey, sessionId, message);
} catch (err) {
throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", {senderKey, error: err.message});
}
if (typeof plaintext === "string") {
break;
}
}
if (typeof plaintext !== "string" && isPreKeyMessage(message)) {
plaintext = await this._createOutboundSessionAndDecrypt(senderKey, message, sortedSessionIds);
}
if (typeof plaintext === "string") {
return this._parseAndValidatePayload(plaintext, event);
throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS", event);
}
return message;
}
async _getSortedSessionIds(senderKey) {
const readTxn = await this._storage.readTxn([this._storage.storeNames.olmSessions]);
const sortedSessions = await readTxn.olmSessions.getAll(senderKey);
async _getSessions(senderKey, txn) {
const sessionEntries = await txn.olmSessions.getAll(senderKey);
// sort most recent used sessions first
sortedSessions.sort((a, b) => {
return b.lastUsed - a.lastUsed;
});
return sortedSessions.map(s => s.sessionId);
const sessions = sessionEntries.map(s => new Session(s, this._pickleKey, this._olm));
sortSessions(sessions);
return sessions;
}
async _createOutboundSessionAndDecrypt(senderKey, message, sortedSessionIds) {
// serialize calls so the account isn't written from multiple
// sessions at once
while (this._createOutboundSessionPromise) {
await this._createOutboundSessionPromise;
_validatePayload(payload, event) {
if (payload.sender !== event.sender) {
throw new DecryptionError("OLM_FORWARDED_MESSAGE", event, {sentBy: event.sender, encryptedBy: payload.sender});
}
this._createOutboundSessionPromise = (async () => {
try {
return await this._createOutboundSessionAndDecryptImpl(senderKey, message, sortedSessionIds);
} finally {
this._createOutboundSessionPromise = null;
}
})();
return await this._createOutboundSessionPromise;
}
// this could internally dispatch to a web-worker
async _createOutboundSessionAndDecryptImpl(senderKey, message, sortedSessionIds) {
let plaintext;
const session = this._account.createInboundOlmSession(senderKey, message.body);
try {
const txn = await this._storage.readWriteTxn([
this._storage.storeNames.session,
this._storage.storeNames.olmSessions,
]);
try {
// do this before removing the OTK removal, so we know decryption succeeded beforehand,
// as we don't have a way of undoing the OTK removal atm.
plaintext = session.decrypt(message.type, message.body);
this._account.writeRemoveOneTimeKey(session, txn);
// remove oldest session if we reach the limit including the new session
if (sortedSessionIds.length >= SESSION_LIMIT_PER_SENDER_KEY) {
// given they are sorted, the oldest one is the last one
const oldestSessionId = sortedSessionIds[sortedSessionIds.length - 1];
txn.olmSessions.remove(senderKey, oldestSessionId);
}
txn.olmSessions.set({
session: session.pickle(this._pickleKey),
sessionId: session.session_id(),
senderKey,
lastUsed: this._now(),
});
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
} finally {
session.free();
if (payload.recipient !== this._ownUserId) {
throw new DecryptionError("OLM_BAD_RECIPIENT", event, {recipient: payload.recipient});
}
return plaintext;
if (payload.recipient_keys?.ed25519 !== this._account.identityKeys.ed25519) {
throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", event, {key: payload.recipient_keys?.ed25519});
}
// TODO: check room_id
if (!payload.type) {
throw new DecryptionError("missing type on payload", event, {payload});
}
if (!payload.content) {
throw new DecryptionError("missing content on payload", event, {payload});
}
// TODO: how important is it to verify the message?
// we should look at payload.keys.ed25519 for that... and compare it to the key we have fetched
// from /keys/query, which we might not have done yet at this point.
}
}
class Session {
constructor(data, pickleKey, olm, isNew = false) {
this.data = data;
this._olm = olm;
this._pickleKey = pickleKey;
this.isNew = isNew;
this.isModified = isNew;
}
// this could internally dispatch to a web-worker
async _attemptDecryption(senderKey, sessionId, message) {
const txn = await this._storage.readWriteTxn([this._storage.storeNames.olmSessions]);
static create(senderKey, olmSession, olm, pickleKey, timestamp) {
return new Session({
session: olmSession.pickle(pickleKey),
sessionId: olmSession.session_id(),
senderKey,
lastUsed: timestamp,
}, pickleKey, olm, true);
}
get id() {
return this.data.sessionId;
}
load() {
const session = new this._olm.Session();
let plaintext;
session.unpickle(this._pickleKey, this.data.session);
return session;
}
unload(olmSession) {
olmSession.free();
}
save(olmSession) {
this.data.session = olmSession.pickle(this._pickleKey);
this.isModified = true;
}
}
// decryption helper for a single senderKey
class SenderKeyDecryption {
constructor(senderKey, sessions, olm, timestamp) {
this.senderKey = senderKey;
this.sessions = sessions;
this._olm = olm;
this._timestamp = timestamp;
}
addNewSession(session) {
// add at top as it is most recent
this.sessions.unshift(session);
}
decrypt(message) {
for (const session of this.sessions) {
const plaintext = this._decryptWithSession(session, message);
if (typeof plaintext === "string") {
// keep them sorted so will try the same session first for other messages
// and so we can assume the excess ones are at the end
// if they grow too large
sortSessions(this.sessions);
return plaintext;
}
}
}
getModifiedSessions() {
return this.sessions.filter(session => session.isModified);
}
get hasNewSessions() {
return this.sessions.some(session => session.isNew);
}
// this could internally dispatch to a web-worker
// and is why we unpickle/pickle on each iteration
// if this turns out to be a real cost for IE11,
// we could look into adding a less expensive serialization mechanism
// for olm sessions to libolm
_decryptWithSession(session, message) {
const olmSession = session.load();
try {
const sessionEntry = await txn.olmSessions.get(senderKey, sessionId);
session.unpickle(this._pickleKey, sessionEntry.session);
if (isPreKeyMessage(message) && !session.matches_inbound(message.body)) {
if (isPreKeyMessage(message) && !olmSession.matches_inbound(message.body)) {
return;
}
try {
plaintext = session.decrypt(message.type, message.body);
const plaintext = olmSession.decrypt(message.type, message.body);
session.save(olmSession);
session.lastUsed = this._timestamp;
return plaintext;
} catch (err) {
if (isPreKeyMessage(message)) {
throw new Error(`Error decrypting prekey message with existing session id ${sessionId}: ${err.message}`);
throw new Error(`Error decrypting prekey message with existing session id ${session.id}: ${err.message}`);
}
// decryption failed, bail out
return;
}
sessionEntry.session = session.pickle(this._pickleKey);
sessionEntry.lastUsed = this._now();
txn.olmSessions.set(sessionEntry);
} catch(err) {
txn.abort();
throw err;
} finally {
session.free();
session.unload(olmSession);
}
}
}
class DecryptionChanges {
constructor(senderKeyDecryptions, payloads, errors, account) {
this._senderKeyDecryptions = senderKeyDecryptions;
this._account = account;
this.payloads = payloads;
this.errors = errors;
}
get hasNewSessions() {
return this._senderKeyDecryptions.some(skd => skd.hasNewSessions);
}
write(txn) {
for (const senderKeyDecryption of this._senderKeyDecryptions) {
for (const session of senderKeyDecryption.getModifiedSessions()) {
txn.olmSessions.set(session.data);
if (session.isNew) {
const olmSession = session.load();
try {
this._account.writeRemoveOneTimeKey(olmSession, txn);
} finally {
session.unload(olmSession);
}
}
}
if (senderKeyDecryption.sessions.length > SESSION_LIMIT_PER_SENDER_KEY) {
const {senderKey, sessions} = senderKeyDecryption;
// >= because index is zero-based
for (let i = sessions.length - 1; i >= SESSION_LIMIT_PER_SENDER_KEY ; i -= 1) {
const session = sessions[i];
txn.olmSessions.remove(senderKey, session.id);
}
}
}
await txn.complete();
return plaintext;
}
_parseAndValidatePayload(plaintext, event) {
const payload = JSON.parse(plaintext);
if (payload.sender !== event.sender) {
throw new DecryptionError("OLM_FORWARDED_MESSAGE", {sentBy: event.sender, encryptedBy: payload.sender});
}
if (payload.recipient !== this._ownUserId) {
throw new DecryptionError("OLM_BAD_RECIPIENT", {recipient: payload.recipient});
}
if (payload.recipient_keys?.ed25519 !== this._account.identityKeys.ed25519) {
throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", {key: payload.recipient_keys?.ed25519});
}
// TODO: check room_id
if (!payload.type) {
throw new Error("missing type on payload");
}
if (!payload.content) {
throw new Error("missing content on payload");
}
return payload;
}
}

View file

@ -49,4 +49,8 @@ export class SessionStore {
add(key, value) {
return this._sessionStore.put({key, value});
}
remove(key) {
this._sessionStore.delete(key);
}
}