forked from mystiq/hydrogen-web
offload olm account creation in worker
This commit is contained in:
parent
0b26e6f53a
commit
e0d9d703b7
8 changed files with 134 additions and 61 deletions
12
src/main.js
12
src/main.js
|
@ -26,6 +26,7 @@ import {BrawlView} from "./ui/web/BrawlView.js";
|
|||
import {Clock} from "./ui/web/dom/Clock.js";
|
||||
import {OnlineStatus} from "./ui/web/dom/OnlineStatus.js";
|
||||
import {WorkerPool} from "./utils/WorkerPool.js";
|
||||
import {OlmWorker} from "./matrix/e2ee/OlmWorker.js";
|
||||
|
||||
function addScript(src) {
|
||||
return new Promise(function (resolve, reject) {
|
||||
|
@ -65,12 +66,13 @@ function relPath(path, basePath) {
|
|||
return "../".repeat(dirCount) + path;
|
||||
}
|
||||
|
||||
async function loadWorker(paths) {
|
||||
async function loadOlmWorker(paths) {
|
||||
const workerPool = new WorkerPool(paths.worker, 4);
|
||||
await workerPool.init();
|
||||
const path = relPath(paths.olm.legacyBundle, paths.worker);
|
||||
await workerPool.sendAll({type: "load_olm", path});
|
||||
return workerPool;
|
||||
const olmWorker = new OlmWorker(workerPool);
|
||||
return olmWorker;
|
||||
}
|
||||
|
||||
// Don't use a default export here, as we use multiple entries during legacy build,
|
||||
|
@ -100,9 +102,9 @@ export async function main(container, paths) {
|
|||
// if wasm is not supported, we'll want
|
||||
// to run some olm operations in a worker (mainly for IE11)
|
||||
let workerPromise;
|
||||
if (!window.WebAssembly) {
|
||||
workerPromise = loadWorker(paths);
|
||||
}
|
||||
// if (!window.WebAssembly) {
|
||||
workerPromise = loadOlmWorker(paths);
|
||||
// }
|
||||
|
||||
const vm = new BrawlViewModel({
|
||||
createSessionContainer: () => {
|
||||
|
|
|
@ -33,7 +33,7 @@ const PICKLE_KEY = "DEFAULT_KEY";
|
|||
|
||||
export class Session {
|
||||
// sessionInfo contains deviceId, userId and homeServer
|
||||
constructor({clock, storage, hsApi, sessionInfo, olm, workerPool}) {
|
||||
constructor({clock, storage, hsApi, sessionInfo, olm, olmWorker}) {
|
||||
this._clock = clock;
|
||||
this._storage = storage;
|
||||
this._hsApi = hsApi;
|
||||
|
@ -52,7 +52,7 @@ export class Session {
|
|||
this._megolmEncryption = null;
|
||||
this._megolmDecryption = null;
|
||||
this._getSyncToken = () => this.syncToken;
|
||||
this._workerPool = workerPool;
|
||||
this._olmWorker = olmWorker;
|
||||
|
||||
if (olm) {
|
||||
this._olmUtil = new olm.Utility();
|
||||
|
@ -101,7 +101,7 @@ export class Session {
|
|||
this._megolmDecryption = new MegOlmDecryption({
|
||||
pickleKey: PICKLE_KEY,
|
||||
olm: this._olm,
|
||||
workerPool: this._workerPool,
|
||||
olmWorker: this._olmWorker,
|
||||
});
|
||||
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption});
|
||||
}
|
||||
|
@ -140,23 +140,15 @@ export class Session {
|
|||
throw new Error("there should not be an e2ee account already on a fresh login");
|
||||
}
|
||||
if (!this._e2eeAccount) {
|
||||
const txn = await this._storage.readWriteTxn([
|
||||
this._storage.storeNames.session
|
||||
]);
|
||||
try {
|
||||
this._e2eeAccount = await E2EEAccount.create({
|
||||
hsApi: this._hsApi,
|
||||
olm: this._olm,
|
||||
pickleKey: PICKLE_KEY,
|
||||
userId: this._sessionInfo.userId,
|
||||
deviceId: this._sessionInfo.deviceId,
|
||||
txn
|
||||
});
|
||||
} catch (err) {
|
||||
txn.abort();
|
||||
throw err;
|
||||
}
|
||||
await txn.complete();
|
||||
this._e2eeAccount = await E2EEAccount.create({
|
||||
hsApi: this._hsApi,
|
||||
olm: this._olm,
|
||||
pickleKey: PICKLE_KEY,
|
||||
userId: this._sessionInfo.userId,
|
||||
deviceId: this._sessionInfo.deviceId,
|
||||
olmWorker: this._olmWorker,
|
||||
storage: this._storage,
|
||||
});
|
||||
this._setupEncryption();
|
||||
}
|
||||
await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
|
||||
|
@ -184,6 +176,7 @@ export class Session {
|
|||
pickleKey: PICKLE_KEY,
|
||||
userId: this._sessionInfo.userId,
|
||||
deviceId: this._sessionInfo.deviceId,
|
||||
olmWorker: this._olmWorker,
|
||||
txn
|
||||
});
|
||||
if (this._e2eeAccount) {
|
||||
|
@ -204,7 +197,7 @@ export class Session {
|
|||
}
|
||||
|
||||
stop() {
|
||||
this._workerPool?.dispose();
|
||||
this._olmWorker?.dispose();
|
||||
this._sendScheduler.stop();
|
||||
}
|
||||
|
||||
|
|
|
@ -153,13 +153,13 @@ export class SessionContainer {
|
|||
homeServer: sessionInfo.homeServer,
|
||||
};
|
||||
const olm = await this._olmPromise;
|
||||
let workerPool = null;
|
||||
let olmWorker = null;
|
||||
if (this._workerPromise) {
|
||||
workerPool = await this._workerPromise;
|
||||
olmWorker = await this._workerPromise;
|
||||
}
|
||||
this._session = new Session({storage: this._storage,
|
||||
sessionInfo: filteredSessionInfo, hsApi, olm,
|
||||
clock: this._clock, workerPool});
|
||||
clock: this._clock, olmWorker});
|
||||
await this._session.load();
|
||||
this._status.set(LoadStatus.SessionSetup);
|
||||
await this._session.beforeFirstSync(isNewLogin);
|
||||
|
|
|
@ -23,7 +23,7 @@ const DEVICE_KEY_FLAG_SESSION_KEY = SESSION_KEY_PREFIX + "areDeviceKeysUploaded"
|
|||
const SERVER_OTK_COUNT_SESSION_KEY = SESSION_KEY_PREFIX + "serverOTKCount";
|
||||
|
||||
export class Account {
|
||||
static async load({olm, pickleKey, hsApi, userId, deviceId, txn}) {
|
||||
static async load({olm, pickleKey, hsApi, userId, deviceId, olmWorker, txn}) {
|
||||
const pickledAccount = await txn.session.get(ACCOUNT_SESSION_KEY);
|
||||
if (pickledAccount) {
|
||||
const account = new olm.Account();
|
||||
|
@ -31,26 +31,39 @@ export class Account {
|
|||
account.unpickle(pickleKey, pickledAccount);
|
||||
const serverOTKCount = await txn.session.get(SERVER_OTK_COUNT_SESSION_KEY);
|
||||
return new Account({pickleKey, hsApi, account, userId,
|
||||
deviceId, areDeviceKeysUploaded, serverOTKCount, olm});
|
||||
deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker});
|
||||
}
|
||||
}
|
||||
|
||||
static async create({olm, pickleKey, hsApi, userId, deviceId, txn}) {
|
||||
static async create({olm, pickleKey, hsApi, userId, deviceId, olmWorker, storage}) {
|
||||
const account = new olm.Account();
|
||||
account.create();
|
||||
account.generate_one_time_keys(account.max_number_of_one_time_keys());
|
||||
if (olmWorker) {
|
||||
await olmWorker.createAccountAndOTKs(account, account.max_number_of_one_time_keys());
|
||||
} else {
|
||||
account.create();
|
||||
account.generate_one_time_keys(account.max_number_of_one_time_keys());
|
||||
}
|
||||
const pickledAccount = account.pickle(pickleKey);
|
||||
// add will throw if the key already exists
|
||||
// we would not want to overwrite olmAccount here
|
||||
const areDeviceKeysUploaded = false;
|
||||
await txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
|
||||
await txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
|
||||
await txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
|
||||
const txn = await storage.readWriteTxn([
|
||||
storage.storeNames.session
|
||||
]);
|
||||
try {
|
||||
// add will throw if the key already exists
|
||||
// we would not want to overwrite olmAccount here
|
||||
txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
|
||||
txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
|
||||
txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
|
||||
} catch (err) {
|
||||
txn.abort();
|
||||
throw err;
|
||||
}
|
||||
await txn.complete();
|
||||
return new Account({pickleKey, hsApi, account, userId,
|
||||
deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm});
|
||||
deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm, olmWorker});
|
||||
}
|
||||
|
||||
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm}) {
|
||||
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker}) {
|
||||
this._olm = olm;
|
||||
this._pickleKey = pickleKey;
|
||||
this._hsApi = hsApi;
|
||||
|
@ -59,6 +72,7 @@ export class Account {
|
|||
this._deviceId = deviceId;
|
||||
this._areDeviceKeysUploaded = areDeviceKeysUploaded;
|
||||
this._serverOTKCount = serverOTKCount;
|
||||
this._olmWorker = olmWorker;
|
||||
this._identityKeys = JSON.parse(this._account.identity_keys());
|
||||
}
|
||||
|
||||
|
|
|
@ -14,13 +14,30 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
*/
|
||||
|
||||
export class DecryptionWorker {
|
||||
export class OlmWorker {
|
||||
constructor(workerPool) {
|
||||
this._workerPool = workerPool;
|
||||
}
|
||||
|
||||
decrypt(session, ciphertext) {
|
||||
megolmDecrypt(session, ciphertext) {
|
||||
const sessionKey = session.export_session(session.first_known_index());
|
||||
return this._workerPool.send({type: "megolm_decrypt", ciphertext, sessionKey});
|
||||
}
|
||||
|
||||
async createAccountAndOTKs(account, otkAmount) {
|
||||
// IE11 does not support getRandomValues in a worker, so we have to generate the values upfront.
|
||||
let randomValues;
|
||||
if (window.msCrypto) {
|
||||
randomValues = [
|
||||
window.msCrypto.getRandomValues(new Uint8Array(64)),
|
||||
window.msCrypto.getRandomValues(new Uint8Array(otkAmount * 32)),
|
||||
];
|
||||
}
|
||||
const pickle = await this._workerPool.send({type: "olm_create_account_otks", randomValues, otkAmount}).response();
|
||||
account.unpickle("", pickle);
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this._workerPool.dispose();
|
||||
}
|
||||
}
|
|
@ -21,7 +21,6 @@ import {SessionInfo} from "./decryption/SessionInfo.js";
|
|||
import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js";
|
||||
import {SessionDecryption} from "./decryption/SessionDecryption.js";
|
||||
import {SessionCache} from "./decryption/SessionCache.js";
|
||||
import {DecryptionWorker} from "./decryption/DecryptionWorker.js";
|
||||
|
||||
function getSenderKey(event) {
|
||||
return event.content?.["sender_key"];
|
||||
|
@ -36,10 +35,10 @@ function getCiphertext(event) {
|
|||
}
|
||||
|
||||
export class Decryption {
|
||||
constructor({pickleKey, olm, workerPool}) {
|
||||
constructor({pickleKey, olm, olmWorker}) {
|
||||
this._pickleKey = pickleKey;
|
||||
this._olm = olm;
|
||||
this._decryptor = workerPool ? new DecryptionWorker(workerPool) : null;
|
||||
this._olmWorker = olmWorker;
|
||||
}
|
||||
|
||||
createSessionCache(fallback) {
|
||||
|
@ -86,7 +85,7 @@ export class Decryption {
|
|||
errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event));
|
||||
}
|
||||
} else {
|
||||
sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._decryptor));
|
||||
sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._olmWorker));
|
||||
}
|
||||
}));
|
||||
|
||||
|
|
|
@ -22,12 +22,12 @@ import {ReplayDetectionEntry} from "./ReplayDetectionEntry.js";
|
|||
* Does the actual decryption of all events for a given megolm session in a batch
|
||||
*/
|
||||
export class SessionDecryption {
|
||||
constructor(sessionInfo, events, decryptor) {
|
||||
constructor(sessionInfo, events, olmWorker) {
|
||||
sessionInfo.retain();
|
||||
this._sessionInfo = sessionInfo;
|
||||
this._events = events;
|
||||
this._decryptor = decryptor;
|
||||
this._decryptionRequests = decryptor ? [] : null;
|
||||
this._olmWorker = olmWorker;
|
||||
this._decryptionRequests = olmWorker ? [] : null;
|
||||
}
|
||||
|
||||
async decryptAll() {
|
||||
|
@ -41,8 +41,8 @@ export class SessionDecryption {
|
|||
const {session} = this._sessionInfo;
|
||||
const ciphertext = event.content.ciphertext;
|
||||
let decryptionResult;
|
||||
if (this._decryptor) {
|
||||
const request = this._decryptor.decrypt(session, ciphertext);
|
||||
if (this._olmWorker) {
|
||||
const request = this._olmWorker.megolmDecrypt(session, ciphertext);
|
||||
this._decryptionRequests.push(request);
|
||||
decryptionResult = await request.response();
|
||||
} else {
|
||||
|
|
|
@ -32,6 +32,44 @@ function asSuccessMessage(payload) {
|
|||
class MessageHandler {
|
||||
constructor() {
|
||||
this._olm = null;
|
||||
this._randomValues = self.crypto ? null : [];
|
||||
}
|
||||
|
||||
_feedRandomValues(randomValues) {
|
||||
if (this._randomValues) {
|
||||
this._randomValues.push(...randomValues);
|
||||
}
|
||||
}
|
||||
|
||||
_checkRandomValuesUsed() {
|
||||
if (this._randomValues && this._randomValues.length !== 0) {
|
||||
throw new Error(`${this._randomValues.length} random values left`);
|
||||
}
|
||||
}
|
||||
|
||||
_getRandomValues(typedArray) {
|
||||
if (!(typedArray instanceof Uint8Array)) {
|
||||
throw new Error("only Uint8Array is supported: " + JSON.stringify({
|
||||
Int8Array: typedArray instanceof Int8Array,
|
||||
Uint8Array: typedArray instanceof Uint8Array,
|
||||
Int16Array: typedArray instanceof Int16Array,
|
||||
Uint16Array: typedArray instanceof Uint16Array,
|
||||
Int32Array: typedArray instanceof Int32Array,
|
||||
Uint32Array: typedArray instanceof Uint32Array,
|
||||
}));
|
||||
}
|
||||
if (this._randomValues.length === 0) {
|
||||
throw new Error("no more random values, needed one of length " + typedArray.length);
|
||||
}
|
||||
const precalculated = this._randomValues.shift();
|
||||
if (precalculated.length !== typedArray.length) {
|
||||
throw new Error(`typedArray length (${typedArray.length}) does not match precalculated length (${precalculated.length})`);
|
||||
}
|
||||
// copy values
|
||||
for (let i = 0; i < typedArray.length; ++i) {
|
||||
typedArray[i] = precalculated[i];
|
||||
}
|
||||
return typedArray;
|
||||
}
|
||||
|
||||
handleEvent(e) {
|
||||
|
@ -47,7 +85,7 @@ class MessageHandler {
|
|||
|
||||
_toMessage(fn) {
|
||||
try {
|
||||
let payload = fn();
|
||||
const payload = fn();
|
||||
if (payload instanceof Promise) {
|
||||
return payload.then(
|
||||
payload => asSuccessMessage(payload),
|
||||
|
@ -63,18 +101,15 @@ class MessageHandler {
|
|||
|
||||
_loadOlm(path) {
|
||||
return this._toMessage(async () => {
|
||||
// might have some problems here with window vs self as global object?
|
||||
if (self.msCrypto && !self.crypto) {
|
||||
self.crypto = self.msCrypto;
|
||||
if (!self.crypto) {
|
||||
self.crypto = {getRandomValues: this._getRandomValues.bind(this)};
|
||||
}
|
||||
self.importScripts(path);
|
||||
const olm = self.olm_exports;
|
||||
// mangle the globals enough to make olm load believe it is running in a browser
|
||||
// mangle the globals enough to make olm believe it is running in a browser
|
||||
self.window = self;
|
||||
self.document = {};
|
||||
self.importScripts(path);
|
||||
const olm = self.olm_exports;
|
||||
await olm.init();
|
||||
delete self.document;
|
||||
delete self.window;
|
||||
this._olm = olm;
|
||||
});
|
||||
}
|
||||
|
@ -93,6 +128,17 @@ class MessageHandler {
|
|||
});
|
||||
}
|
||||
|
||||
_olmCreateAccountAndOTKs(randomValues, otkAmount) {
|
||||
return this._toMessage(() => {
|
||||
this._feedRandomValues(randomValues);
|
||||
const account = new this._olm.Account();
|
||||
account.create();
|
||||
account.generate_one_time_keys(otkAmount);
|
||||
this._checkRandomValuesUsed();
|
||||
return account.pickle("");
|
||||
});
|
||||
}
|
||||
|
||||
async _handleMessage(message) {
|
||||
const {type} = message;
|
||||
if (type === "ping") {
|
||||
|
@ -101,6 +147,8 @@ class MessageHandler {
|
|||
this._sendReply(message, await this._loadOlm(message.path));
|
||||
} else if (type === "megolm_decrypt") {
|
||||
this._sendReply(message, this._megolmDecrypt(message.sessionKey, message.ciphertext));
|
||||
} else if (type === "olm_create_account_otks") {
|
||||
this._sendReply(message, this._olmCreateAccountAndOTKs(message.randomValues, message.otkAmount));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue