expose multi-step decryption from RoomEncryption, adjust room timeline

sync code hasn't been adjusted yet
This commit is contained in:
Bruno Windels 2020-09-10 12:09:17 +02:00
parent 7c1f9dbed0
commit 1c77c3b876
5 changed files with 199 additions and 133 deletions

View file

@ -14,8 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {MEGOLM_ALGORITHM} from "./common.js"; import {MEGOLM_ALGORITHM, DecryptionSource} from "./common.js";
import {groupBy} from "../../utils/groupBy.js"; import {groupBy} from "../../utils/groupBy.js";
import {mergeMap} from "../../utils/mergeMap.js";
import {makeTxnId} from "../common.js"; import {makeTxnId} from "../common.js";
const ENCRYPTED_TYPE = "m.room.encrypted"; const ENCRYPTED_TYPE = "m.room.encrypted";
@ -55,23 +56,54 @@ export class RoomEncryption {
return await this._deviceTracker.writeMemberChanges(this._room, memberChanges, txn); return await this._deviceTracker.writeMemberChanges(this._room, memberChanges, txn);
} }
async decrypt(event, isSync, isTimelineOpen, retryData, txn) { // this happens before entries exists, as they are created by the syncwriter
// but we want to be able to map it back to something in the timeline easily
// when retrying decryption.
async prepareDecryptAll(events, source, isTimelineOpen, txn) {
const errors = [];
const validEvents = [];
for (const event of events) {
if (event.redacted_because || event.unsigned?.redacted_because) { if (event.redacted_because || event.unsigned?.redacted_because) {
return; continue;
} }
if (event.content?.algorithm !== MEGOLM_ALGORITHM) { if (event.content?.algorithm !== MEGOLM_ALGORITHM) {
throw new Error("Unsupported algorithm: " + event.content?.algorithm); errors.set(event.event_id, new Error("Unsupported algorithm: " + event.content?.algorithm));
} }
let sessionCache = isSync ? this._megolmSyncCache : this._megolmBackfillCache; validEvents.push(event);
const result = await this._megolmDecryption.decrypt(
this._room.id, event, sessionCache, txn);
if (!result) {
this._addMissingSessionEvent(event, isSync, retryData);
} }
if (result && isTimelineOpen) { let customCache;
let sessionCache;
if (source === DecryptionSource.Sync) {
sessionCache = this._megolmSyncCache;
} else if (source === DecryptionSource.Timeline) {
sessionCache = this._megolmBackfillCache;
} else if (source === DecryptionSource.Retry) {
// when retrying, we could have mixed events from at the bottom of the timeline (sync)
// and somewhere else, so create a custom cache we use just for this operation.
customCache = this._megolmEncryption.createSessionCache();
sessionCache = customCache;
} else {
throw new Error("Unknown source: " + source);
}
const preparation = await this._megolmDecryption.prepareDecryptAll(
this._room.id, validEvents, sessionCache, txn);
if (customCache) {
customCache.dispose();
}
return new DecryptionPreparation(preparation, errors, {isTimelineOpen}, this);
}
async _processDecryptionResults(results, errors, flags, txn) {
for (const error of errors.values()) {
if (error.code === "MEGOLM_NO_SESSION") {
this._addMissingSessionEvent(error.event);
}
}
if (flags.isTimelineOpen) {
for (const result of results.values()) {
await this._verifyDecryptionResult(result, txn); await this._verifyDecryptionResult(result, txn);
} }
return result; }
} }
async _verifyDecryptionResult(result, txn) { async _verifyDecryptionResult(result, txn) {
@ -87,30 +119,30 @@ export class RoomEncryption {
} }
} }
_addMissingSessionEvent(event, isSync, data) { _addMissingSessionEvent(event) {
const senderKey = event.content?.["sender_key"]; const senderKey = event.content?.["sender_key"];
const sessionId = event.content?.["session_id"]; const sessionId = event.content?.["session_id"];
const key = `${senderKey}|${sessionId}`; const key = `${senderKey}|${sessionId}`;
let eventIds = this._eventIdsByMissingSession.get(key); let eventIds = this._eventIdsByMissingSession.get(key);
if (!eventIds) { if (!eventIds) {
eventIds = new Map(); eventIds = new Set();
this._eventIdsByMissingSession.set(key, eventIds); this._eventIdsByMissingSession.set(key, eventIds);
} }
eventIds.set(event.event_id, {data, isSync}); eventIds.add(event.event_id);
} }
applyRoomKeys(roomKeys) { applyRoomKeys(roomKeys) {
// retry decryption with the new sessions // retry decryption with the new sessions
const retryEntries = []; const retryEventIds = [];
for (const roomKey of roomKeys) { for (const roomKey of roomKeys) {
const key = `${roomKey.senderKey}|${roomKey.sessionId}`; const key = `${roomKey.senderKey}|${roomKey.sessionId}`;
const entriesForSession = this._eventIdsByMissingSession.get(key); const entriesForSession = this._eventIdsByMissingSession.get(key);
if (entriesForSession) { if (entriesForSession) {
this._eventIdsByMissingSession.delete(key); this._eventIdsByMissingSession.delete(key);
retryEntries.push(...entriesForSession.values()); retryEventIds.push(...entriesForSession);
} }
} }
return retryEntries; return retryEventIds;
} }
async encrypt(type, content, hsApi) { async encrypt(type, content, hsApi) {
@ -214,3 +246,67 @@ export class RoomEncryption {
await hsApi.sendToDevice(type, payload, txnId).response(); await hsApi.sendToDevice(type, payload, txnId).response();
} }
} }
/**
* wrappers around megolm decryption classes to be able to post-process
* the decryption results before turning them
*/
class DecryptionPreparation {
constructor(megolmDecryptionPreparation, extraErrors, flags, roomEncryption) {
this._megolmDecryptionPreparation = megolmDecryptionPreparation;
this._extraErrors = extraErrors;
this._flags = flags;
this._roomEncryption = roomEncryption;
}
async decrypt() {
return new DecryptionChanges(
await this._megolmDecryptionPreparation.decrypt(),
this._extraErrors,
this._flags,
this._roomEncryption);
}
dispose() {
this._megolmDecryptionChanges.dispose();
}
}
class DecryptionChanges {
constructor(megolmDecryptionChanges, extraErrors, flags, roomEncryption) {
this._megolmDecryptionChanges = megolmDecryptionChanges;
this._extraErrors = extraErrors;
this._flags = flags;
this._roomEncryption = roomEncryption;
}
async write(txn) {
const {results, errors} = await this._megolmDecryptionChanges.write(txn);
mergeMap(this._extraErrors, errors);
await this._roomEncryption._processDecryptionResults(results, errors, this._flags, txn);
return new BatchDecryptionResult(results, errors);
}
}
class BatchDecryptionResult {
constructor(results, errors) {
this.results = results;
this.errors = errors;
console.log("BatchDecryptionResult", this);
}
applyToEntries(entries) {
console.log("BatchDecryptionResult.applyToEntries", this);
for (const entry of entries) {
const result = this.results.get(entry.id);
if (result) {
entry.setDecryptionResult(result);
} else {
const error = this.errors.get(entry.id);
if (error) {
entry.setDecryptionError(error);
}
}
}
}
}

View file

@ -15,6 +15,9 @@ limitations under the License.
*/ */
import anotherjson from "../../../lib/another-json/index.js"; import anotherjson from "../../../lib/another-json/index.js";
import {createEnum} from "../../utils/enum.js";
export const DecryptionSource = createEnum(["Sync", "Timeline", "Retry"]);
// use common prefix so it's easy to clear properties that are not e2ee related during session clear // use common prefix so it's easy to clear properties that are not e2ee related during session clear
export const SESSION_KEY_PREFIX = "e2ee:"; export const SESSION_KEY_PREFIX = "e2ee:";

View file

@ -26,6 +26,9 @@ import {fetchOrLoadMembers} from "./members/load.js";
import {MemberList} from "./members/MemberList.js"; import {MemberList} from "./members/MemberList.js";
import {Heroes} from "./members/Heroes.js"; import {Heroes} from "./members/Heroes.js";
import {EventEntry} from "./timeline/entries/EventEntry.js"; import {EventEntry} from "./timeline/entries/EventEntry.js";
import {DecryptionSource} from "../e2ee/common.js";
const EVENT_ENCRYPTED_TYPE = "m.room.encrypted";
export class Room extends EventEmitter { export class Room extends EventEmitter {
constructor({roomId, storage, hsApi, emitCollectionChange, sendScheduler, pendingEvents, user, createRoomEncryption, getSyncToken}) { constructor({roomId, storage, hsApi, emitCollectionChange, sendScheduler, pendingEvents, user, createRoomEncryption, getSyncToken}) {
@ -49,67 +52,70 @@ export class Room extends EventEmitter {
async notifyRoomKeys(roomKeys) { async notifyRoomKeys(roomKeys) {
if (this._roomEncryption) { if (this._roomEncryption) {
// array of {data, isSync} let retryEventIds = this._roomEncryption.applyRoomKeys(roomKeys);
let retryEntries = this._roomEncryption.applyRoomKeys(roomKeys); if (retryEventIds.length) {
let decryptedEntries = []; const retryEntries = [];
if (retryEntries.length) { const txn = await this._storage.readTxn([
// groupSessionDecryptions can be written, the other stores not
const txn = await this._storage.readWriteTxn([
this._storage.storeNames.timelineEvents, this._storage.storeNames.timelineEvents,
this._storage.storeNames.inboundGroupSessions, this._storage.storeNames.inboundGroupSessions,
this._storage.storeNames.groupSessionDecryptions,
this._storage.storeNames.deviceIdentities,
]); ]);
try { for (const eventId of retryEventIds) {
for (const retryEntry of retryEntries) { const storageEntry = await txn.timelineEvents.getByEventId(this._roomId, eventId);
const {data: eventKey} = retryEntry;
let entry = this._timeline?.findEntry(eventKey);
if (!entry) {
const storageEntry = await txn.timelineEvents.get(this._roomId, eventKey);
if (storageEntry) { if (storageEntry) {
entry = new EventEntry(storageEntry, this._fragmentIdComparer); retryEntries.push(new EventEntry(storageEntry, this._fragmentIdComparer));
} }
} }
if (entry) { await this._decryptEntries(DecryptionSource.Retry, retryEntries, txn);
entry = await this._decryptEntry(entry, txn, retryEntry.isSync);
decryptedEntries.push(entry);
}
}
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
}
if (this._timeline) { if (this._timeline) {
// only adds if already present // only adds if already present
this._timeline.replaceEntries(decryptedEntries); this._timeline.replaceEntries(retryEntries);
} }
// pass decryptedEntries to roomSummary // pass decryptedEntries to roomSummary
} }
} }
}
_enableEncryption(encryptionParams) { _enableEncryption(encryptionParams) {
this._roomEncryption = this._createRoomEncryption(this, encryptionParams); this._roomEncryption = this._createRoomEncryption(this, encryptionParams);
if (this._roomEncryption) { if (this._roomEncryption) {
this._sendQueue.enableEncryption(this._roomEncryption); this._sendQueue.enableEncryption(this._roomEncryption);
if (this._timeline) { if (this._timeline) {
this._timeline.enableEncryption(this._decryptEntries.bind(this)); this._timeline.enableEncryption(this._decryptEntries.bind(this, DecryptionSource.Timeline));
} }
} }
} }
async _decryptEntry(entry, txn, isSync) { /**
if (entry.eventType === "m.room.encrypted") { * Used for decrypting when loading/filling the timeline, and retrying decryption,
try { * not during sync, where it is split up during the multiple phases.
const decryptionResult = await this._roomEncryption.decrypt( */
entry.event, isSync, !!this._timeline, entry.asEventKey(), txn); async _decryptEntries(source, entries, inboundSessionTxn = null) {
if (decryptionResult) { if (!inboundSessionTxn) {
entry.setDecryptionResult(decryptionResult); inboundSessionTxn = await this._storage.readTxn([this._storage.storeNames.inboundGroupSessions]);
} }
const events = entries.filter(entry => {
return entry.eventType === EVENT_ENCRYPTED_TYPE;
}).map(entry => entry.event);
const isTimelineOpen = this._isTimelineOpen;
const preparation = await this._roomEncryption.prepareDecryptAll(events, source, isTimelineOpen, inboundSessionTxn);
const changes = await preparation.decrypt();
const stores = [this._storage.storeNames.groupSessionDecryptions];
if (isTimelineOpen) {
// read to fetch devices if timeline is open
stores.push(this._storage.storeNames.deviceIdentities);
}
const writeTxn = await this._storage.readWriteTxn(stores);
let decryption;
try {
decryption = await changes.write(writeTxn);
} catch (err) { } catch (err) {
console.warn("event decryption error", err, entry.event); writeTxn.abort();
entry.setDecryptionError(err); throw err;
}
await writeTxn.complete();
decryption.applyToEntries(entries);
}
} }
} }
return entry; return entry;
@ -299,19 +305,11 @@ export class Room extends EventEmitter {
} }
}).response(); }).response();
let stores = [ const txn = await this._storage.readWriteTxn([
this._storage.storeNames.pendingEvents, this._storage.storeNames.pendingEvents,
this._storage.storeNames.timelineEvents, this._storage.storeNames.timelineEvents,
this._storage.storeNames.timelineFragments, this._storage.storeNames.timelineFragments,
];
if (this._roomEncryption) {
stores = stores.concat([
this._storage.storeNames.inboundGroupSessions,
this._storage.storeNames.groupSessionDecryptions,
this._storage.storeNames.deviceIdentities,
]); ]);
}
const txn = await this._storage.readWriteTxn(stores);
let removedPendingEvents; let removedPendingEvents;
let gapResult; let gapResult;
try { try {
@ -324,14 +322,14 @@ export class Room extends EventEmitter {
fragmentIdComparer: this._fragmentIdComparer, fragmentIdComparer: this._fragmentIdComparer,
}); });
gapResult = await gapWriter.writeFragmentFill(fragmentEntry, response, txn); gapResult = await gapWriter.writeFragmentFill(fragmentEntry, response, txn);
if (this._roomEncryption) {
gapResult.entries = await this._decryptEntries(gapResult.entries, txn, false);
}
} catch (err) { } catch (err) {
txn.abort(); txn.abort();
throw err; throw err;
} }
await txn.complete(); await txn.complete();
if (this._roomEncryption) {
await this._decryptEntries(DecryptionSource.Timeline, gapResult.entries);
}
// once txn is committed, update in-memory state & emit events // once txn is committed, update in-memory state & emit events
for (const fragment of gapResult.fragments) { for (const fragment of gapResult.fragments) {
this._fragmentIdComparer.add(fragment); this._fragmentIdComparer.add(fragment);
@ -406,6 +404,10 @@ export class Room extends EventEmitter {
} }
} }
get _isTimelineOpen() {
return !!this._timeline;
}
async clearUnread() { async clearUnread() {
if (this.isUnread || this.notificationCount) { if (this.isUnread || this.notificationCount) {
const txn = await this._storage.readWriteTxn([ const txn = await this._storage.readWriteTxn([
@ -458,7 +460,7 @@ export class Room extends EventEmitter {
user: this._user, user: this._user,
}); });
if (this._roomEncryption) { if (this._roomEncryption) {
this._timeline.enableEncryption(this._decryptEntries.bind(this)); this._timeline.enableEncryption(this._decryptEntries.bind(this, DecryptionSource.Timeline));
} }
await this._timeline.load(); await this._timeline.load();
return this._timeline; return this._timeline;

View file

@ -46,21 +46,6 @@ export class Timeline {
this._remoteEntries.setManySorted(entries); this._remoteEntries.setManySorted(entries);
} }
findEntry(eventKey) {
// a storage event entry has a fragmentId and eventIndex property, used for sorting,
// just like an EventKey, so this will work, but perhaps a bit brittle.
const entry = new EventEntry(eventKey, this._fragmentIdComparer);
try {
const idx = this._remoteEntries.indexOf(entry);
if (idx !== -1) {
return this._remoteEntries.get(idx);
}
} catch (err) {
// fragmentIdComparer threw, ignore
return;
}
}
replaceEntries(entries) { replaceEntries(entries) {
for (const entry of entries) { for (const entry of entries) {
this._remoteEntries.replace(entry); this._remoteEntries.replace(entry);

View file

@ -32,34 +32,19 @@ export class TimelineReader {
} }
_openTxn() { _openTxn() {
const stores = [
this._storage.storeNames.timelineEvents,
this._storage.storeNames.timelineFragments,
];
if (this._decryptEntries) { if (this._decryptEntries) {
return this._storage.readWriteTxn([ stores.push(this._storage.storeNames.inboundGroupSessions);
this._storage.storeNames.timelineEvents,
this._storage.storeNames.timelineFragments,
this._storage.storeNames.inboundGroupSessions,
this._storage.storeNames.groupSessionDecryptions,
this._storage.storeNames.deviceIdentities,
]);
} else {
return this._storage.readTxn([
this._storage.storeNames.timelineEvents,
this._storage.storeNames.timelineFragments,
]);
} }
return this._storage.readTxn(stores);
} }
async readFrom(eventKey, direction, amount) { async readFrom(eventKey, direction, amount) {
const txn = await this._openTxn(); const txn = await this._openTxn();
let entries; return await this._readFrom(eventKey, direction, amount, txn);
try {
entries = await this._readFrom(eventKey, direction, amount, txn);
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
return entries;
} }
async _readFrom(eventKey, direction, amount, txn) { async _readFrom(eventKey, direction, amount, txn) {
@ -75,9 +60,6 @@ export class TimelineReader {
eventsWithinFragment = await timelineStore.eventsBefore(this._roomId, eventKey, amount); eventsWithinFragment = await timelineStore.eventsBefore(this._roomId, eventKey, amount);
} }
let eventEntries = eventsWithinFragment.map(e => new EventEntry(e, this._fragmentIdComparer)); let eventEntries = eventsWithinFragment.map(e => new EventEntry(e, this._fragmentIdComparer));
if (this._decryptEntries) {
eventEntries = await this._decryptEntries(eventEntries, txn);
}
entries = directionalConcat(entries, eventEntries, direction); entries = directionalConcat(entries, eventEntries, direction);
// prepend or append eventsWithinFragment to entries, and wrap them in EventEntry // prepend or append eventsWithinFragment to entries, and wrap them in EventEntry
@ -100,14 +82,17 @@ export class TimelineReader {
} }
} }
if (this._decryptEntries) {
await this._decryptEntries(entries, txn);
}
return entries; return entries;
} }
async readFromEnd(amount) { async readFromEnd(amount) {
const txn = await this._openTxn(); const txn = await this._openTxn();
let entries;
try {
const liveFragment = await txn.timelineFragments.liveFragment(this._roomId); const liveFragment = await txn.timelineFragments.liveFragment(this._roomId);
let entries;
// room hasn't been synced yet // room hasn't been synced yet
if (!liveFragment) { if (!liveFragment) {
entries = []; entries = [];
@ -118,11 +103,6 @@ export class TimelineReader {
entries = await this._readFrom(eventKey, Direction.Backward, amount, txn); entries = await this._readFrom(eventKey, Direction.Backward, amount, txn);
entries.unshift(liveFragmentEntry); entries.unshift(liveFragmentEntry);
} }
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
return entries; return entries;
} }
} }