offload olm account creation in worker

This commit is contained in:
Bruno Windels 2020-09-11 10:43:17 +02:00
parent 0b26e6f53a
commit e0d9d703b7
8 changed files with 134 additions and 61 deletions

View file

@ -26,6 +26,7 @@ import {BrawlView} from "./ui/web/BrawlView.js";
import {Clock} from "./ui/web/dom/Clock.js";
import {OnlineStatus} from "./ui/web/dom/OnlineStatus.js";
import {WorkerPool} from "./utils/WorkerPool.js";
import {OlmWorker} from "./matrix/e2ee/OlmWorker.js";
function addScript(src) {
return new Promise(function (resolve, reject) {
@ -65,12 +66,13 @@ function relPath(path, basePath) {
return "../".repeat(dirCount) + path;
}
async function loadWorker(paths) {
async function loadOlmWorker(paths) {
const workerPool = new WorkerPool(paths.worker, 4);
await workerPool.init();
const path = relPath(paths.olm.legacyBundle, paths.worker);
await workerPool.sendAll({type: "load_olm", path});
return workerPool;
const olmWorker = new OlmWorker(workerPool);
return olmWorker;
}
// Don't use a default export here, as we use multiple entries during legacy build,
@ -100,9 +102,9 @@ export async function main(container, paths) {
// if wasm is not supported, we'll want
// to run some olm operations in a worker (mainly for IE11)
let workerPromise;
if (!window.WebAssembly) {
workerPromise = loadWorker(paths);
}
// if (!window.WebAssembly) {
workerPromise = loadOlmWorker(paths);
// }
const vm = new BrawlViewModel({
createSessionContainer: () => {

View file

@ -33,7 +33,7 @@ const PICKLE_KEY = "DEFAULT_KEY";
export class Session {
// sessionInfo contains deviceId, userId and homeServer
constructor({clock, storage, hsApi, sessionInfo, olm, workerPool}) {
constructor({clock, storage, hsApi, sessionInfo, olm, olmWorker}) {
this._clock = clock;
this._storage = storage;
this._hsApi = hsApi;
@ -52,7 +52,7 @@ export class Session {
this._megolmEncryption = null;
this._megolmDecryption = null;
this._getSyncToken = () => this.syncToken;
this._workerPool = workerPool;
this._olmWorker = olmWorker;
if (olm) {
this._olmUtil = new olm.Utility();
@ -101,7 +101,7 @@ export class Session {
this._megolmDecryption = new MegOlmDecryption({
pickleKey: PICKLE_KEY,
olm: this._olm,
workerPool: this._workerPool,
olmWorker: this._olmWorker,
});
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption});
}
@ -140,23 +140,15 @@ export class Session {
throw new Error("there should not be an e2ee account already on a fresh login");
}
if (!this._e2eeAccount) {
const txn = await this._storage.readWriteTxn([
this._storage.storeNames.session
]);
try {
this._e2eeAccount = await E2EEAccount.create({
hsApi: this._hsApi,
olm: this._olm,
pickleKey: PICKLE_KEY,
userId: this._sessionInfo.userId,
deviceId: this._sessionInfo.deviceId,
txn
olmWorker: this._olmWorker,
storage: this._storage,
});
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
this._setupEncryption();
}
await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
@ -184,6 +176,7 @@ export class Session {
pickleKey: PICKLE_KEY,
userId: this._sessionInfo.userId,
deviceId: this._sessionInfo.deviceId,
olmWorker: this._olmWorker,
txn
});
if (this._e2eeAccount) {
@ -204,7 +197,7 @@ export class Session {
}
stop() {
this._workerPool?.dispose();
this._olmWorker?.dispose();
this._sendScheduler.stop();
}

View file

@ -153,13 +153,13 @@ export class SessionContainer {
homeServer: sessionInfo.homeServer,
};
const olm = await this._olmPromise;
let workerPool = null;
let olmWorker = null;
if (this._workerPromise) {
workerPool = await this._workerPromise;
olmWorker = await this._workerPromise;
}
this._session = new Session({storage: this._storage,
sessionInfo: filteredSessionInfo, hsApi, olm,
clock: this._clock, workerPool});
clock: this._clock, olmWorker});
await this._session.load();
this._status.set(LoadStatus.SessionSetup);
await this._session.beforeFirstSync(isNewLogin);

View file

@ -23,7 +23,7 @@ const DEVICE_KEY_FLAG_SESSION_KEY = SESSION_KEY_PREFIX + "areDeviceKeysUploaded"
const SERVER_OTK_COUNT_SESSION_KEY = SESSION_KEY_PREFIX + "serverOTKCount";
export class Account {
static async load({olm, pickleKey, hsApi, userId, deviceId, txn}) {
static async load({olm, pickleKey, hsApi, userId, deviceId, olmWorker, txn}) {
const pickledAccount = await txn.session.get(ACCOUNT_SESSION_KEY);
if (pickledAccount) {
const account = new olm.Account();
@ -31,26 +31,39 @@ export class Account {
account.unpickle(pickleKey, pickledAccount);
const serverOTKCount = await txn.session.get(SERVER_OTK_COUNT_SESSION_KEY);
return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount, olm});
deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker});
}
}
static async create({olm, pickleKey, hsApi, userId, deviceId, txn}) {
static async create({olm, pickleKey, hsApi, userId, deviceId, olmWorker, storage}) {
const account = new olm.Account();
if (olmWorker) {
await olmWorker.createAccountAndOTKs(account, account.max_number_of_one_time_keys());
} else {
account.create();
account.generate_one_time_keys(account.max_number_of_one_time_keys());
}
const pickledAccount = account.pickle(pickleKey);
const areDeviceKeysUploaded = false;
const txn = await storage.readWriteTxn([
storage.storeNames.session
]);
try {
// add will throw if the key already exists
// we would not want to overwrite olmAccount here
const areDeviceKeysUploaded = false;
await txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
await txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
await txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm});
deviceId, areDeviceKeysUploaded, serverOTKCount: 0, olm, olmWorker});
}
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm}) {
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount, olm, olmWorker}) {
this._olm = olm;
this._pickleKey = pickleKey;
this._hsApi = hsApi;
@ -59,6 +72,7 @@ export class Account {
this._deviceId = deviceId;
this._areDeviceKeysUploaded = areDeviceKeysUploaded;
this._serverOTKCount = serverOTKCount;
this._olmWorker = olmWorker;
this._identityKeys = JSON.parse(this._account.identity_keys());
}

View file

@ -14,13 +14,30 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
export class DecryptionWorker {
export class OlmWorker {
constructor(workerPool) {
this._workerPool = workerPool;
}
decrypt(session, ciphertext) {
megolmDecrypt(session, ciphertext) {
const sessionKey = session.export_session(session.first_known_index());
return this._workerPool.send({type: "megolm_decrypt", ciphertext, sessionKey});
}
async createAccountAndOTKs(account, otkAmount) {
// IE11 does not support getRandomValues in a worker, so we have to generate the values upfront.
let randomValues;
if (window.msCrypto) {
randomValues = [
window.msCrypto.getRandomValues(new Uint8Array(64)),
window.msCrypto.getRandomValues(new Uint8Array(otkAmount * 32)),
];
}
const pickle = await this._workerPool.send({type: "olm_create_account_otks", randomValues, otkAmount}).response();
account.unpickle("", pickle);
}
dispose() {
this._workerPool.dispose();
}
}

View file

@ -21,7 +21,6 @@ import {SessionInfo} from "./decryption/SessionInfo.js";
import {DecryptionPreparation} from "./decryption/DecryptionPreparation.js";
import {SessionDecryption} from "./decryption/SessionDecryption.js";
import {SessionCache} from "./decryption/SessionCache.js";
import {DecryptionWorker} from "./decryption/DecryptionWorker.js";
function getSenderKey(event) {
return event.content?.["sender_key"];
@ -36,10 +35,10 @@ function getCiphertext(event) {
}
export class Decryption {
constructor({pickleKey, olm, workerPool}) {
constructor({pickleKey, olm, olmWorker}) {
this._pickleKey = pickleKey;
this._olm = olm;
this._decryptor = workerPool ? new DecryptionWorker(workerPool) : null;
this._olmWorker = olmWorker;
}
createSessionCache(fallback) {
@ -86,7 +85,7 @@ export class Decryption {
errors.set(event.event_id, new DecryptionError("MEGOLM_NO_SESSION", event));
}
} else {
sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._decryptor));
sessionDecryptions.push(new SessionDecryption(sessionInfo, eventsForSession, this._olmWorker));
}
}));

View file

@ -22,12 +22,12 @@ import {ReplayDetectionEntry} from "./ReplayDetectionEntry.js";
* Does the actual decryption of all events for a given megolm session in a batch
*/
export class SessionDecryption {
constructor(sessionInfo, events, decryptor) {
constructor(sessionInfo, events, olmWorker) {
sessionInfo.retain();
this._sessionInfo = sessionInfo;
this._events = events;
this._decryptor = decryptor;
this._decryptionRequests = decryptor ? [] : null;
this._olmWorker = olmWorker;
this._decryptionRequests = olmWorker ? [] : null;
}
async decryptAll() {
@ -41,8 +41,8 @@ export class SessionDecryption {
const {session} = this._sessionInfo;
const ciphertext = event.content.ciphertext;
let decryptionResult;
if (this._decryptor) {
const request = this._decryptor.decrypt(session, ciphertext);
if (this._olmWorker) {
const request = this._olmWorker.megolmDecrypt(session, ciphertext);
this._decryptionRequests.push(request);
decryptionResult = await request.response();
} else {

View file

@ -32,6 +32,44 @@ function asSuccessMessage(payload) {
class MessageHandler {
constructor() {
this._olm = null;
this._randomValues = self.crypto ? null : [];
}
_feedRandomValues(randomValues) {
if (this._randomValues) {
this._randomValues.push(...randomValues);
}
}
_checkRandomValuesUsed() {
if (this._randomValues && this._randomValues.length !== 0) {
throw new Error(`${this._randomValues.length} random values left`);
}
}
_getRandomValues(typedArray) {
if (!(typedArray instanceof Uint8Array)) {
throw new Error("only Uint8Array is supported: " + JSON.stringify({
Int8Array: typedArray instanceof Int8Array,
Uint8Array: typedArray instanceof Uint8Array,
Int16Array: typedArray instanceof Int16Array,
Uint16Array: typedArray instanceof Uint16Array,
Int32Array: typedArray instanceof Int32Array,
Uint32Array: typedArray instanceof Uint32Array,
}));
}
if (this._randomValues.length === 0) {
throw new Error("no more random values, needed one of length " + typedArray.length);
}
const precalculated = this._randomValues.shift();
if (precalculated.length !== typedArray.length) {
throw new Error(`typedArray length (${typedArray.length}) does not match precalculated length (${precalculated.length})`);
}
// copy values
for (let i = 0; i < typedArray.length; ++i) {
typedArray[i] = precalculated[i];
}
return typedArray;
}
handleEvent(e) {
@ -47,7 +85,7 @@ class MessageHandler {
_toMessage(fn) {
try {
let payload = fn();
const payload = fn();
if (payload instanceof Promise) {
return payload.then(
payload => asSuccessMessage(payload),
@ -63,18 +101,15 @@ class MessageHandler {
_loadOlm(path) {
return this._toMessage(async () => {
// might have some problems here with window vs self as global object?
if (self.msCrypto && !self.crypto) {
self.crypto = self.msCrypto;
if (!self.crypto) {
self.crypto = {getRandomValues: this._getRandomValues.bind(this)};
}
self.importScripts(path);
const olm = self.olm_exports;
// mangle the globals enough to make olm load believe it is running in a browser
// mangle the globals enough to make olm believe it is running in a browser
self.window = self;
self.document = {};
self.importScripts(path);
const olm = self.olm_exports;
await olm.init();
delete self.document;
delete self.window;
this._olm = olm;
});
}
@ -93,6 +128,17 @@ class MessageHandler {
});
}
_olmCreateAccountAndOTKs(randomValues, otkAmount) {
return this._toMessage(() => {
this._feedRandomValues(randomValues);
const account = new this._olm.Account();
account.create();
account.generate_one_time_keys(otkAmount);
this._checkRandomValuesUsed();
return account.pickle("");
});
}
async _handleMessage(message) {
const {type} = message;
if (type === "ping") {
@ -101,6 +147,8 @@ class MessageHandler {
this._sendReply(message, await this._loadOlm(message.path));
} else if (type === "megolm_decrypt") {
this._sendReply(message, this._megolmDecrypt(message.sessionKey, message.ciphertext));
} else if (type === "olm_create_account_otks") {
this._sendReply(message, this._olmCreateAccountAndOTKs(message.randomValues, message.otkAmount));
}
}
}