Merge pull request #651 from vector-im/bwindels/write-session-backup

Session backup writing
This commit is contained in:
Bruno Windels 2022-02-01 11:54:53 +01:00 committed by GitHub
commit 247d13f97a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 956 additions and 314 deletions

View file

@ -16,7 +16,7 @@ limitations under the License.
import {ViewModel} from "./ViewModel.js"; import {ViewModel} from "./ViewModel.js";
import {KeyType} from "../matrix/ssss/index"; import {KeyType} from "../matrix/ssss/index";
import {Status} from "./session/settings/SessionBackupViewModel.js"; import {Status} from "./session/settings/KeyBackupViewModel.js";
export class AccountSetupViewModel extends ViewModel { export class AccountSetupViewModel extends ViewModel {
constructor(accountSetup) { constructor(accountSetup) {
@ -50,7 +50,7 @@ export class AccountSetupViewModel extends ViewModel {
} }
} }
// this vm adopts the same shape as SessionBackupViewModel so the same view can be reused. // this vm adopts the same shape as KeyBackupViewModel so the same view can be reused.
class DecryptDehydratedDeviceViewModel extends ViewModel { class DecryptDehydratedDeviceViewModel extends ViewModel {
constructor(accountSetupViewModel, decryptedCallback) { constructor(accountSetupViewModel, decryptedCallback) {
super(); super();

View file

@ -36,7 +36,7 @@ export class SessionStatusViewModel extends ViewModel {
this._reconnector = reconnector; this._reconnector = reconnector;
this._status = this._calculateState(reconnector.connectionStatus.get(), sync.status.get()); this._status = this._calculateState(reconnector.connectionStatus.get(), sync.status.get());
this._session = session; this._session = session;
this._setupSessionBackupUrl = this.urlCreator.urlForSegment("settings"); this._setupKeyBackupUrl = this.urlCreator.urlForSegment("settings");
this._dismissSecretStorage = false; this._dismissSecretStorage = false;
} }
@ -44,17 +44,17 @@ export class SessionStatusViewModel extends ViewModel {
const update = () => this._updateStatus(); const update = () => this._updateStatus();
this.track(this._sync.status.subscribe(update)); this.track(this._sync.status.subscribe(update));
this.track(this._reconnector.connectionStatus.subscribe(update)); this.track(this._reconnector.connectionStatus.subscribe(update));
this.track(this._session.needsSessionBackup.subscribe(() => { this.track(this._session.needsKeyBackup.subscribe(() => {
this.emitChange(); this.emitChange();
})); }));
} }
get setupSessionBackupUrl () { get setupKeyBackupUrl () {
return this._setupSessionBackupUrl; return this._setupKeyBackupUrl;
} }
get isShown() { get isShown() {
return (this._session.needsSessionBackup.get() && !this._dismissSecretStorage) || this._status !== SessionStatus.Syncing; return (this._session.needsKeyBackup.get() && !this._dismissSecretStorage) || this._status !== SessionStatus.Syncing;
} }
get statusLabel() { get statusLabel() {
@ -70,7 +70,7 @@ export class SessionStatusViewModel extends ViewModel {
case SessionStatus.SyncError: case SessionStatus.SyncError:
return this.i18n`Sync failed because of ${this._sync.error}`; return this.i18n`Sync failed because of ${this._sync.error}`;
} }
if (this._session.needsSessionBackup.get()) { if (this._session.needsKeyBackup.get()) {
return this.i18n`Set up session backup to decrypt older messages.`; return this.i18n`Set up session backup to decrypt older messages.`;
} }
return ""; return "";
@ -135,7 +135,7 @@ export class SessionStatusViewModel extends ViewModel {
get isSecretStorageShown() { get isSecretStorageShown() {
// TODO: we need a model here where we can have multiple messages queued up and their buttons don't bleed into each other. // TODO: we need a model here where we can have multiple messages queued up and their buttons don't bleed into each other.
return this._status === SessionStatus.Syncing && this._session.needsSessionBackup.get() && !this._dismissSecretStorage; return this._status === SessionStatus.Syncing && this._session.needsKeyBackup.get() && !this._dismissSecretStorage;
} }
get canDismiss() { get canDismiss() {

View file

@ -18,9 +18,10 @@ import {ViewModel} from "../../ViewModel.js";
import {KeyType} from "../../../matrix/ssss/index"; import {KeyType} from "../../../matrix/ssss/index";
import {createEnum} from "../../../utils/enum"; import {createEnum} from "../../../utils/enum";
export const Status = createEnum("Enabled", "SetupKey", "SetupPhrase", "Pending"); export const Status = createEnum("Enabled", "SetupKey", "SetupPhrase", "Pending", "NewVersionAvailable");
export const BackupWriteStatus = createEnum("Writing", "Stopped", "Done", "Pending");
export class SessionBackupViewModel extends ViewModel { export class KeyBackupViewModel extends ViewModel {
constructor(options) { constructor(options) {
super(options); super(options);
this._session = options.session; this._session = options.session;
@ -28,8 +29,16 @@ export class SessionBackupViewModel extends ViewModel {
this._isBusy = false; this._isBusy = false;
this._dehydratedDeviceId = undefined; this._dehydratedDeviceId = undefined;
this._status = undefined; this._status = undefined;
this._backupOperation = this._session.keyBackup.flatMap(keyBackup => keyBackup.operationInProgress);
this._progress = this._backupOperation.flatMap(op => op.progress);
this.track(this._backupOperation.subscribe(() => {
// see if needsNewKey might be set
this._reevaluateStatus(); this._reevaluateStatus();
this.track(this._session.hasSecretStorageKey.subscribe(() => { this.emitChange("isBackingUp");
}));
this.track(this._progress.subscribe(() => this.emitChange("backupPercentage")));
this._reevaluateStatus();
this.track(this._session.keyBackup.subscribe(() => {
if (this._reevaluateStatus()) { if (this._reevaluateStatus()) {
this.emitChange("status"); this.emitChange("status");
} }
@ -41,11 +50,11 @@ export class SessionBackupViewModel extends ViewModel {
return false; return false;
} }
let status; let status;
const hasSecretStorageKey = this._session.hasSecretStorageKey.get(); const keyBackup = this._session.keyBackup.get();
if (hasSecretStorageKey === true) { if (keyBackup) {
status = this._session.sessionBackup ? Status.Enabled : Status.SetupKey; status = keyBackup.needsNewKey ? Status.NewVersionAvailable : Status.Enabled;
} else if (hasSecretStorageKey === false) { } else if (keyBackup === null) {
status = Status.SetupKey; status = this.showPhraseSetup() ? Status.SetupPhrase : Status.SetupKey;
} else { } else {
status = Status.Pending; status = Status.Pending;
} }
@ -59,7 +68,7 @@ export class SessionBackupViewModel extends ViewModel {
} }
get purpose() { get purpose() {
return this.i18n`set up session backup`; return this.i18n`set up key backup`;
} }
offerDehydratedDeviceSetup() { offerDehydratedDeviceSetup() {
@ -75,7 +84,28 @@ export class SessionBackupViewModel extends ViewModel {
} }
get backupVersion() { get backupVersion() {
return this._session.sessionBackup?.version; return this._session.keyBackup.get()?.version;
}
get backupWriteStatus() {
const keyBackup = this._session.keyBackup.get();
if (!keyBackup) {
return BackupWriteStatus.Pending;
} else if (keyBackup.hasStopped) {
return BackupWriteStatus.Stopped;
}
const operation = keyBackup.operationInProgress.get();
if (operation) {
return BackupWriteStatus.Writing;
} else if (keyBackup.hasBackedUpAllKeys) {
return BackupWriteStatus.Done;
} else {
return BackupWriteStatus.Pending;
}
}
get backupError() {
return this._session.keyBackup.get()?.error?.message;
} }
get status() { get status() {
@ -144,4 +174,33 @@ export class SessionBackupViewModel extends ViewModel {
this.emitChange(""); this.emitChange("");
} }
} }
get isBackingUp() {
return !!this._backupOperation.get();
}
get backupPercentage() {
const progress = this._progress.get();
if (progress) {
return Math.round((progress.finished / progress.total) * 100);
}
return 0;
}
get backupInProgressLabel() {
const progress = this._progress.get();
if (progress) {
return this.i18n`${progress.finished} of ${progress.total}`;
}
return this.i18n``;
}
cancelBackup() {
this._backupOperation.get()?.abort();
}
startBackup() {
this._session.keyBackup.get()?.flush();
}
} }

View file

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {ViewModel} from "../../ViewModel.js"; import {ViewModel} from "../../ViewModel.js";
import {SessionBackupViewModel} from "./SessionBackupViewModel.js"; import {KeyBackupViewModel} from "./KeyBackupViewModel.js";
class PushNotificationStatus { class PushNotificationStatus {
constructor() { constructor() {
@ -43,7 +43,7 @@ export class SettingsViewModel extends ViewModel {
this._updateService = options.updateService; this._updateService = options.updateService;
const {client} = options; const {client} = options;
this._client = client; this._client = client;
this._sessionBackupViewModel = this.track(new SessionBackupViewModel(this.childOptions({session: this._session}))); this._keyBackupViewModel = this.track(new KeyBackupViewModel(this.childOptions({session: this._session})));
this._closeUrl = this.urlCreator.urlUntilSegment("session"); this._closeUrl = this.urlCreator.urlUntilSegment("session");
this._estimate = null; this._estimate = null;
this.sentImageSizeLimit = null; this.sentImageSizeLimit = null;
@ -115,8 +115,8 @@ export class SettingsViewModel extends ViewModel {
return !!this.platform.updateService; return !!this.platform.updateService;
} }
get sessionBackupViewModel() { get keyBackupViewModel() {
return this._sessionBackupViewModel; return this._keyBackupViewModel;
} }
get storageQuota() { get storageQuota() {

View file

@ -57,7 +57,8 @@ export class DeviceMessageHandler {
async writeSync(prep, txn) { async writeSync(prep, txn) {
// write olm changes // write olm changes
prep.olmDecryptChanges.write(txn); prep.olmDecryptChanges.write(txn);
await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn))); const didWriteValues = await Promise.all(prep.newRoomKeys.map(key => this._megolmDecryption.writeRoomKey(key, txn)));
return didWriteValues.some(didWrite => !!didWrite);
} }
} }

View file

@ -29,7 +29,7 @@ import {Decryption as OlmDecryption} from "./e2ee/olm/Decryption.js";
import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js"; import {Encryption as OlmEncryption} from "./e2ee/olm/Encryption.js";
import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption"; import {Decryption as MegOlmDecryption} from "./e2ee/megolm/Decryption";
import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader"; import {KeyLoader as MegOlmKeyLoader} from "./e2ee/megolm/decryption/KeyLoader";
import {SessionBackup} from "./e2ee/megolm/SessionBackup.js"; import {KeyBackup} from "./e2ee/megolm/keybackup/KeyBackup";
import {Encryption as MegOlmEncryption} from "./e2ee/megolm/Encryption.js"; import {Encryption as MegOlmEncryption} from "./e2ee/megolm/Encryption.js";
import {MEGOLM_ALGORITHM} from "./e2ee/common.js"; import {MEGOLM_ALGORITHM} from "./e2ee/common.js";
import {RoomEncryption} from "./e2ee/RoomEncryption.js"; import {RoomEncryption} from "./e2ee/RoomEncryption.js";
@ -70,12 +70,12 @@ export class Session {
this._e2eeAccount = null; this._e2eeAccount = null;
this._deviceTracker = null; this._deviceTracker = null;
this._olmEncryption = null; this._olmEncryption = null;
this._keyLoader = null;
this._megolmEncryption = null; this._megolmEncryption = null;
this._megolmDecryption = null; this._megolmDecryption = null;
this._getSyncToken = () => this.syncToken; this._getSyncToken = () => this.syncToken;
this._olmWorker = olmWorker; this._olmWorker = olmWorker;
this._sessionBackup = null; this._keyBackup = new ObservableValue(undefined);
this._hasSecretStorageKey = new ObservableValue(null);
this._observedRoomStatus = new Map(); this._observedRoomStatus = new Map();
if (olm) { if (olm) {
@ -90,7 +90,7 @@ export class Session {
} }
this._createRoomEncryption = this._createRoomEncryption.bind(this); this._createRoomEncryption = this._createRoomEncryption.bind(this);
this._forgetArchivedRoom = this._forgetArchivedRoom.bind(this); this._forgetArchivedRoom = this._forgetArchivedRoom.bind(this);
this.needsSessionBackup = new ObservableValue(false); this.needsKeyBackup = new ObservableValue(false);
} }
get fingerprintKey() { get fingerprintKey() {
@ -133,16 +133,17 @@ export class Session {
olmUtil: this._olmUtil, olmUtil: this._olmUtil,
senderKeyLock senderKeyLock
}); });
this._keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20);
this._megolmEncryption = new MegOlmEncryption({ this._megolmEncryption = new MegOlmEncryption({
account: this._e2eeAccount, account: this._e2eeAccount,
pickleKey: PICKLE_KEY, pickleKey: PICKLE_KEY,
olm: this._olm, olm: this._olm,
storage: this._storage, storage: this._storage,
keyLoader: this._keyLoader,
now: this._platform.clock.now, now: this._platform.clock.now,
ownDeviceId: this._sessionInfo.deviceId, ownDeviceId: this._sessionInfo.deviceId,
}); });
const keyLoader = new MegOlmKeyLoader(this._olm, PICKLE_KEY, 20); this._megolmDecryption = new MegOlmDecryption(this._keyLoader, this._olmWorker);
this._megolmDecryption = new MegOlmDecryption(keyLoader, this._olmWorker);
this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption}); this._deviceMessageHandler.enableEncryption({olmDecryption, megolmDecryption: this._megolmDecryption});
} }
@ -169,11 +170,11 @@ export class Session {
megolmEncryption: this._megolmEncryption, megolmEncryption: this._megolmEncryption,
megolmDecryption: this._megolmDecryption, megolmDecryption: this._megolmDecryption,
storage: this._storage, storage: this._storage,
sessionBackup: this._sessionBackup, keyBackup: this._keyBackup?.get(),
encryptionParams, encryptionParams,
notifyMissingMegolmSession: () => { notifyMissingMegolmSession: () => {
if (!this._sessionBackup) { if (!this._keyBackup.get()) {
this.needsSessionBackup.set(true) this.needsKeyBackup.set(true)
} }
}, },
clock: this._platform.clock clock: this._platform.clock
@ -182,38 +183,59 @@ export class Session {
/** /**
* Enable secret storage by providing the secret storage credential. * Enable secret storage by providing the secret storage credential.
* This will also see if there is a megolm session backup and try to enable that if so. * This will also see if there is a megolm key backup and try to enable that if so.
* *
* @param {string} type either "passphrase" or "recoverykey" * @param {string} type either "passphrase" or "recoverykey"
* @param {string} credential either the passphrase or the recovery key, depending on the type * @param {string} credential either the passphrase or the recovery key, depending on the type
* @return {Promise} resolves or rejects after having tried to enable secret storage * @return {Promise} resolves or rejects after having tried to enable secret storage
*/ */
async enableSecretStorage(type, credential) { enableSecretStorage(type, credential, log = undefined) {
return this._platform.logger.wrapOrRun(log, "enable secret storage", async log => {
if (!this._olm) { if (!this._olm) {
throw new Error("olm required"); throw new Error("olm required");
} }
if (this._sessionBackup) { if (this._keyBackup.get()) {
return false; this._keyBackup.get().dispose();
this._keyBackup.set(null);
} }
const key = await ssssKeyFromCredential(type, credential, this._storage, this._platform, this._olm); const key = await ssssKeyFromCredential(type, credential, this._storage, this._platform, this._olm);
// and create session backup, which needs to read from accountData // and create key backup, which needs to read from accountData
const readTxn = await this._storage.readTxn([ const readTxn = await this._storage.readTxn([
this._storage.storeNames.accountData, this._storage.storeNames.accountData,
]); ]);
await this._createSessionBackup(key, readTxn); if (await this._createKeyBackup(key, readTxn, log)) {
await this._writeSSSSKey(key);
this._hasSecretStorageKey.set(true);
return key;
}
async _writeSSSSKey(key) {
// only after having read a secret, write the key // only after having read a secret, write the key
// as we only find out if it was good if the MAC verification succeeds // as we only find out if it was good if the MAC verification succeeds
await this._writeSSSSKey(key, log);
this._keyBackup.get().flush(log);
return key;
} else {
throw new Error("Could not read key backup with the given key");
}
});
}
async _writeSSSSKey(key, log) {
// we're going to write the 4S key, and also the backup version.
// this way, we can detect when we enter a key for a new backup version
// and mark all inbound sessions to be backed up again
const keyBackup = this._keyBackup.get();
if (!keyBackup) {
return;
}
const backupVersion = keyBackup.version;
const writeTxn = await this._storage.readWriteTxn([ const writeTxn = await this._storage.readWriteTxn([
this._storage.storeNames.session, this._storage.storeNames.session,
this._storage.storeNames.inboundGroupSessions,
]); ]);
try { try {
ssssWriteKey(key, writeTxn); const previousBackupVersion = await ssssWriteKey(key, backupVersion, writeTxn);
log.set("previousBackupVersion", previousBackupVersion);
log.set("backupVersion", backupVersion);
if (!!previousBackupVersion && previousBackupVersion !== backupVersion) {
const amountMarked = await keyBackup.markAllForBackup(writeTxn);
log.set("amountMarkedForBackup", amountMarked);
}
} catch (err) { } catch (err) {
writeTxn.abort(); writeTxn.abort();
throw err; throw err;
@ -232,38 +254,53 @@ export class Session {
throw err; throw err;
} }
await writeTxn.complete(); await writeTxn.complete();
if (this._sessionBackup) { if (this._keyBackup.get()) {
for (const room of this._rooms.values()) { for (const room of this._rooms.values()) {
if (room.isEncrypted) { if (room.isEncrypted) {
room.enableSessionBackup(undefined); room.enableKeyBackup(undefined);
} }
} }
this._sessionBackup?.dispose(); this._keyBackup.get().dispose();
this._sessionBackup = undefined; this._keyBackup.set(null);
} }
this._hasSecretStorageKey.set(false);
} }
async _createSessionBackup(ssssKey, txn) { _createKeyBackup(ssssKey, txn, log) {
return log.wrap("enable key backup", async log => {
try {
const secretStorage = new SecretStorage({key: ssssKey, platform: this._platform}); const secretStorage = new SecretStorage({key: ssssKey, platform: this._platform});
this._sessionBackup = await SessionBackup.fromSecretStorage({ const keyBackup = await KeyBackup.fromSecretStorage(
platform: this._platform, this._platform,
olm: this._olm, secretStorage, this._olm,
hsApi: this._hsApi, secretStorage,
this._hsApi,
this._keyLoader,
this._storage,
txn txn
}); );
if (this._sessionBackup) { if (keyBackup) {
for (const room of this._rooms.values()) { for (const room of this._rooms.values()) {
if (room.isEncrypted) { if (room.isEncrypted) {
room.enableSessionBackup(this._sessionBackup); room.enableKeyBackup(keyBackup);
} }
} }
this._keyBackup.set(keyBackup);
return true;
} }
this.needsSessionBackup.set(false); } catch (err) {
log.catch(err);
}
return false;
});
} }
get sessionBackup() { /**
return this._sessionBackup; * @type {ObservableValue<KeyBackup | undefined | null}
* - `undefined` means, we're not done with catchup sync yet and haven't checked yet if key backup is configured
* - `null` means we've checked and key backup hasn't been configured correctly or at all.
*/
get keyBackup() {
return this._keyBackup;
} }
get hasIdentity() { get hasIdentity() {
@ -401,8 +438,8 @@ export class Session {
dispose() { dispose() {
this._olmWorker?.dispose(); this._olmWorker?.dispose();
this._olmWorker = undefined; this._olmWorker = undefined;
this._sessionBackup?.dispose(); this._keyBackup.get()?.dispose();
this._sessionBackup = undefined; this._keyBackup.set(undefined);
this._megolmDecryption?.dispose(); this._megolmDecryption?.dispose();
this._megolmDecryption = undefined; this._megolmDecryption = undefined;
this._e2eeAccount?.dispose(); this._e2eeAccount?.dispose();
@ -430,7 +467,7 @@ export class Session {
await txn.complete(); await txn.complete();
} }
// enable session backup, this requests the latest backup version // enable session backup, this requests the latest backup version
if (!this._sessionBackup) { if (!this._keyBackup.get()) {
if (dehydratedDevice) { if (dehydratedDevice) {
await log.wrap("SSSSKeyFromDehydratedDeviceKey", async log => { await log.wrap("SSSSKeyFromDehydratedDeviceKey", async log => {
const ssssKey = await createSSSSKeyFromDehydratedDeviceKey(dehydratedDevice.key, this._storage, this._platform); const ssssKey = await createSSSSKeyFromDehydratedDeviceKey(dehydratedDevice.key, this._storage, this._platform);
@ -438,7 +475,7 @@ export class Session {
log.set("success", true); log.set("success", true);
await this._writeSSSSKey(ssssKey); await this._writeSSSSKey(ssssKey);
} }
}) });
} }
const txn = await this._storage.readTxn([ const txn = await this._storage.readTxn([
this._storage.storeNames.session, this._storage.storeNames.session,
@ -448,9 +485,15 @@ export class Session {
const ssssKey = await ssssReadKey(txn); const ssssKey = await ssssReadKey(txn);
if (ssssKey) { if (ssssKey) {
// txn will end here as this does a network request // txn will end here as this does a network request
await this._createSessionBackup(ssssKey, txn); if (await this._createKeyBackup(ssssKey, txn, log)) {
this._keyBackup.get()?.flush(log);
}
}
if (!this._keyBackup.get()) {
// null means key backup isn't configured yet
// as opposed to undefined, which means we're still checking
this._keyBackup.set(null);
} }
this._hasSecretStorageKey.set(!!ssssKey);
} }
// restore unfinished operations, like sending out room keys // restore unfinished operations, like sending out room keys
const opsTxn = await this._storage.readWriteTxn([ const opsTxn = await this._storage.readWriteTxn([
@ -555,7 +598,7 @@ export class Session {
async writeSync(syncResponse, syncFilterId, preparation, txn, log) { async writeSync(syncResponse, syncFilterId, preparation, txn, log) {
const changes = { const changes = {
syncInfo: null, syncInfo: null,
e2eeAccountChanges: null, e2eeAccountChanges: null
}; };
const syncToken = syncResponse.next_batch; const syncToken = syncResponse.next_batch;
if (syncToken !== this.syncToken) { if (syncToken !== this.syncToken) {
@ -576,7 +619,7 @@ export class Session {
} }
if (preparation) { if (preparation) {
await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log)); changes.hasNewRoomKeys = await log.wrap("deviceMsgs", log => this._deviceMessageHandler.writeSync(preparation, txn, log));
} }
// store account data // store account data
@ -614,6 +657,9 @@ export class Session {
await log.wrap("uploadKeys", log => this._e2eeAccount.uploadKeys(this._storage, false, log)); await log.wrap("uploadKeys", log => this._e2eeAccount.uploadKeys(this._storage, false, log));
} }
} }
if (changes.hasNewRoomKeys) {
this._keyBackup.get()?.flush(log);
}
} }
applyRoomCollectionChangesAfterSync(inviteStates, roomStates, archivedRoomStates) { applyRoomCollectionChangesAfterSync(inviteStates, roomStates, archivedRoomStates) {

View file

@ -28,7 +28,7 @@ const MIN_PRESHARE_INTERVAL = 60 * 1000; // 1min
// TODO: this class is a good candidate for splitting up into encryption and decryption, there doesn't seem to be much overlap // TODO: this class is a good candidate for splitting up into encryption and decryption, there doesn't seem to be much overlap
export class RoomEncryption { export class RoomEncryption {
constructor({room, deviceTracker, olmEncryption, megolmEncryption, megolmDecryption, encryptionParams, storage, sessionBackup, notifyMissingMegolmSession, clock}) { constructor({room, deviceTracker, olmEncryption, megolmEncryption, megolmDecryption, encryptionParams, storage, keyBackup, notifyMissingMegolmSession, clock}) {
this._room = room; this._room = room;
this._deviceTracker = deviceTracker; this._deviceTracker = deviceTracker;
this._olmEncryption = olmEncryption; this._olmEncryption = olmEncryption;
@ -39,7 +39,7 @@ export class RoomEncryption {
// caches devices to verify events // caches devices to verify events
this._senderDeviceCache = new Map(); this._senderDeviceCache = new Map();
this._storage = storage; this._storage = storage;
this._sessionBackup = sessionBackup; this._keyBackup = keyBackup;
this._notifyMissingMegolmSession = notifyMissingMegolmSession; this._notifyMissingMegolmSession = notifyMissingMegolmSession;
this._clock = clock; this._clock = clock;
this._isFlushingRoomKeyShares = false; this._isFlushingRoomKeyShares = false;
@ -48,11 +48,11 @@ export class RoomEncryption {
this._disposed = false; this._disposed = false;
} }
enableSessionBackup(sessionBackup) { enableKeyBackup(keyBackup) {
if (this._sessionBackup && !!sessionBackup) { if (this._keyBackup && !!keyBackup) {
return; return;
} }
this._sessionBackup = sessionBackup; this._keyBackup = keyBackup;
} }
async restoreMissingSessionsFromBackup(entries, log) { async restoreMissingSessionsFromBackup(entries, log) {
@ -130,7 +130,7 @@ export class RoomEncryption {
})); }));
} }
if (!this._sessionBackup) { if (!this._keyBackup) {
return; return;
} }
@ -174,7 +174,7 @@ export class RoomEncryption {
async _requestMissingSessionFromBackup(senderKey, sessionId, log) { async _requestMissingSessionFromBackup(senderKey, sessionId, log) {
// show prompt to enable secret storage // show prompt to enable secret storage
if (!this._sessionBackup) { if (!this._keyBackup) {
log.set("enabled", false); log.set("enabled", false);
this._notifyMissingMegolmSession(); this._notifyMissingMegolmSession();
return; return;
@ -182,9 +182,7 @@ export class RoomEncryption {
log.set("id", sessionId); log.set("id", sessionId);
log.set("senderKey", senderKey); log.set("senderKey", senderKey);
try { try {
const session = await this._sessionBackup.getSession(this._room.id, sessionId, log); const roomKey = await this._keyBackup.getRoomKey(this._room.id, sessionId, log);
if (session?.algorithm === MEGOLM_ALGORITHM) {
let roomKey = this._megolmDecryption.roomKeyFromBackup(this._room.id, sessionId, session);
if (roomKey) { if (roomKey) {
if (roomKey.senderKey !== senderKey) { if (roomKey.senderKey !== senderKey) {
log.set("wrong_sender_key", roomKey.senderKey); log.set("wrong_sender_key", roomKey.senderKey);
@ -209,9 +207,6 @@ export class RoomEncryption {
await log.wrap("retryDecryption", log => this._room.notifyRoomKey(roomKey, retryEventIds || [], log)); await log.wrap("retryDecryption", log => this._room.notifyRoomKey(roomKey, retryEventIds || [], log));
} }
} }
} else if (session?.algorithm) {
log.set("unknown algorithm", session.algorithm);
}
} catch (err) { } catch (err) {
if (!(err.name === "HomeServerError" && err.errcode === "M_NOT_FOUND")) { if (!(err.name === "HomeServerError" && err.errcode === "M_NOT_FOUND")) {
log.set("not_found", true); log.set("not_found", true);
@ -241,6 +236,7 @@ export class RoomEncryption {
this._keySharePromise = (async () => { this._keySharePromise = (async () => {
const roomKeyMessage = await this._megolmEncryption.ensureOutboundSession(this._room.id, this._encryptionParams); const roomKeyMessage = await this._megolmEncryption.ensureOutboundSession(this._room.id, this._encryptionParams);
if (roomKeyMessage) { if (roomKeyMessage) {
this._keyBackup?.flush(log);
await log.wrap("share key", log => this._shareNewRoomKey(roomKeyMessage, hsApi, log)); await log.wrap("share key", log => this._shareNewRoomKey(roomKeyMessage, hsApi, log));
} }
})(); })();
@ -259,6 +255,7 @@ export class RoomEncryption {
} }
const megolmResult = await log.wrap("megolm encrypt", () => this._megolmEncryption.encrypt(this._room.id, type, content, this._encryptionParams)); const megolmResult = await log.wrap("megolm encrypt", () => this._megolmEncryption.encrypt(this._room.id, type, content, this._encryptionParams));
if (megolmResult.roomKeyMessage) { if (megolmResult.roomKeyMessage) {
this._keyBackup?.flush(log);
await log.wrap("share key", log => this._shareNewRoomKey(megolmResult.roomKeyMessage, hsApi, log)); await log.wrap("share key", log => this._shareNewRoomKey(megolmResult.roomKeyMessage, hsApi, log));
} }
return { return {

View file

@ -15,12 +15,14 @@ limitations under the License.
*/ */
import {MEGOLM_ALGORITHM} from "../common.js"; import {MEGOLM_ALGORITHM} from "../common.js";
import {OutboundRoomKey} from "./decryption/RoomKey";
export class Encryption { export class Encryption {
constructor({pickleKey, olm, account, storage, now, ownDeviceId}) { constructor({pickleKey, olm, account, keyLoader, storage, now, ownDeviceId}) {
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._olm = olm; this._olm = olm;
this._account = account; this._account = account;
this._keyLoader = keyLoader;
this._storage = storage; this._storage = storage;
this._now = now; this._now = now;
this._ownDeviceId = ownDeviceId; this._ownDeviceId = ownDeviceId;
@ -64,7 +66,7 @@ export class Encryption {
let roomKeyMessage; let roomKeyMessage;
try { try {
let sessionEntry = await txn.outboundGroupSessions.get(roomId); let sessionEntry = await txn.outboundGroupSessions.get(roomId);
roomKeyMessage = this._readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn); roomKeyMessage = await this._readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn);
if (roomKeyMessage) { if (roomKeyMessage) {
this._writeSession(this._now(), session, roomId, txn); this._writeSession(this._now(), session, roomId, txn);
} }
@ -79,7 +81,7 @@ export class Encryption {
} }
} }
_readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn) { async _readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn) {
if (sessionEntry) { if (sessionEntry) {
session.unpickle(this._pickleKey, sessionEntry.session); session.unpickle(this._pickleKey, sessionEntry.session);
} }
@ -91,7 +93,8 @@ export class Encryption {
} }
session.create(); session.create();
const roomKeyMessage = this._createRoomKeyMessage(session, roomId); const roomKeyMessage = this._createRoomKeyMessage(session, roomId);
this._storeAsInboundSession(session, roomId, txn); const roomKey = new OutboundRoomKey(roomId, session, this._account.identityKeys);
await roomKey.write(this._keyLoader, txn);
return roomKeyMessage; return roomKeyMessage;
} }
} }
@ -123,7 +126,7 @@ export class Encryption {
let encryptedContent; let encryptedContent;
try { try {
let sessionEntry = await txn.outboundGroupSessions.get(roomId); let sessionEntry = await txn.outboundGroupSessions.get(roomId);
roomKeyMessage = this._readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn); roomKeyMessage = await this._readOrCreateSession(session, sessionEntry, roomId, encryptionParams, txn);
encryptedContent = this._encryptContent(roomId, session, type, content); encryptedContent = this._encryptContent(roomId, session, type, content);
// update timestamp when a new session is created // update timestamp when a new session is created
const createdAt = roomKeyMessage ? this._now() : sessionEntry.createdAt; const createdAt = roomKeyMessage ? this._now() : sessionEntry.createdAt;
@ -190,26 +193,6 @@ export class Encryption {
chain_index: session.message_index() chain_index: session.message_index()
} }
} }
_storeAsInboundSession(outboundSession, roomId, txn) {
const {identityKeys} = this._account;
const claimedKeys = {ed25519: identityKeys.ed25519};
const session = new this._olm.InboundGroupSession();
try {
session.create(outboundSession.session_key());
const sessionEntry = {
roomId,
senderKey: identityKeys.curve25519,
sessionId: session.session_id(),
session: session.pickle(this._pickleKey),
claimedKeys,
};
txn.inboundGroupSessions.set(sessionEntry);
return sessionEntry;
} finally {
session.free();
}
}
} }
/** /**

View file

@ -1,62 +0,0 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
export class SessionBackup {
constructor({backupInfo, decryption, hsApi}) {
this._backupInfo = backupInfo;
this._decryption = decryption;
this._hsApi = hsApi;
}
async getSession(roomId, sessionId, log) {
const sessionResponse = await this._hsApi.roomKeyForRoomAndSession(this._backupInfo.version, roomId, sessionId, {log}).response();
const sessionInfo = this._decryption.decrypt(
sessionResponse.session_data.ephemeral,
sessionResponse.session_data.mac,
sessionResponse.session_data.ciphertext,
);
return JSON.parse(sessionInfo);
}
get version() {
return this._backupInfo.version;
}
dispose() {
this._decryption.free();
}
static async fromSecretStorage({platform, olm, secretStorage, hsApi, txn}) {
const base64PrivateKey = await secretStorage.readSecret("m.megolm_backup.v1", txn);
if (base64PrivateKey) {
const privateKey = new Uint8Array(platform.encoding.base64.decode(base64PrivateKey));
const backupInfo = await hsApi.roomKeysVersion().response();
const expectedPubKey = backupInfo.auth_data.public_key;
const decryption = new olm.PkDecryption();
try {
const pubKey = decryption.init_with_private_key(privateKey);
if (pubKey !== expectedPubKey) {
throw new Error(`Bad backup key, public key does not match. Calculated ${pubKey} but expected ${expectedPubKey}`);
}
} catch(err) {
decryption.free();
throw err;
}
return new SessionBackup({backupInfo, decryption, hsApi});
}
}
}

View file

@ -17,25 +17,14 @@ limitations under the License.
import {isBetterThan, IncomingRoomKey} from "./RoomKey"; import {isBetterThan, IncomingRoomKey} from "./RoomKey";
import {BaseLRUCache} from "../../../../utils/LRUCache"; import {BaseLRUCache} from "../../../../utils/LRUCache";
import type {RoomKey} from "./RoomKey"; import type {RoomKey} from "./RoomKey";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export declare class OlmDecryptionResult { export declare class OlmDecryptionResult {
readonly plaintext: string; readonly plaintext: string;
readonly message_index: number; readonly message_index: number;
} }
export declare class OlmInboundGroupSession {
constructor();
free(): void;
pickle(key: string | Uint8Array): string;
unpickle(key: string | Uint8Array, pickle: string);
create(session_key: string): string;
import_session(session_key: string): string;
decrypt(message: string): OlmDecryptionResult;
session_id(): string;
first_known_index(): number;
export_session(message_index: number): string;
}
/* /*
Because Olm only has very limited memory available when compiled to wasm, Because Olm only has very limited memory available when compiled to wasm,
we limit the amount of sessions held in memory. we limit the amount of sessions held in memory.
@ -43,11 +32,11 @@ we limit the amount of sessions held in memory.
export class KeyLoader extends BaseLRUCache<KeyOperation> { export class KeyLoader extends BaseLRUCache<KeyOperation> {
private pickleKey: string; private pickleKey: string;
private olm: any; private olm: Olm;
private resolveUnusedOperation?: () => void; private resolveUnusedOperation?: () => void;
private operationBecomesUnusedPromise?: Promise<void>; private operationBecomesUnusedPromise?: Promise<void>;
constructor(olm: any, pickleKey: string, limit: number) { constructor(olm: Olm, pickleKey: string, limit: number) {
super(limit); super(limit);
this.pickleKey = pickleKey; this.pickleKey = pickleKey;
this.olm = olm; this.olm = olm;
@ -60,7 +49,7 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
} }
} }
async useKey<T>(key: RoomKey, callback: (session: OlmInboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> { async useKey<T>(key: RoomKey, callback: (session: Olm.InboundGroupSession, pickleKey: string) => Promise<T> | T): Promise<T> {
const keyOp = await this.allocateOperation(key); const keyOp = await this.allocateOperation(key);
try { try {
return await callback(keyOp.session, this.pickleKey); return await callback(keyOp.session, this.pickleKey);
@ -186,11 +175,11 @@ export class KeyLoader extends BaseLRUCache<KeyOperation> {
} }
class KeyOperation { class KeyOperation {
session: OlmInboundGroupSession; session: Olm.InboundGroupSession;
key: RoomKey; key: RoomKey;
refCount: number; refCount: number;
constructor(key: RoomKey, session: OlmInboundGroupSession) { constructor(key: RoomKey, session: Olm.InboundGroupSession) {
this.key = key; this.key = key;
this.session = session; this.session = session;
this.refCount = 1; this.refCount = 1;
@ -224,6 +213,9 @@ class KeyOperation {
} }
} }
import {KeySource} from "../../../storage/idb/stores/InboundGroupSessionStore";
export function tests() { export function tests() {
let instances = 0; let instances = 0;
@ -248,7 +240,9 @@ export function tests() {
get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; } get serializationKey(): string { return `key-${this.sessionId}-${this._firstKnownIndex}`; }
get serializationType(): string { return "type"; } get serializationType(): string { return "type"; }
get eventIds(): string[] | undefined { return undefined; } get eventIds(): string[] | undefined { return undefined; }
loadInto(session: OlmInboundGroupSession) { get keySource(): KeySource { return KeySource.DeviceMessage; }
loadInto(session: Olm.InboundGroupSession) {
const mockSession = session as MockInboundSession; const mockSession = session as MockInboundSession;
mockSession.sessionId = this.sessionId; mockSession.sessionId = this.sessionId;
mockSession.firstKnownIndex = this._firstKnownIndex; mockSession.firstKnownIndex = this._firstKnownIndex;
@ -284,7 +278,7 @@ export function tests() {
return { return {
"load key gives correct session": async assert => { "load key gives correct session": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let callback1Called = false; let callback1Called = false;
let callback2Called = false; let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -305,7 +299,7 @@ export function tests() {
assert(callback2Called); assert(callback2Called);
}, },
"keys with different first index are kept separate": async assert => { "keys with different first index are kept separate": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let callback1Called = false; let callback1Called = false;
let callback2Called = false; let callback2Called = false;
const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { const p1 = loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -326,7 +320,7 @@ export function tests() {
assert(callback2Called); assert(callback2Called);
}, },
"useKey blocks as long as no free sessions are available": async assert => { "useKey blocks as long as no free sessions are available": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1);
let resolve; let resolve;
let callbackCalled = false; let callbackCalled = false;
loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => { loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {
@ -343,7 +337,7 @@ export function tests() {
assert.equal(callbackCalled, true); assert.equal(callbackCalled, true);
}, },
"cache hit while key in use, then replace (check refCount works properly)": async assert => { "cache hit while key in use, then replace (check refCount works properly)": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 1); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 1);
let resolve1, resolve2; let resolve1, resolve2;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
const p1 = loader.useKey(key1, async session => { const p1 = loader.useKey(key1, async session => {
@ -371,7 +365,7 @@ export function tests() {
assert.equal(callbackCalled, true); assert.equal(callbackCalled, true);
}, },
"cache hit while key not in use": async assert => { "cache hit while key not in use": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
let resolve1, resolve2, invocations = 0; let resolve1, resolve2, invocations = 0;
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1);
await loader.useKey(key1, async session => { invocations += 1; }); await loader.useKey(key1, async session => { invocations += 1; });
@ -385,7 +379,7 @@ export function tests() {
}, },
"dispose calls free on all sessions": async assert => { "dispose calls free on all sessions": async assert => {
instances = 0; instances = 0;
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {}); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId1, 1), async session => {});
await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {}); await loader.useKey(new MockRoomKey(roomId, aliceSenderKey, sessionId2, 1), async session => {});
assert.equal(instances, 2); assert.equal(instances, 2);
@ -395,7 +389,7 @@ export function tests() {
assert.strictEqual(loader.size, 0, "loader.size"); assert.strictEqual(loader.size, 0, "loader.size");
}, },
"checkBetterThanKeyInStorage false with cache": async assert => { "checkBetterThanKeyInStorage false with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
// fake we've checked with storage that this is the best key, // fake we've checked with storage that this is the best key,
@ -409,7 +403,7 @@ export function tests() {
assert.strictEqual(key2.isBetter, false); assert.strictEqual(key2.isBetter, false);
}, },
"checkBetterThanKeyInStorage true with cache": async assert => { "checkBetterThanKeyInStorage true with cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key key1.isBetter = true; // fake we've check with storage so far (not including key2) this is the best key
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
@ -420,7 +414,7 @@ export function tests() {
assert.strictEqual(key2.isBetter, true); assert.strictEqual(key2.isBetter, true);
}, },
"prefer to remove worst key for a session from cache": async assert => { "prefer to remove worst key for a session from cache": async assert => {
const loader = new KeyLoader(olm, PICKLE_KEY, 2); const loader = new KeyLoader(olm as any as Olm, PICKLE_KEY, 2);
const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2); const key1 = new MockRoomKey(roomId, aliceSenderKey, sessionId1, 2);
await loader.useKey(key1, async session => {}); await loader.useKey(key1, async session => {});
key1.isBetter = true; // set to true just so it gets returned from getCachedKey key1.isBetter = true; // set to true just so it gets returned from getCachedKey

View file

@ -14,10 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {BackupStatus, KeySource} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore"; import type {InboundGroupSessionEntry} from "../../../storage/idb/stores/InboundGroupSessionStore";
import type {Transaction} from "../../../storage/idb/Transaction"; import type {Transaction} from "../../../storage/idb/Transaction";
import type {DecryptionResult} from "../../DecryptionResult"; import type {DecryptionResult} from "../../DecryptionResult";
import type {KeyLoader, OlmInboundGroupSession} from "./KeyLoader"; import type {KeyLoader} from "./KeyLoader";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export abstract class RoomKey { export abstract class RoomKey {
private _isBetter: boolean | undefined; private _isBetter: boolean | undefined;
@ -33,7 +36,7 @@ export abstract class RoomKey {
abstract get serializationKey(): string; abstract get serializationKey(): string;
abstract get serializationType(): string; abstract get serializationType(): string;
abstract get eventIds(): string[] | undefined; abstract get eventIds(): string[] | undefined;
abstract loadInto(session: OlmInboundGroupSession, pickleKey: string): void; abstract loadInto(session: Olm.InboundGroupSession, pickleKey: string): void;
/* Whether the key has been checked against storage (or is from storage) /* Whether the key has been checked against storage (or is from storage)
* to be the better key for a given session. Given that all keys are checked to be better * to be the better key for a given session. Given that all keys are checked to be better
* as part of writing, we can trust that when this returns true, it really is the best key * as part of writing, we can trust that when this returns true, it really is the best key
@ -44,7 +47,7 @@ export abstract class RoomKey {
set isBetter(value: boolean | undefined) { this._isBetter = value; } set isBetter(value: boolean | undefined) { this._isBetter = value; }
} }
export function isBetterThan(newSession: OlmInboundGroupSession, existingSession: OlmInboundGroupSession) { export function isBetterThan(newSession: Olm.InboundGroupSession, existingSession: Olm.InboundGroupSession) {
return newSession.first_known_index() < existingSession.first_known_index(); return newSession.first_known_index() < existingSession.first_known_index();
} }
@ -57,7 +60,7 @@ export abstract class IncomingRoomKey extends RoomKey {
async write(loader: KeyLoader, txn: Transaction): Promise<boolean> { async write(loader: KeyLoader, txn: Transaction): Promise<boolean> {
// we checked already and we had a better session in storage, so don't write // we checked already and we had a better session in storage, so don't write
let pickledSession; let pickledSession: string | undefined;
if (this.isBetter === undefined) { if (this.isBetter === undefined) {
// if this key wasn't used to decrypt any messages in the same sync, // if this key wasn't used to decrypt any messages in the same sync,
// we haven't checked if this is the best key yet, // we haven't checked if this is the best key yet,
@ -79,6 +82,8 @@ export abstract class IncomingRoomKey extends RoomKey {
senderKey: this.senderKey, senderKey: this.senderKey,
sessionId: this.sessionId, sessionId: this.sessionId,
session: pickledSession, session: pickledSession,
backup: this.backupStatus,
source: this.keySource,
claimedKeys: {"ed25519": this.claimedEd25519Key}, claimedKeys: {"ed25519": this.claimedEd25519Key},
}; };
txn.inboundGroupSessions.set(sessionEntry); txn.inboundGroupSessions.set(sessionEntry);
@ -87,7 +92,7 @@ export abstract class IncomingRoomKey extends RoomKey {
get eventIds() { return this._eventIds; } get eventIds() { return this._eventIds; }
private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: OlmInboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> { private async _checkBetterThanKeyInStorage(loader: KeyLoader, callback: (((session: Olm.InboundGroupSession, pickleKey: string) => void) | undefined), txn: Transaction): Promise<boolean> {
if (this.isBetter !== undefined) { if (this.isBetter !== undefined) {
return this.isBetter; return this.isBetter;
} }
@ -123,6 +128,12 @@ export abstract class IncomingRoomKey extends RoomKey {
} }
return this.isBetter!; return this.isBetter!;
} }
protected get backupStatus(): BackupStatus {
return BackupStatus.NotBackedUp;
}
protected abstract get keySource(): KeySource;
} }
class DeviceMessageRoomKey extends IncomingRoomKey { class DeviceMessageRoomKey extends IncomingRoomKey {
@ -139,22 +150,48 @@ class DeviceMessageRoomKey extends IncomingRoomKey {
get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; } get claimedEd25519Key() { return this._decryptionResult.claimedEd25519Key; }
get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; } get serializationKey(): string { return this._decryptionResult.event.content?.["session_key"]; }
get serializationType(): string { return "create"; } get serializationType(): string { return "create"; }
protected get keySource(): KeySource { return KeySource.DeviceMessage; }
loadInto(session) { loadInto(session) {
session.create(this.serializationKey); session.create(this.serializationKey);
} }
} }
class BackupRoomKey extends IncomingRoomKey { // a room key we send out ourselves,
private _roomId: string; // here adapted to write it as an incoming key
private _sessionId: string; // as we don't send it to ourself with a to_device msg
private _backupInfo: string; export class OutboundRoomKey extends IncomingRoomKey {
private _sessionKey: string;
constructor(roomId, sessionId, backupInfo) { constructor(
private readonly _roomId: string,
private readonly outboundSession: Olm.OutboundGroupSession,
private readonly identityKeys: {[algo: string]: string}
) {
super();
// this is a new key, so always better than what might be in storage, no need to check
this.isBetter = true;
// cache this, as it is used by key loader to find a matching key and
// this calls into WASM so is not just reading a prop
this._sessionKey = this.outboundSession.session_key();
}
get roomId(): string { return this._roomId; }
get senderKey(): string { return this.identityKeys.curve25519; }
get sessionId(): string { return this.outboundSession.session_id(); }
get claimedEd25519Key(): string { return this.identityKeys.ed25519; }
get serializationKey(): string { return this._sessionKey; }
get serializationType(): string { return "create"; }
protected get keySource(): KeySource { return KeySource.Outbound; }
loadInto(session: Olm.InboundGroupSession) {
session.create(this.serializationKey);
}
}
class BackupRoomKey extends IncomingRoomKey {
constructor(private _roomId: string, private _sessionId: string, private _backupInfo: object) {
super(); super();
this._roomId = roomId;
this._sessionId = sessionId;
this._backupInfo = backupInfo;
} }
get roomId() { return this._roomId; } get roomId() { return this._roomId; }
@ -163,13 +200,18 @@ class BackupRoomKey extends IncomingRoomKey {
get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; } get claimedEd25519Key() { return this._backupInfo["sender_claimed_keys"]?.["ed25519"]; }
get serializationKey(): string { return this._backupInfo["session_key"]; } get serializationKey(): string { return this._backupInfo["session_key"]; }
get serializationType(): string { return "import_session"; } get serializationType(): string { return "import_session"; }
protected get keySource(): KeySource { return KeySource.Backup; }
loadInto(session) { loadInto(session) {
session.import_session(this.serializationKey); session.import_session(this.serializationKey);
} }
protected get backupStatus(): BackupStatus {
return BackupStatus.BackedUp;
}
} }
class StoredRoomKey extends RoomKey { export class StoredRoomKey extends RoomKey {
private storageEntry: InboundGroupSessionEntry; private storageEntry: InboundGroupSessionEntry;
constructor(storageEntry: InboundGroupSessionEntry) { constructor(storageEntry: InboundGroupSessionEntry) {

View file

@ -17,7 +17,7 @@ limitations under the License.
import {DecryptionResult} from "../../DecryptionResult.js"; import {DecryptionResult} from "../../DecryptionResult.js";
import {DecryptionError} from "../../common.js"; import {DecryptionError} from "../../common.js";
import {ReplayDetectionEntry} from "./ReplayDetectionEntry"; import {ReplayDetectionEntry} from "./ReplayDetectionEntry";
import type {RoomKey} from "./RoomKey.js"; import type {RoomKey} from "./RoomKey";
import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader"; import type {KeyLoader, OlmDecryptionResult} from "./KeyLoader";
import type {OlmWorker} from "../../OlmWorker"; import type {OlmWorker} from "../../OlmWorker";
import type {TimelineEvent} from "../../../storage/types"; import type {TimelineEvent} from "../../../storage/types";
@ -61,7 +61,7 @@ export class SessionDecryption {
this.decryptionRequests!.push(request); this.decryptionRequests!.push(request);
decryptionResult = await request.response(); decryptionResult = await request.response();
} else { } else {
decryptionResult = session.decrypt(ciphertext); decryptionResult = session.decrypt(ciphertext) as OlmDecryptionResult;
} }
const {plaintext} = decryptionResult!; const {plaintext} = decryptionResult!;
let payload; let payload;

View file

@ -0,0 +1,91 @@
/*
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {MEGOLM_ALGORITHM} from "../../common";
import type {RoomKey} from "../decryption/RoomKey";
import type {BaseBackupInfo, SignatureMap, SessionKeyInfo} from "./types";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
export const Algorithm = "m.megolm_backup.v1.curve25519-aes-sha2";
export type BackupInfo = BaseBackupInfo & {
algorithm: typeof Algorithm,
auth_data: AuthData,
}
type AuthData = {
public_key: string,
signatures: SignatureMap
}
export type SessionData = {
ciphertext: string,
mac: string,
ephemeral: string,
}
export class BackupEncryption {
constructor(
private encryption?: Olm.PkEncryption,
private decryption?: Olm.PkDecryption
) {}
static fromAuthData(authData: AuthData, privateKey: Uint8Array, olm: Olm): BackupEncryption {
const expectedPubKey = authData.public_key;
const decryption = new olm.PkDecryption();
const encryption = new olm.PkEncryption();
try {
const pubKey = decryption.init_with_private_key(privateKey);
if (pubKey !== expectedPubKey) {
throw new Error(`Bad backup key, public key does not match. Calculated ${pubKey} but expected ${expectedPubKey}`);
}
encryption.set_recipient_key(pubKey);
} catch(err) {
decryption.free();
throw err;
}
return new BackupEncryption(encryption, decryption);
}
decryptRoomKey(sessionData: SessionData): SessionKeyInfo {
const sessionInfo = this.decryption!.decrypt(
sessionData.ephemeral,
sessionData.mac,
sessionData.ciphertext,
);
return JSON.parse(sessionInfo) as SessionKeyInfo;
}
encryptRoomKey(key: RoomKey, sessionKey: string): SessionData {
const sessionInfo: SessionKeyInfo = {
algorithm: MEGOLM_ALGORITHM,
sender_key: key.senderKey,
sender_claimed_keys: {ed25519: key.claimedEd25519Key},
forwarding_curve25519_key_chain: [],
session_key: sessionKey
};
return this.encryption!.encrypt(JSON.stringify(sessionInfo)) as SessionData;
}
dispose() {
this.decryption?.free();
this.decryption = undefined;
this.encryption?.free();
this.encryption = undefined;
}
}

View file

@ -0,0 +1,209 @@
/*
Copyright 2020 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import {StoreNames} from "../../../storage/common";
import {StoredRoomKey, keyFromBackup} from "../decryption/RoomKey";
import {MEGOLM_ALGORITHM} from "../../common";
import * as Curve25519 from "./Curve25519";
import {AbortableOperation} from "../../../../utils/AbortableOperation";
import {ObservableValue} from "../../../../observable/ObservableValue";
import {SetAbortableFn} from "../../../../utils/AbortableOperation";
import type {BackupInfo, SessionData, SessionKeyInfo, SessionInfo, KeyBackupPayload} from "./types";
import type {HomeServerApi} from "../../../net/HomeServerApi";
import type {IncomingRoomKey, RoomKey} from "../decryption/RoomKey";
import type {KeyLoader} from "../decryption/KeyLoader";
import type {SecretStorage} from "../../../ssss/SecretStorage";
import type {Storage} from "../../../storage/idb/Storage";
import type {ILogItem} from "../../../../logging/types";
import type {Platform} from "../../../../platform/web/Platform";
import type {Transaction} from "../../../storage/idb/Transaction";
import type * as OlmNamespace from "@matrix-org/olm";
type Olm = typeof OlmNamespace;
const KEYS_PER_REQUEST = 200;
export class KeyBackup {
public readonly operationInProgress = new ObservableValue<AbortableOperation<Promise<void>, Progress> | undefined>(undefined);
private _stopped = false;
private _needsNewKey = false;
private _hasBackedUpAllKeys = false;
private _error?: Error;
constructor(
private readonly backupInfo: BackupInfo,
private readonly crypto: Curve25519.BackupEncryption,
private readonly hsApi: HomeServerApi,
private readonly keyLoader: KeyLoader,
private readonly storage: Storage,
private readonly platform: Platform,
private readonly maxDelay: number = 10000
) {}
get hasStopped(): boolean { return this._stopped; }
get error(): Error | undefined { return this._error; }
get version(): string { return this.backupInfo.version; }
get needsNewKey(): boolean { return this._needsNewKey; }
get hasBackedUpAllKeys(): boolean { return this._hasBackedUpAllKeys; }
async getRoomKey(roomId: string, sessionId: string, log: ILogItem): Promise<IncomingRoomKey | undefined> {
const sessionResponse = await this.hsApi.roomKeyForRoomAndSession(this.backupInfo.version, roomId, sessionId, {log}).response();
if (!sessionResponse.session_data) {
return;
}
const sessionKeyInfo = this.crypto.decryptRoomKey(sessionResponse.session_data as SessionData);
if (sessionKeyInfo?.algorithm === MEGOLM_ALGORITHM) {
return keyFromBackup(roomId, sessionId, sessionKeyInfo);
} else if (sessionKeyInfo?.algorithm) {
log.set("unknown algorithm", sessionKeyInfo.algorithm);
}
}
markAllForBackup(txn: Transaction): Promise<number> {
return txn.inboundGroupSessions.markAllAsNotBackedUp();
}
flush(log: ILogItem): void {
if (!this.operationInProgress.get()) {
log.wrapDetached("flush key backup", async log => {
if (this._needsNewKey) {
log.set("needsNewKey", this._needsNewKey);
return;
}
this._stopped = false;
this._error = undefined;
this._hasBackedUpAllKeys = false;
const operation = this._runFlushOperation(log);
this.operationInProgress.set(operation);
try {
await operation.result;
this._hasBackedUpAllKeys = true;
} catch (err) {
this._stopped = true;
if (err.name === "HomeServerError" && (err.errcode === "M_WRONG_ROOM_KEYS_VERSION" || err.errcode === "M_NOT_FOUND")) {
log.set("wrong_version", true);
this._needsNewKey = true;
} else {
// TODO should really also use AbortError in storage
if (err.name !== "AbortError" || (err.name === "StorageError" && err.errcode === "AbortError")) {
this._error = err;
}
}
log.catch(err);
}
this.operationInProgress.set(undefined);
});
}
}
private _runFlushOperation(log: ILogItem): AbortableOperation<Promise<void>, Progress> {
return new AbortableOperation(async (setAbortable, setProgress) => {
let total = 0;
let amountFinished = 0;
while (true) {
const waitMs = this.platform.random() * this.maxDelay;
const timeout = this.platform.clock.createTimeout(waitMs);
setAbortable(timeout);
await timeout.elapsed();
const txn = await this.storage.readTxn([StoreNames.inboundGroupSessions]);
setAbortable(txn);
// fetch total again on each iteration as while we are flushing, sync might be adding keys
total = amountFinished + await txn.inboundGroupSessions.countNonBackedUpSessions();
setProgress(new Progress(total, amountFinished));
const keysNeedingBackup = (await txn.inboundGroupSessions.getFirstNonBackedUpSessions(KEYS_PER_REQUEST))
.map(entry => new StoredRoomKey(entry));
if (keysNeedingBackup.length === 0) {
return;
}
const payload = await this.encodeKeysForBackup(keysNeedingBackup);
const uploadRequest = this.hsApi.uploadRoomKeysToBackup(this.backupInfo.version, payload, {log});
setAbortable(uploadRequest);
await uploadRequest.response();
await this.markKeysAsBackedUp(keysNeedingBackup, setAbortable);
amountFinished += keysNeedingBackup.length;
setProgress(new Progress(total, amountFinished));
}
});
}
private async encodeKeysForBackup(roomKeys: RoomKey[]): Promise<KeyBackupPayload> {
const payload: KeyBackupPayload = { rooms: {} };
const payloadRooms = payload.rooms;
for (const key of roomKeys) {
let roomPayload = payloadRooms[key.roomId];
if (!roomPayload) {
roomPayload = payloadRooms[key.roomId] = { sessions: {} };
}
roomPayload.sessions[key.sessionId] = await this.encodeRoomKey(key);
}
return payload;
}
private async markKeysAsBackedUp(roomKeys: RoomKey[], setAbortable: SetAbortableFn) {
const txn = await this.storage.readWriteTxn([
StoreNames.inboundGroupSessions,
]);
setAbortable(txn);
try {
await Promise.all(roomKeys.map(key => {
return txn.inboundGroupSessions.markAsBackedUp(key.roomId, key.senderKey, key.sessionId);
}));
} catch (err) {
txn.abort();
throw err;
}
await txn.complete();
}
private async encodeRoomKey(roomKey: RoomKey): Promise<SessionInfo> {
return await this.keyLoader.useKey(roomKey, session => {
const firstMessageIndex = session.first_known_index();
const sessionKey = session.export_session(firstMessageIndex);
return {
first_message_index: firstMessageIndex,
forwarded_count: 0,
is_verified: false,
session_data: this.crypto.encryptRoomKey(roomKey, sessionKey)
};
});
}
dispose() {
this.crypto.dispose();
}
static async fromSecretStorage(platform: Platform, olm: Olm, secretStorage: SecretStorage, hsApi: HomeServerApi, keyLoader: KeyLoader, storage: Storage, txn: Transaction): Promise<KeyBackup | undefined> {
const base64PrivateKey = await secretStorage.readSecret("m.megolm_backup.v1", txn);
if (base64PrivateKey) {
const privateKey = new Uint8Array(platform.encoding.base64.decode(base64PrivateKey));
const backupInfo = await hsApi.roomKeysVersion().response() as BackupInfo;
if (backupInfo.algorithm === Curve25519.Algorithm) {
const crypto = Curve25519.BackupEncryption.fromAuthData(backupInfo.auth_data, privateKey, olm);
return new KeyBackup(backupInfo, crypto, hsApi, keyLoader, storage, platform);
} else {
throw new Error(`Unknown backup algorithm: ${backupInfo.algorithm}`);
}
}
}
}
export class Progress {
constructor(
public readonly total: number,
public readonly finished: number
) {}
}

View file

@ -0,0 +1,61 @@
/*
Copyright 2022 The Matrix.org Foundation C.I.C.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
import type * as Curve25519 from "./Curve25519";
import type {MEGOLM_ALGORITHM} from "../../common";
export type SignatureMap = {
[userId: string]: {[deviceIdAndAlgorithm: string]: string}
}
export type BaseBackupInfo = {
version: string,
etag: string,
count: number,
}
export type OtherBackupInfo = BaseBackupInfo & {
algorithm: "other"
};
export type BackupInfo = Curve25519.BackupInfo | OtherBackupInfo;
export type SessionData = Curve25519.SessionData;
export type SessionInfo = {
first_message_index: number,
forwarded_count: number,
is_verified: boolean,
session_data: SessionData
}
export type MegOlmSessionKeyInfo = {
algorithm: MEGOLM_ALGORITHM,
sender_key: string,
sender_claimed_keys: {[algorithm: string]: string},
forwarding_curve25519_key_chain: string[],
session_key: string
}
// the type that session_data decrypts from / encrypts to
export type SessionKeyInfo = MegOlmSessionKeyInfo | {algorithm: string};
export type KeyBackupPayload = {
rooms: {
[roomId: string]: {
sessions: {[sessionId: string]: SessionInfo}
}
}
}

View file

@ -227,6 +227,10 @@ export class HomeServerApi {
return this._get(`/room_keys/keys/${encodeURIComponent(roomId)}/${encodeURIComponent(sessionId)}`, {version}, undefined, options); return this._get(`/room_keys/keys/${encodeURIComponent(roomId)}/${encodeURIComponent(sessionId)}`, {version}, undefined, options);
} }
uploadRoomKeysToBackup(version: string, payload: Record<string, any>, options?: IRequestOptions): IHomeServerRequest {
return this._put(`/room_keys/keys`, {version}, payload, options);
}
uploadAttachment(blob: Blob, filename: string, options?: IRequestOptions): IHomeServerRequest { uploadAttachment(blob: Blob, filename: string, options?: IRequestOptions): IHomeServerRequest {
return this._authedRequest("POST", `${this._homeserver}/_matrix/media/r0/upload`, {filename}, blob, options); return this._authedRequest("POST", `${this._homeserver}/_matrix/media/r0/upload`, {filename}, blob, options);
} }

View file

@ -461,11 +461,11 @@ export class BaseRoom extends EventEmitter {
return observable; return observable;
} }
enableSessionBackup(sessionBackup) { enableKeyBackup(keyBackup) {
this._roomEncryption?.enableSessionBackup(sessionBackup); this._roomEncryption?.enableKeyBackup(keyBackup);
// TODO: do we really want to do this every time you open the app? // TODO: do we really want to do this every time you open the app?
if (this._timeline && sessionBackup) { if (this._timeline && keyBackup) {
this._platform.logger.run("enableSessionBackup", log => { this._platform.logger.run("enableKeyBackup", log => {
return this._roomEncryption.restoreMissingSessionsFromBackup(this._timeline.remoteEntries, log); return this._roomEncryption.restoreMissingSessionsFromBackup(this._timeline.remoteEntries, log);
}); });
} }

View file

@ -27,6 +27,7 @@ import type * as OlmNamespace from "@matrix-org/olm"
type Olm = typeof OlmNamespace; type Olm = typeof OlmNamespace;
const SSSS_KEY = `${SESSION_E2EE_KEY_PREFIX}ssssKey`; const SSSS_KEY = `${SESSION_E2EE_KEY_PREFIX}ssssKey`;
const BACKUPVERSION_KEY = `${SESSION_E2EE_KEY_PREFIX}keyBackupVersion`;
export enum KeyType { export enum KeyType {
"RecoveryKey", "RecoveryKey",
@ -49,8 +50,11 @@ async function readDefaultKeyDescription(storage: Storage): Promise<KeyDescripti
return new KeyDescription(id, keyAccountData.content as KeyDescriptionData); return new KeyDescription(id, keyAccountData.content as KeyDescriptionData);
} }
export async function writeKey(key: Key, txn: Transaction): Promise<void> { export async function writeKey(key: Key, keyBackupVersion: number, txn: Transaction): Promise<number | undefined> {
const existingVersion: number | undefined = await txn.session.get(BACKUPVERSION_KEY);
txn.session.set(BACKUPVERSION_KEY, keyBackupVersion);
txn.session.set(SSSS_KEY, {id: key.id, binaryKey: key.binaryKey}); txn.session.set(SSSS_KEY, {id: key.id, binaryKey: key.binaryKey});
return existingVersion;
} }
export async function readKey(txn: Transaction): Promise<Key | undefined> { export async function readKey(txn: Transaction): Promise<Key | undefined> {

View file

@ -37,7 +37,8 @@ interface QueryTargetInterface<T> {
openKeyCursor(range?: IDBQuery, direction?: IDBCursorDirection | undefined): IDBRequest<IDBCursor | null>; openKeyCursor(range?: IDBQuery, direction?: IDBCursorDirection | undefined): IDBRequest<IDBCursor | null>;
supports(method: string): boolean; supports(method: string): boolean;
keyPath: string | string[]; keyPath: string | string[];
get(key: IDBValidKey | IDBKeyRange): IDBRequest<T | null>; count(keyRange?: IDBKeyRange): IDBRequest<number>;
get(key: IDBValidKey | IDBKeyRange): IDBRequest<T | undefined>;
getKey(key: IDBValidKey | IDBKeyRange): IDBRequest<IDBValidKey | undefined>; getKey(key: IDBValidKey | IDBKeyRange): IDBRequest<IDBValidKey | undefined>;
} }
@ -78,7 +79,11 @@ export class QueryTarget<T> {
return this._target.supports(methodName); return this._target.supports(methodName);
} }
get(key: IDBValidKey | IDBKeyRange): Promise<T | null> { count(keyRange?: IDBKeyRange): Promise<number> {
return reqAsPromise(this._target.count(keyRange));
}
get(key: IDBValidKey | IDBKeyRange): Promise<T | undefined> {
return reqAsPromise(this._target.get(key)); return reqAsPromise(this._target.get(key));
} }

View file

@ -91,7 +91,7 @@ export class QueryTargetWrapper<T> {
} }
} }
get(key: IDBValidKey | IDBKeyRange): IDBRequest<T | null> { get(key: IDBValidKey | IDBKeyRange): IDBRequest<T | undefined> {
try { try {
LOG_REQUESTS && logRequest("get", [key], this._qt); LOG_REQUESTS && logRequest("get", [key], this._qt);
return this._qt.get(key); return this._qt.get(key);
@ -118,6 +118,14 @@ export class QueryTargetWrapper<T> {
} }
} }
count(keyRange?: IDBKeyRange): IDBRequest<number> {
try {
return this._qt.count(keyRange);
} catch(err) {
throw new IDBRequestAttemptError("count", this._qt, err, [keyRange]);
}
}
index(name: string): IDBIndex { index(name: string): IDBIndex {
try { try {
return this._qtStore.index(name); return this._qtStore.index(name);

View file

@ -6,6 +6,7 @@ import {addRoomToIdentity} from "../../e2ee/DeviceTracker.js";
import {SESSION_E2EE_KEY_PREFIX} from "../../e2ee/common.js"; import {SESSION_E2EE_KEY_PREFIX} from "../../e2ee/common.js";
import {SummaryData} from "../../room/RoomSummary"; import {SummaryData} from "../../room/RoomSummary";
import {RoomMemberStore, MemberData} from "./stores/RoomMemberStore"; import {RoomMemberStore, MemberData} from "./stores/RoomMemberStore";
import {InboundGroupSessionStore, InboundGroupSessionEntry, BackupStatus, KeySource} from "./stores/InboundGroupSessionStore";
import {RoomStateEntry} from "./stores/RoomStateStore"; import {RoomStateEntry} from "./stores/RoomStateStore";
import {SessionStore} from "./stores/SessionStore"; import {SessionStore} from "./stores/SessionStore";
import {Store} from "./Store"; import {Store} from "./Store";
@ -31,13 +32,29 @@ export const schema: MigrationFunc[] = [
fixMissingRoomsInUserIdentities, fixMissingRoomsInUserIdentities,
changeSSSSKeyPrefix, changeSSSSKeyPrefix,
backupAndRestoreE2EEAccountToLocalStorage, backupAndRestoreE2EEAccountToLocalStorage,
clearAllStores clearAllStores,
addInboundSessionBackupIndex
]; ];
// TODO: how to deal with git merge conflicts of this array? // TODO: how to deal with git merge conflicts of this array?
// TypeScript note: for now, do not bother introducing interfaces / alias // TypeScript note: for now, do not bother introducing interfaces / alias
// for old schemas. Just take them as `any`. // for old schemas. Just take them as `any`.
function createDatabaseNameHelper(db: IDBDatabase): ITransaction {
// the Store object gets passed in several things through the Transaction class (a wrapper around IDBTransaction),
// the only thing we should need here is the databaseName though, so we mock it out.
// ideally we should have an easier way to go from the idb primitive layer to the specific store classes where
// we implement logic, but for now we need this.
const databaseNameHelper: ITransaction = {
databaseName: db.name,
get idbFactory(): IDBFactory { throw new Error("unused");},
get IDBKeyRange(): typeof IDBKeyRange { throw new Error("unused");},
addWriteError() {},
};
return databaseNameHelper;
}
// how do we deal with schema updates vs existing data migration in a way that // how do we deal with schema updates vs existing data migration in a way that
//v1 //v1
function createInitialStores(db: IDBDatabase): void { function createInitialStores(db: IDBDatabase): void {
@ -222,17 +239,7 @@ async function changeSSSSKeyPrefix(db: IDBDatabase, txn: IDBTransaction) {
// v13 // v13
async function backupAndRestoreE2EEAccountToLocalStorage(db: IDBDatabase, txn: IDBTransaction, localStorage: IDOMStorage, log: ILogItem) { async function backupAndRestoreE2EEAccountToLocalStorage(db: IDBDatabase, txn: IDBTransaction, localStorage: IDOMStorage, log: ILogItem) {
const session = txn.objectStore("session"); const session = txn.objectStore("session");
// the Store object gets passed in several things through the Transaction class (a wrapper around IDBTransaction), const sessionStore = new SessionStore(new Store(session, createDatabaseNameHelper(db)), localStorage);
// the only thing we should need here is the databaseName though, so we mock it out.
// ideally we should have an easier way to go from the idb primitive layer to the specific store classes where
// we implement logic, but for now we need this.
const databaseNameHelper: ITransaction = {
databaseName: db.name,
get idbFactory(): IDBFactory { throw new Error("unused");},
get IDBKeyRange(): typeof IDBKeyRange { throw new Error("unused");},
addWriteError() {},
};
const sessionStore = new SessionStore(new Store(session, databaseNameHelper), localStorage);
// if we already have an e2ee identity, write a backup to local storage. // if we already have an e2ee identity, write a backup to local storage.
// further updates to e2ee keys in the session store will also write to local storage from 0.2.15 on, // further updates to e2ee keys in the session store will also write to local storage from 0.2.15 on,
// but here we make sure a backup is immediately created after installing the update and we don't wait until // but here we make sure a backup is immediately created after installing the update and we don't wait until
@ -270,3 +277,18 @@ async function clearAllStores(db: IDBDatabase, txn: IDBTransaction) {
} }
} }
} }
// v15 add backup index to inboundGroupSessions
async function addInboundSessionBackupIndex(db: IDBDatabase, txn: IDBTransaction, localStorage: IDOMStorage, log: ILogItem): Promise<void> {
const inboundGroupSessions = txn.objectStore("inboundGroupSessions");
await iterateCursor<InboundGroupSessionEntry>(inboundGroupSessions.openCursor(), (value, key, cursor) => {
value.backup = BackupStatus.NotBackedUp;
// we'll also have backup keys in here, we can't tell,
// but the worst thing that can happen is that we try
// to backup keys that were already in backup, which
// the server will ignore
value.source = KeySource.DeviceMessage;
return NOT_DONE;
});
inboundGroupSessions.createIndex("byBackup", "backup", {unique: false});
}

View file

@ -28,7 +28,7 @@ export class AccountDataStore {
this._store = store; this._store = store;
} }
async get(type: string): Promise<AccountDataEntry | null> { async get(type: string): Promise<AccountDataEntry | undefined> {
return await this._store.get(type); return await this._store.get(type);
} }

View file

@ -17,7 +17,7 @@ limitations under the License.
import {MAX_UNICODE, MIN_UNICODE} from "./common"; import {MAX_UNICODE, MIN_UNICODE} from "./common";
import {Store} from "../Store"; import {Store} from "../Store";
interface DeviceIdentity { export interface DeviceIdentity {
userId: string; userId: string;
deviceId: string; deviceId: string;
ed25519Key: string; ed25519Key: string;
@ -65,7 +65,7 @@ export class DeviceIdentityStore {
return deviceIds; return deviceIds;
} }
get(userId: string, deviceId: string): Promise<DeviceIdentity | null> { get(userId: string, deviceId: string): Promise<DeviceIdentity | undefined> {
return this._store.get(encodeKey(userId, deviceId)); return this._store.get(encodeKey(userId, deviceId));
} }
@ -74,7 +74,7 @@ export class DeviceIdentityStore {
this._store.put(deviceIdentity); this._store.put(deviceIdentity);
} }
getByCurve25519Key(curve25519Key: string): Promise<DeviceIdentity | null> { getByCurve25519Key(curve25519Key: string): Promise<DeviceIdentity | undefined> {
return this._store.index("byCurve25519Key").get(curve25519Key); return this._store.index("byCurve25519Key").get(curve25519Key);
} }

View file

@ -35,7 +35,7 @@ export class GroupSessionDecryptionStore {
this._store = store; this._store = store;
} }
get(roomId: string, sessionId: string, messageIndex: number): Promise<GroupSessionDecryption | null> { get(roomId: string, sessionId: string, messageIndex: number): Promise<GroupSessionDecryption | undefined> {
return this._store.get(encodeKey(roomId, sessionId, messageIndex)); return this._store.get(encodeKey(roomId, sessionId, messageIndex));
} }

View file

@ -17,6 +17,17 @@ limitations under the License.
import {MIN_UNICODE, MAX_UNICODE} from "./common"; import {MIN_UNICODE, MAX_UNICODE} from "./common";
import {Store} from "../Store"; import {Store} from "../Store";
export enum BackupStatus {
NotBackedUp = 0,
BackedUp = 1
}
export enum KeySource {
DeviceMessage = 1,
Backup,
Outbound
}
export interface InboundGroupSessionEntry { export interface InboundGroupSessionEntry {
roomId: string; roomId: string;
senderKey: string; senderKey: string;
@ -24,6 +35,8 @@ export interface InboundGroupSessionEntry {
session?: string; session?: string;
claimedKeys?: { [algorithm : string] : string }; claimedKeys?: { [algorithm : string] : string };
eventIds?: string[]; eventIds?: string[];
backup: BackupStatus,
source: KeySource
} }
type InboundGroupSessionStorageEntry = InboundGroupSessionEntry & { key: string }; type InboundGroupSessionStorageEntry = InboundGroupSessionEntry & { key: string };
@ -46,7 +59,7 @@ export class InboundGroupSessionStore {
return key === fetchedKey; return key === fetchedKey;
} }
get(roomId: string, senderKey: string, sessionId: string): Promise<InboundGroupSessionEntry | null> { get(roomId: string, senderKey: string, sessionId: string): Promise<InboundGroupSessionEntry | undefined> {
return this._store.get(encodeKey(roomId, senderKey, sessionId)); return this._store.get(encodeKey(roomId, senderKey, sessionId));
} }
@ -63,4 +76,31 @@ export class InboundGroupSessionStore {
); );
this._store.delete(range); this._store.delete(range);
} }
countNonBackedUpSessions(): Promise<number> {
return this._store.index("byBackup").count(this._store.IDBKeyRange.only(BackupStatus.NotBackedUp));
}
getFirstNonBackedUpSessions(amount: number): Promise<InboundGroupSessionEntry[]> {
return this._store.index("byBackup").selectLimit(this._store.IDBKeyRange.only(BackupStatus.NotBackedUp), amount);
}
async markAsBackedUp(roomId: string, senderKey: string, sessionId: string): Promise<void> {
const entry = await this._store.get(encodeKey(roomId, senderKey, sessionId));
if (entry) {
entry.backup = BackupStatus.BackedUp;
this._store.put(entry);
}
}
async markAllAsNotBackedUp(): Promise<number> {
const backedUpKey = this._store.IDBKeyRange.only(BackupStatus.BackedUp);
let count = 0;
await this._store.index("byBackup").iterateValues(backedUpKey, (val: InboundGroupSessionEntry, key: IDBValidKey, cur: IDBCursorWithValue) => {
val.backup = BackupStatus.NotBackedUp;
cur.update(val);
count += 1;
return false;
});
return count;
}
} }

View file

@ -62,7 +62,7 @@ export class OlmSessionStore {
}); });
} }
get(senderKey: string, sessionId: string): Promise<OlmSession | null> { get(senderKey: string, sessionId: string): Promise<OlmSession | undefined> {
return this._store.get(encodeKey(senderKey, sessionId)); return this._store.get(encodeKey(senderKey, sessionId));
} }

View file

@ -32,7 +32,7 @@ export class OutboundGroupSessionStore {
this._store.delete(roomId); this._store.delete(roomId);
} }
get(roomId: string): Promise<OutboundSession | null> { get(roomId: string): Promise<OutboundSession | undefined> {
return this._store.get(roomId); return this._store.get(roomId);
} }

View file

@ -46,7 +46,7 @@ export class RoomMemberStore {
this._roomMembersStore = roomMembersStore; this._roomMembersStore = roomMembersStore;
} }
get(roomId: string, userId: string): Promise<MemberStorageEntry | null> { get(roomId: string, userId: string): Promise<MemberStorageEntry | undefined> {
return this._roomMembersStore.get(encodeKey(roomId, userId)); return this._roomMembersStore.get(encodeKey(roomId, userId));
} }

View file

@ -36,7 +36,7 @@ export class RoomStateStore {
this._roomStateStore = idbStore; this._roomStateStore = idbStore;
} }
get(roomId: string, type: string, stateKey: string): Promise<RoomStateEntry | null> { get(roomId: string, type: string, stateKey: string): Promise<RoomStateEntry | undefined> {
const key = encodeKey(roomId, type, stateKey); const key = encodeKey(roomId, type, stateKey);
return this._roomStateStore.get(key); return this._roomStateStore.get(key);
} }

View file

@ -301,11 +301,11 @@ export class TimelineEventStore {
this._timelineStore.put(entry as TimelineEventStorageEntry); this._timelineStore.put(entry as TimelineEventStorageEntry);
} }
get(roomId: string, eventKey: EventKey): Promise<TimelineEventEntry | null> { get(roomId: string, eventKey: EventKey): Promise<TimelineEventEntry | undefined> {
return this._timelineStore.get(encodeKey(roomId, eventKey.fragmentId, eventKey.eventIndex)); return this._timelineStore.get(encodeKey(roomId, eventKey.fragmentId, eventKey.eventIndex));
} }
getByEventId(roomId: string, eventId: string): Promise<TimelineEventEntry | null> { getByEventId(roomId: string, eventId: string): Promise<TimelineEventEntry | undefined> {
return this._timelineStore.index("byEventId").get(encodeEventIdKey(roomId, eventId)); return this._timelineStore.index("byEventId").get(encodeEventIdKey(roomId, eventId));
} }

View file

@ -83,7 +83,7 @@ export class TimelineFragmentStore {
this._store.put(fragment); this._store.put(fragment);
} }
get(roomId: string, fragmentId: number): Promise<FragmentEntry | null> { get(roomId: string, fragmentId: number): Promise<FragmentEntry | undefined> {
return this._store.get(encodeKey(roomId, fragmentId)); return this._store.get(encodeKey(roomId, fragmentId));
} }

View file

@ -28,7 +28,7 @@ export class UserIdentityStore {
this._store = store; this._store = store;
} }
get(userId: string): Promise<UserIdentity | null> { get(userId: string): Promise<UserIdentity | undefined> {
return this._store.get(userId); return this._store.get(userId);
} }

View file

@ -16,6 +16,7 @@ limitations under the License.
import {AbortError} from "../utils/error"; import {AbortError} from "../utils/error";
import {BaseObservable} from "./BaseObservable"; import {BaseObservable} from "./BaseObservable";
import type {SubscriptionHandle} from "./BaseObservable";
// like an EventEmitter, but doesn't have an event type // like an EventEmitter, but doesn't have an event type
export abstract class BaseObservableValue<T> extends BaseObservable<(value: T) => void> { export abstract class BaseObservableValue<T> extends BaseObservable<(value: T) => void> {
@ -34,6 +35,10 @@ export abstract class BaseObservableValue<T> extends BaseObservable<(value: T) =
return new WaitForHandle(this, predicate); return new WaitForHandle(this, predicate);
} }
} }
flatMap<C>(mapper: (value: T) => (BaseObservableValue<C> | undefined)): BaseObservableValue<C | undefined> {
return new FlatMapObservableValue<T, C>(this, mapper);
}
} }
interface IWaitHandle<T> { interface IWaitHandle<T> {
@ -114,6 +119,61 @@ export class RetainedObservableValue<T> extends ObservableValue<T> {
} }
} }
export class FlatMapObservableValue<P, C> extends BaseObservableValue<C | undefined> {
private sourceSubscription?: SubscriptionHandle;
private targetSubscription?: SubscriptionHandle;
constructor(
private readonly source: BaseObservableValue<P>,
private readonly mapper: (value: P) => (BaseObservableValue<C> | undefined)
) {
super();
}
onUnsubscribeLast() {
super.onUnsubscribeLast();
this.sourceSubscription = this.sourceSubscription!();
if (this.targetSubscription) {
this.targetSubscription = this.targetSubscription();
}
}
onSubscribeFirst() {
super.onSubscribeFirst();
this.sourceSubscription = this.source.subscribe(() => {
this.updateTargetSubscription();
this.emit(this.get());
});
this.updateTargetSubscription();
}
private updateTargetSubscription() {
const sourceValue = this.source.get();
if (sourceValue) {
const target = this.mapper(sourceValue);
if (target) {
if (!this.targetSubscription) {
this.targetSubscription = target.subscribe(() => this.emit(this.get()));
}
return;
}
}
// if no sourceValue or target
if (this.targetSubscription) {
this.targetSubscription = this.targetSubscription();
}
}
get(): C | undefined {
const sourceValue = this.source.get();
if (!sourceValue) {
return undefined;
}
const mapped = this.mapper(sourceValue);
return mapped?.get();
}
}
export function tests() { export function tests() {
return { return {
"set emits an update": assert => { "set emits an update": assert => {
@ -155,5 +215,34 @@ export function tests() {
}); });
await assert.rejects(handle.promise, AbortError); await assert.rejects(handle.promise, AbortError);
}, },
"flatMap.get": assert => {
const a = new ObservableValue<undefined | {count: ObservableValue<number>}>(undefined);
const countProxy = a.flatMap(a => a!.count);
assert.strictEqual(countProxy.get(), undefined);
const count = new ObservableValue<number>(0);
a.set({count});
assert.strictEqual(countProxy.get(), 0);
},
"flatMap update from source": assert => {
const a = new ObservableValue<undefined | {count: ObservableValue<number>}>(undefined);
const updates: (number | undefined)[] = [];
a.flatMap(a => a!.count).subscribe(count => {
updates.push(count);
});
const count = new ObservableValue<number>(0);
a.set({count});
assert.deepEqual(updates, [0]);
},
"flatMap update from target": assert => {
const a = new ObservableValue<undefined | {count: ObservableValue<number>}>(undefined);
const updates: (number | undefined)[] = [];
a.flatMap(a => a!.count).subscribe(count => {
updates.push(count);
});
const count = new ObservableValue<number>(0);
a.set({count});
count.set(5);
assert.deepEqual(updates, [0, 5]);
}
} }
} }

View file

@ -15,13 +15,13 @@ limitations under the License.
*/ */
import {TemplateView} from "../general/TemplateView"; import {TemplateView} from "../general/TemplateView";
import {SessionBackupSettingsView} from "../session/settings/SessionBackupSettingsView.js"; import {KeyBackupSettingsView} from "../session/settings/KeyBackupSettingsView.js";
export class AccountSetupView extends TemplateView { export class AccountSetupView extends TemplateView {
render(t, vm) { render(t, vm) {
return t.div({className: "Settings" /* hack for now to get the layout right*/}, [ return t.div({className: "Settings" /* hack for now to get the layout right*/}, [
t.h3(vm.i18n`Restore your encrypted history?`), t.h3(vm.i18n`Restore your encrypted history?`),
t.ifView(vm => vm.decryptDehydratedDeviceViewModel, vm => new SessionBackupSettingsView(vm.decryptDehydratedDeviceViewModel)), t.ifView(vm => vm.decryptDehydratedDeviceViewModel, vm => new KeyBackupSettingsView(vm.decryptDehydratedDeviceViewModel)),
t.map(vm => vm.deviceDecrypted, (decrypted, t) => { t.map(vm => vm.deviceDecrypted, (decrypted, t) => {
if (decrypted) { if (decrypted) {
return t.p(vm.i18n`That worked out, you're good to go!`); return t.p(vm.i18n`That worked out, you're good to go!`);

View file

@ -26,7 +26,7 @@ export class SessionStatusView extends TemplateView {
spinner(t, {hidden: vm => !vm.isWaiting}), spinner(t, {hidden: vm => !vm.isWaiting}),
t.p(vm => vm.statusLabel), t.p(vm => vm.statusLabel),
t.if(vm => vm.isConnectNowShown, t => t.button({className: "link", onClick: () => vm.connectNow()}, "Retry now")), t.if(vm => vm.isConnectNowShown, t => t.button({className: "link", onClick: () => vm.connectNow()}, "Retry now")),
t.if(vm => vm.isSecretStorageShown, t => t.a({href: vm.setupSessionBackupUrl}, "Go to settings")), t.if(vm => vm.isSecretStorageShown, t => t.a({href: vm.setupKeyBackupUrl}, "Go to settings")),
t.if(vm => vm.canDismiss, t => t.div({className: "end"}, t.button({className: "dismiss", onClick: () => vm.dismiss()}))), t.if(vm => vm.canDismiss, t => t.div({className: "end"}, t.button({className: "dismiss", onClick: () => vm.dismiss()}))),
]); ]);
} }

View file

@ -42,7 +42,8 @@ export class TextMessageView extends BaseMessageView {
} }
})); }));
const shouldRemove = (element) => element?.nodeType === Node.ELEMENT_NODE && element.className !== "ReplyPreviewView"; // exclude comment nodes as they are used by t.map and friends for placeholders
const shouldRemove = (element) => element?.nodeType !== Node.COMMENT_NODE && element.className !== "ReplyPreviewView";
t.mapSideEffect(vm => vm.body, body => { t.mapSideEffect(vm => vm.body, body => {
while (shouldRemove(container.lastChild)) { while (shouldRemove(container.lastChild)) {

View file

@ -14,25 +14,53 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {TemplateView, InlineTemplateView} from "../../general/TemplateView"; import {TemplateView} from "../../general/TemplateView";
import {StaticView} from "../../general/StaticView.js";
export class SessionBackupSettingsView extends TemplateView { export class KeyBackupSettingsView extends TemplateView {
render(t, vm) { render(t) {
return t.mapView(vm => vm.status, status => { return t.div([
t.map(vm => vm.status, (status, t, vm) => {
switch (status) { switch (status) {
case "Enabled": return new InlineTemplateView(vm, renderEnabled) case "Enabled": return renderEnabled(t, vm);
case "SetupKey": return new InlineTemplateView(vm, renderEnableFromKey) case "NewVersionAvailable": return renderNewVersionAvailable(t, vm);
case "SetupPhrase": return new InlineTemplateView(vm, renderEnableFromPhrase) case "SetupKey": return renderEnableFromKey(t, vm);
case "Pending": return new StaticView(vm, t => t.p(vm.i18n`Waiting to go online…`)) case "SetupPhrase": return renderEnableFromPhrase(t, vm);
case "Pending": return t.p(vm.i18n`Waiting to go online…`);
} }
}),
t.map(vm => vm.backupWriteStatus, (status, t, vm) => {
switch (status) {
case "Writing": {
const progress = t.progress({
min: 0,
max: 100,
value: vm => vm.backupPercentage,
}); });
return t.div([`Backup in progress `, progress, " ", vm => vm.backupInProgressLabel]);
}
case "Stopped": {
let label;
const error = vm.backupError;
if (error) {
label = `Backup has stopped because of an error: ${vm.backupError}`;
} else {
label = `Backup has stopped`;
}
return t.p(label, " ", t.button({onClick: () => vm.startBackup()}, `Backup now`));
}
case "Done":
return t.p(`All keys are backed up.`);
default:
return null;
}
})
]);
} }
} }
function renderEnabled(t, vm) { function renderEnabled(t, vm) {
const items = [ const items = [
t.p([vm.i18n`Session backup is enabled, using backup version ${vm.backupVersion}. `, t.button({onClick: () => vm.disable()}, vm.i18n`Disable`)]) t.p([vm.i18n`Key backup is enabled, using backup version ${vm.backupVersion}. `, t.button({onClick: () => vm.disable()}, vm.i18n`Disable`)])
]; ];
if (vm.dehydratedDeviceId) { if (vm.dehydratedDeviceId) {
items.push(t.p(vm.i18n`A dehydrated device id was set up with id ${vm.dehydratedDeviceId} which you can use during your next login with your secret storage key.`)); items.push(t.p(vm.i18n`A dehydrated device id was set up with id ${vm.dehydratedDeviceId} which you can use during your next login with your secret storage key.`));
@ -40,6 +68,13 @@ function renderEnabled(t, vm) {
return t.div(items); return t.div(items);
} }
function renderNewVersionAvailable(t, vm) {
const items = [
t.p([vm.i18n`A new backup version has been created from another device. Disable key backup and enable it again with the new key.`, t.button({onClick: () => vm.disable()}, vm.i18n`Disable`)])
];
return t.div(items);
}
function renderEnableFromKey(t, vm) { function renderEnableFromKey(t, vm) {
const useASecurityPhrase = t.button({className: "link", onClick: () => vm.showPhraseSetup()}, vm.i18n`use a security phrase`); const useASecurityPhrase = t.button({className: "link", onClick: () => vm.showPhraseSetup()}, vm.i18n`use a security phrase`);
return t.div([ return t.div([
@ -87,7 +122,7 @@ function renderEnableFieldRow(t, vm, label, callback) {
function renderError(t) { function renderError(t) {
return t.if(vm => vm.error, (t, vm) => { return t.if(vm => vm.error, (t, vm) => {
return t.div([ return t.div([
t.p({className: "error"}, vm => vm.i18n`Could not enable session backup: ${vm.error}.`), t.p({className: "error"}, vm => vm.i18n`Could not enable key backup: ${vm.error}.`),
t.p(vm.i18n`Try double checking that you did not mix up your security key, security phrase and login password as explained above.`) t.p(vm.i18n`Try double checking that you did not mix up your security key, security phrase and login password as explained above.`)
]) ])
}); });

View file

@ -15,7 +15,7 @@ limitations under the License.
*/ */
import {TemplateView} from "../../general/TemplateView"; import {TemplateView} from "../../general/TemplateView";
import {SessionBackupSettingsView} from "./SessionBackupSettingsView.js" import {KeyBackupSettingsView} from "./KeyBackupSettingsView.js"
export class SettingsView extends TemplateView { export class SettingsView extends TemplateView {
render(t, vm) { render(t, vm) {
@ -47,8 +47,8 @@ export class SettingsView extends TemplateView {
}, vm.i18n`Log out`)), }, vm.i18n`Log out`)),
); );
settingNodes.push( settingNodes.push(
t.h3("Session Backup"), t.h3("Key backup"),
t.view(new SessionBackupSettingsView(vm.sessionBackupViewModel)) t.view(new KeyBackupSettingsView(vm.keyBackupViewModel))
); );
settingNodes.push( settingNodes.push(

View file

@ -14,27 +14,40 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
import {BaseObservableValue, ObservableValue} from "../observable/ObservableValue";
export interface IAbortable { export interface IAbortable {
abort(); abort();
} }
type RunFn<T> = (setAbortable: (a: IAbortable) => typeof a) => T; export type SetAbortableFn = (a: IAbortable) => typeof a;
export type SetProgressFn<P> = (progress: P) => void;
type RunFn<T, P> = (setAbortable: SetAbortableFn, setProgress: SetProgressFn<P>) => T;
export class AbortableOperation<T> { export class AbortableOperation<T, P = void> implements IAbortable {
public readonly result: T; public readonly result: T;
private _abortable: IAbortable | null; private _abortable?: IAbortable;
private _progress: ObservableValue<P | undefined>;
constructor(run: RunFn<T>) { constructor(run: RunFn<T, P>) {
this._abortable = null; this._abortable = undefined;
const setAbortable = abortable => { const setAbortable: SetAbortableFn = abortable => {
this._abortable = abortable; this._abortable = abortable;
return abortable; return abortable;
}; };
this.result = run(setAbortable); this._progress = new ObservableValue<P | undefined>(undefined);
const setProgress: SetProgressFn<P> = (progress: P) => {
this._progress.set(progress);
};
this.result = run(setAbortable, setProgress);
}
get progress(): BaseObservableValue<P | undefined> {
return this._progress;
} }
abort() { abort() {
this._abortable?.abort(); this._abortable?.abort();
this._abortable = null; this._abortable = undefined;
} }
} }