Merge pull request #670 from vector-im/bwindels/ts-olm

Convert olm code to typescript
This commit is contained in:
Bruno Windels 2022-03-01 18:53:22 +01:00 committed by GitHub
commit 2e1283d199
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 306 additions and 198 deletions

View File

@ -26,8 +26,8 @@ import {User} from "./User.js";
import {DeviceMessageHandler} from "./DeviceMessageHandler.js"; import {DeviceMessageHandler} from "./DeviceMessageHandler.js";
import {Account as E2EEAccount} from "./e2ee/Account.js"; import {Account as E2EEAccount} from "./e2ee/Account.js";
import {uploadAccountAsDehydratedDevice} from "./e2ee/Dehydration.js"; import {uploadAccountAsDehydratedDevice} from "./e2ee/Dehydration.js";
import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js"; import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption";
import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js"; import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption";
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption";
import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader"; import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader";
import {KeyBackup} from "./e2ee/megolm/keybackup/KeyBackup"; import {KeyBackup} from "./e2ee/megolm/keybackup/KeyBackup";
@ -123,25 +123,24 @@ export class Session {
// TODO: this should all go in a wrapper in e2ee/ that is bootstrapped by passing in the account // TODO: this should all go in a wrapper in e2ee/ that is bootstrapped by passing in the account
// and can create RoomEncryption objects and handle encrypted to_device messages and device list changes. // and can create RoomEncryption objects and handle encrypted to_device messages and device list changes.
const senderKeyLock = new LockMap(); const senderKeyLock = new LockMap();
const olmDecryption = new OlmDecryption({ const olmDecryption = new OlmDecryption(
account: this._e2eeAccount, this._e2eeAccount,
pickleKey: PICKLE_KEY, PICKLE_KEY,
olm: this._olm, this._platform.clock.now,
storage: this._storage, this._user.id,
now: this._platform.clock.now, this._olm,
ownUserId: this._user.id,
senderKeyLock senderKeyLock
}); );
this._olmEncryption = new OlmEncryption({ this._olmEncryption = new OlmEncryption(
account: this._e2eeAccount, this._e2eeAccount,
pickleKey: PICKLE_KEY, PICKLE_KEY,
olm: this._olm, this._olm,
storage: this._storage, this._storage,
now: this._platform.clock.now, this._platform.clock.now,
ownUserId: this._user.id, this._user.id,
olmUtil: this._olmUtil, this._olmUtil,
senderKeyLock senderKeyLock
}); );
this._keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20); this._keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20);
this._megolmEncryption = new MegOlmEncryption({ this._megolmEncryption = new MegOlmEncryption({
account: this._e2eeAccount, account: this._e2eeAccount,

View File

@ -26,35 +26,41 @@ limitations under the License.
* see DeviceTracker * see DeviceTracker
*/ */
import type {DeviceIdentity} from "../storage/idb/stores/DeviceIdentityStore";
type DecryptedEvent = {
type?: string,
content?: Record<string, any>
}
export class DecryptionResult { export class DecryptionResult {
constructor(event, senderCurve25519Key, claimedEd25519Key) { private device?: DeviceIdentity;
this.event = event; private roomTracked: boolean = true;
this.senderCurve25519Key = senderCurve25519Key;
this.claimedEd25519Key = claimedEd25519Key; constructor(
this._device = null; public readonly event: DecryptedEvent,
this._roomTracked = true; public readonly senderCurve25519Key: string,
public readonly claimedEd25519Key: string
) {}
setDevice(device: DeviceIdentity): void {
this.device = device;
} }
setDevice(device) { setRoomNotTrackedYet(): void {
this._device = device; this.roomTracked = false;
} }
setRoomNotTrackedYet() { get isVerified(): boolean {
this._roomTracked = false; if (this.device) {
} const comesFromDevice = this.device.ed25519Key === this.claimedEd25519Key;
get isVerified() {
if (this._device) {
const comesFromDevice = this._device.ed25519Key === this.claimedEd25519Key;
return comesFromDevice; return comesFromDevice;
} }
return false; return false;
} }
get isUnverified() { get isUnverified(): boolean {
if (this._device) { if (this.device) {
return !this.isVerified; return !this.isVerified;
} else if (this.isVerificationUnknown) { } else if (this.isVerificationUnknown) {
return false; return false;
@ -63,8 +69,8 @@ export class DecryptionResult {
} }
} }
get isVerificationUnknown() { get isVerificationUnknown(): boolean {
// verification is unknown if we haven't yet fetched the devices for the room // verification is unknown if we haven't yet fetched the devices for the room
return !this._device && !this._roomTracked; return !this.device && !this.roomTracked;
} }
} }

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {DecryptionResult} from "../../DecryptionResult.js"; import {DecryptionResult} from "../../DecryptionResult";
import {DecryptionError} from "../../common.js"; import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; import {ReplayDetectionEntry} from "./ReplayDetectionEntry";
import type {RoomKey} from "./RoomKey"; import type {RoomKey} from "./RoomKey";

View File

@ -16,32 +16,47 @@ limitations under the License.
import {DecryptionError} from "../common.js"; import {DecryptionError} from "../common.js";
import {groupBy} from "../../../utils/groupBy"; import {groupBy} from "../../../utils/groupBy";
import {MultiLock} from "../../../utils/Lock"; import {MultiLock, ILock} from "../../../utils/Lock";
import {Session} from "./Session.js"; import {Session} from "./Session";
import {DecryptionResult} from "../DecryptionResult.js"; import {DecryptionResult} from "../DecryptionResult";
import {OlmPayloadType} from "./types";
import type {OlmMessage, OlmPayload} from "./types";
import type {Account} from "../Account";
import type {LockMap} from "../../../utils/LockMap";
import type {Transaction} from "../../storage/idb/Transaction";
import type {OlmEncryptedEvent} from "./types";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
const SESSION_LIMIT_PER_SENDER_KEY = 4; const SESSION_LIMIT_PER_SENDER_KEY = 4;
function isPreKeyMessage(message) { type DecryptionResults = {
return message.type === 0; results: DecryptionResult[],
} errors: DecryptionError[],
senderKeyDecryption: SenderKeyDecryption
};
function sortSessions(sessions) { type CreateAndDecryptResult = {
session: Session,
plaintext: string
};
function sortSessions(sessions: Session[]): void {
sessions.sort((a, b) => { sessions.sort((a, b) => {
return b.data.lastUsed - a.data.lastUsed; return b.data.lastUsed - a.data.lastUsed;
}); });
} }
export class Decryption { export class Decryption {
constructor({account, pickleKey, now, ownUserId, storage, olm, senderKeyLock}) { constructor(
this._account = account; private readonly account: Account,
this._pickleKey = pickleKey; private readonly pickleKey: string,
this._now = now; private readonly now: () => number,
this._ownUserId = ownUserId; private readonly ownUserId: string,
this._storage = storage; private readonly olm: Olm,
this._olm = olm; private readonly senderKeyLock: LockMap<string>
this._senderKeyLock = senderKeyLock; ) {}
}
// we need to lock because both encryption and decryption can't be done in one txn, // we need to lock because both encryption and decryption can't be done in one txn,
// so for them not to step on each other toes, we need to lock. // so for them not to step on each other toes, we need to lock.
@ -50,8 +65,8 @@ export class Decryption {
// - decryptAll below fails (to release the lock as early as we can) // - decryptAll below fails (to release the lock as early as we can)
// - DecryptionChanges.write succeeds // - DecryptionChanges.write succeeds
// - Sync finishes the writeSync phase (or an error was thrown, in case we never get to DecryptionChanges.write) // - Sync finishes the writeSync phase (or an error was thrown, in case we never get to DecryptionChanges.write)
async obtainDecryptionLock(events) { async obtainDecryptionLock(events: OlmEncryptedEvent[]): Promise<ILock> {
const senderKeys = new Set(); const senderKeys = new Set<string>();
for (const event of events) { for (const event of events) {
const senderKey = event.content?.["sender_key"]; const senderKey = event.content?.["sender_key"];
if (senderKey) { if (senderKey) {
@ -61,7 +76,7 @@ export class Decryption {
// take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen) // take a lock on all senderKeys so encryption or other calls to decryptAll (should not happen)
// don't modify the sessions at the same time // don't modify the sessions at the same time
const locks = await Promise.all(Array.from(senderKeys).map(senderKey => { const locks = await Promise.all(Array.from(senderKeys).map(senderKey => {
return this._senderKeyLock.takeLock(senderKey); return this.senderKeyLock.takeLock(senderKey);
})); }));
return new MultiLock(locks); return new MultiLock(locks);
} }
@ -83,18 +98,18 @@ export class Decryption {
* @param {[type]} events * @param {[type]} events
* @return {Promise<DecryptionChanges>} [description] * @return {Promise<DecryptionChanges>} [description]
*/ */
async decryptAll(events, lock, txn) { async decryptAll(events: OlmEncryptedEvent[], lock: ILock, txn: Transaction): Promise<DecryptionChanges> {
try { try {
const eventsPerSenderKey = groupBy(events, event => event.content?.["sender_key"]); const eventsPerSenderKey = groupBy(events, (event: OlmEncryptedEvent) => event.content?.["sender_key"]);
const timestamp = this._now(); const timestamp = this.now();
// decrypt events for different sender keys in parallel // decrypt events for different sender keys in parallel
const senderKeyOperations = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => { const senderKeyOperations = await Promise.all(Array.from(eventsPerSenderKey.entries()).map(([senderKey, events]) => {
return this._decryptAllForSenderKey(senderKey, events, timestamp, txn); return this._decryptAllForSenderKey(senderKey!, events, timestamp, txn);
})); }));
const results = senderKeyOperations.reduce((all, r) => all.concat(r.results), []); const results = senderKeyOperations.reduce((all, r) => all.concat(r.results), [] as DecryptionResult[]);
const errors = senderKeyOperations.reduce((all, r) => all.concat(r.errors), []); const errors = senderKeyOperations.reduce((all, r) => all.concat(r.errors), [] as DecryptionError[]);
const senderKeyDecryptions = senderKeyOperations.map(r => r.senderKeyDecryption); const senderKeyDecryptions = senderKeyOperations.map(r => r.senderKeyDecryption);
return new DecryptionChanges(senderKeyDecryptions, results, errors, this._account, lock); return new DecryptionChanges(senderKeyDecryptions, results, errors, this.account, lock);
} catch (err) { } catch (err) {
// make sure the locks are release if something throws // make sure the locks are release if something throws
// otherwise they will be released in DecryptionChanges after having written // otherwise they will be released in DecryptionChanges after having written
@ -104,11 +119,11 @@ export class Decryption {
} }
} }
async _decryptAllForSenderKey(senderKey, events, timestamp, readSessionsTxn) { async _decryptAllForSenderKey(senderKey: string, events: OlmEncryptedEvent[], timestamp: number, readSessionsTxn: Transaction): Promise<DecryptionResults> {
const sessions = await this._getSessions(senderKey, readSessionsTxn); const sessions = await this._getSessions(senderKey, readSessionsTxn);
const senderKeyDecryption = new SenderKeyDecryption(senderKey, sessions, this._olm, timestamp); const senderKeyDecryption = new SenderKeyDecryption(senderKey, sessions, timestamp);
const results = []; const results: DecryptionResult[] = [];
const errors = []; const errors: DecryptionError[] = [];
// events for a single senderKey need to be decrypted one by one // events for a single senderKey need to be decrypted one by one
for (const event of events) { for (const event of events) {
try { try {
@ -121,10 +136,10 @@ export class Decryption {
return {results, errors, senderKeyDecryption}; return {results, errors, senderKeyDecryption};
} }
_decryptForSenderKey(senderKeyDecryption, event, timestamp) { _decryptForSenderKey(senderKeyDecryption: SenderKeyDecryption, event: OlmEncryptedEvent, timestamp: number): DecryptionResult {
const senderKey = senderKeyDecryption.senderKey; const senderKey = senderKeyDecryption.senderKey;
const message = this._getMessageAndValidateEvent(event); const message = this._getMessageAndValidateEvent(event);
let plaintext; let plaintext: string | undefined;
try { try {
plaintext = senderKeyDecryption.decrypt(message); plaintext = senderKeyDecryption.decrypt(message);
} catch (err) { } catch (err) {
@ -132,8 +147,8 @@ export class Decryption {
throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", event, {senderKey, error: err.message}); throw new DecryptionError("OLM_BAD_ENCRYPTED_MESSAGE", event, {senderKey, error: err.message});
} }
// could not decrypt with any existing session // could not decrypt with any existing session
if (typeof plaintext !== "string" && isPreKeyMessage(message)) { if (typeof plaintext !== "string" && message.type === OlmPayloadType.PreKey) {
let createResult; let createResult: CreateAndDecryptResult;
try { try {
createResult = this._createSessionAndDecrypt(senderKey, message, timestamp); createResult = this._createSessionAndDecrypt(senderKey, message, timestamp);
} catch (error) { } catch (error) {
@ -143,14 +158,14 @@ export class Decryption {
plaintext = createResult.plaintext; plaintext = createResult.plaintext;
} }
if (typeof plaintext === "string") { if (typeof plaintext === "string") {
let payload; let payload: OlmPayload;
try { try {
payload = JSON.parse(plaintext); payload = JSON.parse(plaintext);
} catch (error) { } catch (error) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error}); throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error});
} }
this._validatePayload(payload, event); this._validatePayload(payload, event);
return new DecryptionResult(payload, senderKey, payload.keys.ed25519); return new DecryptionResult(payload, senderKey, payload.keys!.ed25519!);
} else { } else {
throw new DecryptionError("OLM_NO_MATCHING_SESSION", event, throw new DecryptionError("OLM_NO_MATCHING_SESSION", event,
{knownSessionIds: senderKeyDecryption.sessions.map(s => s.id)}); {knownSessionIds: senderKeyDecryption.sessions.map(s => s.id)});
@ -158,16 +173,16 @@ export class Decryption {
} }
// only for pre-key messages after having attempted decryption with existing sessions // only for pre-key messages after having attempted decryption with existing sessions
_createSessionAndDecrypt(senderKey, message, timestamp) { _createSessionAndDecrypt(senderKey: string, message: OlmMessage, timestamp: number): CreateAndDecryptResult {
let plaintext; let plaintext;
// if we have multiple messages encrypted with the same new session, // if we have multiple messages encrypted with the same new session,
// this could create multiple sessions as the OTK isn't removed yet // this could create multiple sessions as the OTK isn't removed yet
// (this only happens in DecryptionChanges.write) // (this only happens in DecryptionChanges.write)
// This should be ok though as we'll first try to decrypt with the new session // This should be ok though as we'll first try to decrypt with the new session
const olmSession = this._account.createInboundOlmSession(senderKey, message.body); const olmSession = this.account.createInboundOlmSession(senderKey, message.body);
try { try {
plaintext = olmSession.decrypt(message.type, message.body); plaintext = olmSession.decrypt(message.type, message.body);
const session = Session.create(senderKey, olmSession, this._olm, this._pickleKey, timestamp); const session = Session.create(senderKey, olmSession, this.olm, this.pickleKey, timestamp);
session.unload(olmSession); session.unload(olmSession);
return {session, plaintext}; return {session, plaintext};
} catch (err) { } catch (err) {
@ -176,12 +191,12 @@ export class Decryption {
} }
} }
_getMessageAndValidateEvent(event) { _getMessageAndValidateEvent(event: OlmEncryptedEvent): OlmMessage {
const ciphertext = event.content?.ciphertext; const ciphertext = event.content?.ciphertext;
if (!ciphertext) { if (!ciphertext) {
throw new DecryptionError("OLM_MISSING_CIPHERTEXT", event); throw new DecryptionError("OLM_MISSING_CIPHERTEXT", event);
} }
const message = ciphertext?.[this._account.identityKeys.curve25519]; const message = ciphertext?.[this.account.identityKeys.curve25519];
if (!message) { if (!message) {
throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS", event); throw new DecryptionError("OLM_NOT_INCLUDED_IN_RECIPIENTS", event);
} }
@ -189,22 +204,22 @@ export class Decryption {
return message; return message;
} }
async _getSessions(senderKey, txn) { async _getSessions(senderKey: string, txn: Transaction): Promise<Session[]> {
const sessionEntries = await txn.olmSessions.getAll(senderKey); const sessionEntries = await txn.olmSessions.getAll(senderKey);
// sort most recent used sessions first // sort most recent used sessions first
const sessions = sessionEntries.map(s => new Session(s, this._pickleKey, this._olm)); const sessions = sessionEntries.map(s => new Session(s, this.pickleKey, this.olm));
sortSessions(sessions); sortSessions(sessions);
return sessions; return sessions;
} }
_validatePayload(payload, event) { _validatePayload(payload: OlmPayload, event: OlmEncryptedEvent): void {
if (payload.sender !== event.sender) { if (payload.sender !== event.sender) {
throw new DecryptionError("OLM_FORWARDED_MESSAGE", event, {sentBy: event.sender, encryptedBy: payload.sender}); throw new DecryptionError("OLM_FORWARDED_MESSAGE", event, {sentBy: event.sender, encryptedBy: payload.sender});
} }
if (payload.recipient !== this._ownUserId) { if (payload.recipient !== this.ownUserId) {
throw new DecryptionError("OLM_BAD_RECIPIENT", event, {recipient: payload.recipient}); throw new DecryptionError("OLM_BAD_RECIPIENT", event, {recipient: payload.recipient});
} }
if (payload.recipient_keys?.ed25519 !== this._account.identityKeys.ed25519) { if (payload.recipient_keys?.ed25519 !== this.account.identityKeys.ed25519) {
throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", event, {key: payload.recipient_keys?.ed25519}); throw new DecryptionError("OLM_BAD_RECIPIENT_KEY", event, {key: payload.recipient_keys?.ed25519});
} }
// TODO: check room_id // TODO: check room_id
@ -219,21 +234,20 @@ export class Decryption {
// decryption helper for a single senderKey // decryption helper for a single senderKey
class SenderKeyDecryption { class SenderKeyDecryption {
constructor(senderKey, sessions, olm, timestamp) { constructor(
this.senderKey = senderKey; public readonly senderKey: string,
this.sessions = sessions; public readonly sessions: Session[],
this._olm = olm; private readonly timestamp: number
this._timestamp = timestamp; ) {}
}
addNewSession(session) { addNewSession(session: Session): void {
// add at top as it is most recent // add at top as it is most recent
this.sessions.unshift(session); this.sessions.unshift(session);
} }
decrypt(message) { decrypt(message: OlmMessage): string | undefined {
for (const session of this.sessions) { for (const session of this.sessions) {
const plaintext = this._decryptWithSession(session, message); const plaintext = this.decryptWithSession(session, message);
if (typeof plaintext === "string") { if (typeof plaintext === "string") {
// keep them sorted so will try the same session first for other messages // keep them sorted so will try the same session first for other messages
// and so we can assume the excess ones are at the end // and so we can assume the excess ones are at the end
@ -244,11 +258,11 @@ class SenderKeyDecryption {
} }
} }
getModifiedSessions() { getModifiedSessions(): Session[] {
return this.sessions.filter(session => session.isModified); return this.sessions.filter(session => session.isModified);
} }
get hasNewSessions() { get hasNewSessions(): boolean {
return this.sessions.some(session => session.isNew); return this.sessions.some(session => session.isNew);
} }
@ -257,19 +271,22 @@ class SenderKeyDecryption {
// if this turns out to be a real cost for IE11, // if this turns out to be a real cost for IE11,
// we could look into adding a less expensive serialization mechanism // we could look into adding a less expensive serialization mechanism
// for olm sessions to libolm // for olm sessions to libolm
_decryptWithSession(session, message) { private decryptWithSession(session: Session, message: OlmMessage): string | undefined {
if (message.type === undefined || message.body === undefined) {
throw new Error("Invalid message without type or body");
}
const olmSession = session.load(); const olmSession = session.load();
try { try {
if (isPreKeyMessage(message) && !olmSession.matches_inbound(message.body)) { if (message.type === OlmPayloadType.PreKey && !olmSession.matches_inbound(message.body)) {
return; return;
} }
try { try {
const plaintext = olmSession.decrypt(message.type, message.body); const plaintext = olmSession.decrypt(message.type as number, message.body!);
session.save(olmSession); session.save(olmSession);
session.lastUsed = this._timestamp; session.data.lastUsed = this.timestamp;
return plaintext; return plaintext;
} catch (err) { } catch (err) {
if (isPreKeyMessage(message)) { if (message.type === OlmPayloadType.PreKey) {
throw new Error(`Error decrypting prekey message with existing session id ${session.id}: ${err.message}`); throw new Error(`Error decrypting prekey message with existing session id ${session.id}: ${err.message}`);
} }
// decryption failed, bail out // decryption failed, bail out
@ -286,27 +303,27 @@ class SenderKeyDecryption {
* @property {Array<DecryptionError>} errors see DecryptionError.event to retrieve the event that failed to decrypt. * @property {Array<DecryptionError>} errors see DecryptionError.event to retrieve the event that failed to decrypt.
*/ */
class DecryptionChanges { class DecryptionChanges {
constructor(senderKeyDecryptions, results, errors, account, lock) { constructor(
this._senderKeyDecryptions = senderKeyDecryptions; private readonly senderKeyDecryptions: SenderKeyDecryption[],
this._account = account; public readonly results: DecryptionResult[],
this.results = results; public readonly errors: DecryptionError[],
this.errors = errors; private readonly account: Account,
this._lock = lock; private readonly lock: ILock
) {}
get hasNewSessions(): boolean {
return this.senderKeyDecryptions.some(skd => skd.hasNewSessions);
} }
get hasNewSessions() { write(txn: Transaction): void {
return this._senderKeyDecryptions.some(skd => skd.hasNewSessions);
}
write(txn) {
try { try {
for (const senderKeyDecryption of this._senderKeyDecryptions) { for (const senderKeyDecryption of this.senderKeyDecryptions) {
for (const session of senderKeyDecryption.getModifiedSessions()) { for (const session of senderKeyDecryption.getModifiedSessions()) {
txn.olmSessions.set(session.data); txn.olmSessions.set(session.data);
if (session.isNew) { if (session.isNew) {
const olmSession = session.load(); const olmSession = session.load();
try { try {
this._account.writeRemoveOneTimeKey(olmSession, txn); this.account.writeRemoveOneTimeKey(olmSession, txn);
} finally { } finally {
session.unload(olmSession); session.unload(olmSession);
} }
@ -322,7 +339,7 @@ class DecryptionChanges {
} }
} }
} finally { } finally {
this._lock.release(); this.lock.release();
} }
} }
} }

View File

@ -16,7 +16,33 @@ limitations under the License.
import {groupByWithCreator} from "../../../utils/groupBy"; import {groupByWithCreator} from "../../../utils/groupBy";
import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js"; import {verifyEd25519Signature, OLM_ALGORITHM} from "../common.js";
import {createSessionEntry} from "./Session.js"; import {createSessionEntry} from "./Session";
import type {OlmMessage, OlmPayload, OlmEncryptedMessageContent} from "./types";
import type {Account} from "../Account";
import type {LockMap} from "../../../utils/LockMap";
import type {Storage} from "../../storage/idb/Storage";
import type {Transaction} from "../../storage/idb/Transaction";
import type {DeviceIdentity} from "../../storage/idb/stores/DeviceIdentityStore";
import type {HomeServerApi} from "../../net/HomeServerApi";
import type {ILogItem} from "../../../logging/types";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
type ClaimedOTKResponse = {
[userId: string]: {
[deviceId: string]: {
[algorithmAndOtk: string]: {
key: string,
signatures: {
[userId: string]: {
[algorithmAndDevice: string]: string
}
}
}
}
}
};
function findFirstSessionId(sessionIds) { function findFirstSessionId(sessionIds) {
return sessionIds.reduce((first, sessionId) => { return sessionIds.reduce((first, sessionId) => {
@ -36,19 +62,19 @@ const OTK_ALGORITHM = "signed_curve25519";
const MAX_BATCH_SIZE = 20; const MAX_BATCH_SIZE = 20;
export class Encryption { export class Encryption {
constructor({account, olm, olmUtil, ownUserId, storage, now, pickleKey, senderKeyLock}) { constructor(
this._account = account; private readonly account: Account,
this._olm = olm; private readonly pickleKey: string,
this._olmUtil = olmUtil; private readonly olm: Olm,
this._ownUserId = ownUserId; private readonly storage: Storage,
this._storage = storage; private readonly now: () => number,
this._now = now; private readonly ownUserId: string,
this._pickleKey = pickleKey; private readonly olmUtil: Olm.Utility,
this._senderKeyLock = senderKeyLock; private readonly senderKeyLock: LockMap<string>
} ) {}
async encrypt(type, content, devices, hsApi, log) { async encrypt(type: string, content: Record<string, any>, devices: DeviceIdentity[], hsApi: HomeServerApi, log: ILogItem): Promise<EncryptedMessage[]> {
let messages = []; let messages: EncryptedMessage[] = [];
for (let i = 0; i < devices.length ; i += MAX_BATCH_SIZE) { for (let i = 0; i < devices.length ; i += MAX_BATCH_SIZE) {
const batchDevices = devices.slice(i, i + MAX_BATCH_SIZE); const batchDevices = devices.slice(i, i + MAX_BATCH_SIZE);
const batchMessages = await this._encryptForMaxDevices(type, content, batchDevices, hsApi, log); const batchMessages = await this._encryptForMaxDevices(type, content, batchDevices, hsApi, log);
@ -57,12 +83,12 @@ export class Encryption {
return messages; return messages;
} }
async _encryptForMaxDevices(type, content, devices, hsApi, log) { async _encryptForMaxDevices(type: string, content: Record<string, any>, devices: DeviceIdentity[], hsApi: HomeServerApi, log: ILogItem): Promise<EncryptedMessage[]> {
// TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed) // TODO: see if we can only hold some of the locks until after the /keys/claim call (if needed)
// take a lock on all senderKeys so decryption and other calls to encrypt (should not happen) // take a lock on all senderKeys so decryption and other calls to encrypt (should not happen)
// don't modify the sessions at the same time // don't modify the sessions at the same time
const locks = await Promise.all(devices.map(device => { const locks = await Promise.all(devices.map(device => {
return this._senderKeyLock.takeLock(device.curve25519Key); return this.senderKeyLock.takeLock(device.curve25519Key);
})); }));
try { try {
const { const {
@ -70,9 +96,9 @@ export class Encryption {
existingEncryptionTargets, existingEncryptionTargets,
} = await this._findExistingSessions(devices); } = await this._findExistingSessions(devices);
const timestamp = this._now(); const timestamp = this.now();
let encryptionTargets = []; let encryptionTargets: EncryptionTarget[] = [];
try { try {
if (devicesWithoutSession.length) { if (devicesWithoutSession.length) {
const newEncryptionTargets = await log.wrap("create sessions", log => this._createNewSessions( const newEncryptionTargets = await log.wrap("create sessions", log => this._createNewSessions(
@ -100,8 +126,8 @@ export class Encryption {
} }
} }
async _findExistingSessions(devices) { async _findExistingSessions(devices: DeviceIdentity[]): Promise<{devicesWithoutSession: DeviceIdentity[], existingEncryptionTargets: EncryptionTarget[]}> {
const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); const txn = await this.storage.readTxn([this.storage.storeNames.olmSessions]);
const sessionIdsForDevice = await Promise.all(devices.map(async device => { const sessionIdsForDevice = await Promise.all(devices.map(async device => {
return await txn.olmSessions.getSessionIds(device.curve25519Key); return await txn.olmSessions.getSessionIds(device.curve25519Key);
})); }));
@ -116,18 +142,18 @@ export class Encryption {
const sessionId = findFirstSessionId(sessionIds); const sessionId = findFirstSessionId(sessionIds);
return EncryptionTarget.fromSessionId(device, sessionId); return EncryptionTarget.fromSessionId(device, sessionId);
} }
}).filter(target => !!target); }).filter(target => !!target) as EncryptionTarget[];
return {devicesWithoutSession, existingEncryptionTargets}; return {devicesWithoutSession, existingEncryptionTargets};
} }
_encryptForDevice(type, content, target) { _encryptForDevice(type: string, content: Record<string, any>, target: EncryptionTarget): OlmEncryptedMessageContent {
const {session, device} = target; const {session, device} = target;
const plaintext = JSON.stringify(this._buildPlainTextMessageForDevice(type, content, device)); const plaintext = JSON.stringify(this._buildPlainTextMessageForDevice(type, content, device));
const message = session.encrypt(plaintext); const message = session!.encrypt(plaintext);
const encryptedContent = { const encryptedContent = {
algorithm: OLM_ALGORITHM, algorithm: OLM_ALGORITHM,
sender_key: this._account.identityKeys.curve25519, sender_key: this.account.identityKeys.curve25519,
ciphertext: { ciphertext: {
[device.curve25519Key]: message [device.curve25519Key]: message
} }
@ -135,27 +161,27 @@ export class Encryption {
return encryptedContent; return encryptedContent;
} }
_buildPlainTextMessageForDevice(type, content, device) { _buildPlainTextMessageForDevice(type: string, content: Record<string, any>, device: DeviceIdentity): OlmPayload {
return { return {
keys: { keys: {
"ed25519": this._account.identityKeys.ed25519 "ed25519": this.account.identityKeys.ed25519
}, },
recipient_keys: { recipient_keys: {
"ed25519": device.ed25519Key "ed25519": device.ed25519Key
}, },
recipient: device.userId, recipient: device.userId,
sender: this._ownUserId, sender: this.ownUserId,
content, content,
type type
} }
} }
async _createNewSessions(devicesWithoutSession, hsApi, timestamp, log) { async _createNewSessions(devicesWithoutSession: DeviceIdentity[], hsApi: HomeServerApi, timestamp: number, log: ILogItem): Promise<EncryptionTarget[]> {
const newEncryptionTargets = await log.wrap("claim", log => this._claimOneTimeKeys(hsApi, devicesWithoutSession, log)); const newEncryptionTargets = await log.wrap("claim", log => this._claimOneTimeKeys(hsApi, devicesWithoutSession, log));
try { try {
for (const target of newEncryptionTargets) { for (const target of newEncryptionTargets) {
const {device, oneTimeKey} = target; const {device, oneTimeKey} = target;
target.session = await this._account.createOutboundOlmSession(device.curve25519Key, oneTimeKey); target.session = await this.account.createOutboundOlmSession(device.curve25519Key, oneTimeKey);
} }
await this._storeSessions(newEncryptionTargets, timestamp); await this._storeSessions(newEncryptionTargets, timestamp);
} catch (err) { } catch (err) {
@ -167,12 +193,12 @@ export class Encryption {
return newEncryptionTargets; return newEncryptionTargets;
} }
async _claimOneTimeKeys(hsApi, deviceIdentities, log) { async _claimOneTimeKeys(hsApi: HomeServerApi, deviceIdentities: DeviceIdentity[], log: ILogItem): Promise<EncryptionTarget[]> {
// create a Map<userId, Map<deviceId, deviceIdentity>> // create a Map<userId, Map<deviceId, deviceIdentity>>
const devicesByUser = groupByWithCreator(deviceIdentities, const devicesByUser = groupByWithCreator(deviceIdentities,
device => device.userId, (device: DeviceIdentity) => device.userId,
() => new Map(), (): Map<string, DeviceIdentity> => new Map(),
(deviceMap, device) => deviceMap.set(device.deviceId, device) (deviceMap: Map<string, DeviceIdentity>, device: DeviceIdentity) => deviceMap.set(device.deviceId, device)
); );
const oneTimeKeys = Array.from(devicesByUser.entries()).reduce((usersObj, [userId, deviceMap]) => { const oneTimeKeys = Array.from(devicesByUser.entries()).reduce((usersObj, [userId, deviceMap]) => {
usersObj[userId] = Array.from(deviceMap.values()).reduce((devicesObj, device) => { usersObj[userId] = Array.from(deviceMap.values()).reduce((devicesObj, device) => {
@ -188,12 +214,12 @@ export class Encryption {
if (Object.keys(claimResponse.failures).length) { if (Object.keys(claimResponse.failures).length) {
log.log({l: "failures", servers: Object.keys(claimResponse.failures)}, log.level.Warn); log.log({l: "failures", servers: Object.keys(claimResponse.failures)}, log.level.Warn);
} }
const userKeyMap = claimResponse?.["one_time_keys"]; const userKeyMap = claimResponse?.["one_time_keys"] as ClaimedOTKResponse;
return this._verifyAndCreateOTKTargets(userKeyMap, devicesByUser, log); return this._verifyAndCreateOTKTargets(userKeyMap, devicesByUser, log);
} }
_verifyAndCreateOTKTargets(userKeyMap, devicesByUser, log) { _verifyAndCreateOTKTargets(userKeyMap: ClaimedOTKResponse, devicesByUser: Map<string, Map<string, DeviceIdentity>>, log: ILogItem): EncryptionTarget[] {
const verifiedEncryptionTargets = []; const verifiedEncryptionTargets: EncryptionTarget[] = [];
for (const [userId, userSection] of Object.entries(userKeyMap)) { for (const [userId, userSection] of Object.entries(userKeyMap)) {
for (const [deviceId, deviceSection] of Object.entries(userSection)) { for (const [deviceId, deviceSection] of Object.entries(userSection)) {
const [firstPropName, keySection] = Object.entries(deviceSection)[0]; const [firstPropName, keySection] = Object.entries(deviceSection)[0];
@ -202,7 +228,7 @@ export class Encryption {
const device = devicesByUser.get(userId)?.get(deviceId); const device = devicesByUser.get(userId)?.get(deviceId);
if (device) { if (device) {
const isValidSignature = verifyEd25519Signature( const isValidSignature = verifyEd25519Signature(
this._olmUtil, userId, deviceId, device.ed25519Key, keySection, log); this.olmUtil, userId, deviceId, device.ed25519Key, keySection, log);
if (isValidSignature) { if (isValidSignature) {
const target = EncryptionTarget.fromOTK(device, keySection.key); const target = EncryptionTarget.fromOTK(device, keySection.key);
verifiedEncryptionTargets.push(target); verifiedEncryptionTargets.push(target);
@ -214,8 +240,8 @@ export class Encryption {
return verifiedEncryptionTargets; return verifiedEncryptionTargets;
} }
async _loadSessions(encryptionTargets) { async _loadSessions(encryptionTargets: EncryptionTarget[]): Promise<void> {
const txn = await this._storage.readTxn([this._storage.storeNames.olmSessions]); const txn = await this.storage.readTxn([this.storage.storeNames.olmSessions]);
// given we run loading in parallel, there might still be some // given we run loading in parallel, there might still be some
// storage requests that will finish later once one has failed. // storage requests that will finish later once one has failed.
// those should not allocate a session anymore. // those should not allocate a session anymore.
@ -223,10 +249,10 @@ export class Encryption {
try { try {
await Promise.all(encryptionTargets.map(async encryptionTarget => { await Promise.all(encryptionTargets.map(async encryptionTarget => {
const sessionEntry = await txn.olmSessions.get( const sessionEntry = await txn.olmSessions.get(
encryptionTarget.device.curve25519Key, encryptionTarget.sessionId); encryptionTarget.device.curve25519Key, encryptionTarget.sessionId!);
if (sessionEntry && !failed) { if (sessionEntry && !failed) {
const olmSession = new this._olm.Session(); const olmSession = new this.olm.Session();
olmSession.unpickle(this._pickleKey, sessionEntry.session); olmSession.unpickle(this.pickleKey, sessionEntry.session);
encryptionTarget.session = olmSession; encryptionTarget.session = olmSession;
} }
})); }));
@ -240,12 +266,12 @@ export class Encryption {
} }
} }
async _storeSessions(encryptionTargets, timestamp) { async _storeSessions(encryptionTargets: EncryptionTarget[], timestamp: number): Promise<void> {
const txn = await this._storage.readWriteTxn([this._storage.storeNames.olmSessions]); const txn = await this.storage.readWriteTxn([this.storage.storeNames.olmSessions]);
try { try {
for (const target of encryptionTargets) { for (const target of encryptionTargets) {
const sessionEntry = createSessionEntry( const sessionEntry = createSessionEntry(
target.session, target.device.curve25519Key, timestamp, this._pickleKey); target.session!, target.device.curve25519Key, timestamp, this.pickleKey);
txn.olmSessions.set(sessionEntry); txn.olmSessions.set(sessionEntry);
} }
} catch (err) { } catch (err) {
@ -261,23 +287,24 @@ export class Encryption {
// (and later converted to a session) in case of a new session // (and later converted to a session) in case of a new session
// or an existing session // or an existing session
class EncryptionTarget { class EncryptionTarget {
constructor(device, oneTimeKey, sessionId) {
this.device = device; public session: Olm.Session | null = null;
this.oneTimeKey = oneTimeKey;
this.sessionId = sessionId;
// an olmSession, should probably be called olmSession
this.session = null;
}
static fromOTK(device, oneTimeKey) { constructor(
public readonly device: DeviceIdentity,
public readonly oneTimeKey: string | null,
public readonly sessionId: string | null
) {}
static fromOTK(device: DeviceIdentity, oneTimeKey: string): EncryptionTarget {
return new EncryptionTarget(device, oneTimeKey, null); return new EncryptionTarget(device, oneTimeKey, null);
} }
static fromSessionId(device, sessionId) { static fromSessionId(device: DeviceIdentity, sessionId: string): EncryptionTarget {
return new EncryptionTarget(device, null, sessionId); return new EncryptionTarget(device, null, sessionId);
} }
dispose() { dispose(): void {
if (this.session) { if (this.session) {
this.session.free(); this.session.free();
} }
@ -285,8 +312,8 @@ class EncryptionTarget {
} }
class EncryptedMessage { class EncryptedMessage {
constructor(content, device) { constructor(
this.content = content; public readonly content: OlmEncryptedMessageContent,
this.device = device; public readonly device: DeviceIdentity
} ) {}
} }

View File

@ -14,7 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
export function createSessionEntry(olmSession, senderKey, timestamp, pickleKey) { import type {OlmSessionEntry} from "../../storage/idb/stores/OlmSessionStore";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export function createSessionEntry(olmSession: Olm.Session, senderKey: string, timestamp: number, pickleKey: string): OlmSessionEntry {
return { return {
session: olmSession.pickle(pickleKey), session: olmSession.pickle(pickleKey),
sessionId: olmSession.session_id(), sessionId: olmSession.session_id(),
@ -24,35 +28,38 @@ export function createSessionEntry(olmSession, senderKey, timestamp, pickleKey)
} }
export class Session { export class Session {
constructor(data, pickleKey, olm, isNew = false) { public isModified: boolean;
this.data = data;
this._olm = olm; constructor(
this._pickleKey = pickleKey; public readonly data: OlmSessionEntry,
this.isNew = isNew; private readonly pickleKey: string,
private readonly olm: Olm,
public isNew: boolean = false
) {
this.isModified = isNew; this.isModified = isNew;
} }
static create(senderKey, olmSession, olm, pickleKey, timestamp) { static create(senderKey: string, olmSession: Olm.Session, olm: Olm, pickleKey: string, timestamp: number): Session {
const data = createSessionEntry(olmSession, senderKey, timestamp, pickleKey); const data = createSessionEntry(olmSession, senderKey, timestamp, pickleKey);
return new Session(data, pickleKey, olm, true); return new Session(data, pickleKey, olm, true);
} }
get id() { get id(): string {
return this.data.sessionId; return this.data.sessionId;
} }
load() { load(): Olm.Session {
const session = new this._olm.Session(); const session = new this.olm.Session();
session.unpickle(this._pickleKey, this.data.session); session.unpickle(this.pickleKey, this.data.session);
return session; return session;
} }
unload(olmSession) { unload(olmSession: Olm.Session): void {
olmSession.free(); olmSession.free();
} }
save(olmSession) { save(olmSession: Olm.Session): void {
this.data.session = olmSession.pickle(this._pickleKey); this.data.session = olmSession.pickle(this.pickleKey);
this.isModified = true; this.isModified = true;
} }
} }

View File

@ -0,0 +1,48 @@
/*
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
export const enum OlmPayloadType {
PreKey = 0,
Normal = 1
}
export type OlmMessage = {
type?: OlmPayloadType,
body?: string
}
export type OlmEncryptedMessageContent = {
algorithm?: "m.olm.v1.curve25519-aes-sha2"
sender_key?: string,
ciphertext?: {
[deviceCurve25519Key: string]: OlmMessage
}
}
export type OlmEncryptedEvent = {
type?: "m.room.encrypted",
content?: OlmEncryptedMessageContent
sender?: string
}
export type OlmPayload = {
type?: string;
content?: Record<string, any>;
sender?: string;
recipient?: string;
recipient_keys?: {ed25519?: string};
keys?: {ed25519?: string};
}

View File

@ -24,19 +24,19 @@ function decodeKey(key: string): { senderKey: string, sessionId: string } {
return {senderKey, sessionId}; return {senderKey, sessionId};
} }
interface OlmSession { export type OlmSessionEntry = {
session: string; session: string;
sessionId: string; sessionId: string;
senderKey: string; senderKey: string;
lastUsed: number; lastUsed: number;
} }
type OlmSessionEntry = OlmSession & { key: string }; type OlmSessionStoredEntry = OlmSessionEntry & { key: string };
export class OlmSessionStore { export class OlmSessionStore {
private _store: Store<OlmSessionEntry>; private _store: Store<OlmSessionStoredEntry>;
constructor(store: Store<OlmSessionEntry>) { constructor(store: Store<OlmSessionStoredEntry>) {
this._store = store; this._store = store;
} }
@ -55,20 +55,20 @@ export class OlmSessionStore {
return sessionIds; return sessionIds;
} }
getAll(senderKey: string): Promise<OlmSession[]> { getAll(senderKey: string): Promise<OlmSessionEntry[]> {
const range = this._store.IDBKeyRange.lowerBound(encodeKey(senderKey, "")); const range = this._store.IDBKeyRange.lowerBound(encodeKey(senderKey, ""));
return this._store.selectWhile(range, session => { return this._store.selectWhile(range, session => {
return session.senderKey === senderKey; return session.senderKey === senderKey;
}); });
} }
get(senderKey: string, sessionId: string): Promise<OlmSession | undefined> { get(senderKey: string, sessionId: string): Promise<OlmSessionEntry | undefined> {
return this._store.get(encodeKey(senderKey, sessionId)); return this._store.get(encodeKey(senderKey, sessionId));
} }
set(session: OlmSession): void { set(session: OlmSessionEntry): void {
(session as OlmSessionEntry).key = encodeKey(session.senderKey, session.sessionId); (session as OlmSessionStoredEntry).key = encodeKey(session.senderKey, session.sessionId);
this._store.put(session as OlmSessionEntry); this._store.put(session as OlmSessionStoredEntry);
} }
remove(senderKey: string, sessionId: string): void { remove(senderKey: string, sessionId: string): void {

View File

@ -14,7 +14,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
export class Lock { export interface ILock {
release(): void;
}
export class Lock implements ILock {
private _promise?: Promise<void>; private _promise?: Promise<void>;
private _resolve?: (() => void); private _resolve?: (() => void);
@ -52,7 +56,7 @@ export class Lock {
} }
} }
export class MultiLock { export class MultiLock implements ILock {
constructor(public readonly locks: Lock[]) { constructor(public readonly locks: Lock[]) {
} }