Merge pull request #76 from vector-im/bwindels/maintain-otks

Maintain OTKs above max/2
This commit is contained in:
Bruno Windels 2020-08-28 12:03:29 +00:00 committed by GitHub
commit 2b6530b459
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 20 deletions

View file

@ -60,6 +60,7 @@ export class Session {
} }
await txn.complete(); await txn.complete();
} }
await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
await this._e2eeAccount.uploadKeys(this._storage); await this._e2eeAccount.uploadKeys(this._storage);
} }
} }
@ -151,20 +152,40 @@ export class Session {
return room; return room;
} }
writeSync(syncToken, syncFilterId, accountData, txn) { writeSync(syncResponse, syncFilterId, txn) {
const changes = {};
const syncToken = syncResponse.next_batch;
const deviceOneTimeKeysCount = syncResponse.device_one_time_keys_count;
if (this._e2eeAccount && deviceOneTimeKeysCount) {
changes.e2eeAccountChanges = this._e2eeAccount.writeSync(deviceOneTimeKeysCount, txn);
}
if (syncToken !== this.syncToken) { if (syncToken !== this.syncToken) {
const syncInfo = {token: syncToken, filterId: syncFilterId}; const syncInfo = {token: syncToken, filterId: syncFilterId};
// don't modify `this` because transaction might still fail // don't modify `this` because transaction might still fail
txn.session.set("sync", syncInfo); txn.session.set("sync", syncInfo);
return syncInfo; changes.syncInfo = syncInfo;
} }
return changes;
} }
afterSync(syncInfo) { afterSync({syncInfo, e2eeAccountChanges}) {
if (syncInfo) { if (syncInfo) {
// sync transaction succeeded, modify object state now // sync transaction succeeded, modify object state now
this._syncInfo = syncInfo; this._syncInfo = syncInfo;
} }
if (this._e2eeAccount && e2eeAccountChanges) {
this._e2eeAccount.afterSync(e2eeAccountChanges);
}
}
async afterSyncCompleted() {
const needsToUploadOTKs = await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
if (needsToUploadOTKs) {
// TODO: we could do this in parallel with sync if it proves to be too slow
// but I'm not sure how to not swallow errors in that case
await this._e2eeAccount.uploadKeys(this._storage);
}
} }
get syncToken() { get syncToken() {

View file

@ -237,10 +237,16 @@ export class SessionContainer {
} }
stop() { stop() {
this._reconnectSubscription(); if (this._reconnectSubscription) {
this._reconnectSubscription = null; this._reconnectSubscription();
this._sync.stop(); this._reconnectSubscription = null;
this._session.stop(); }
if (this._sync) {
this._sync.stop();
}
if (this._session) {
this._session.stop();
}
if (this._waitForFirstSyncHandle) { if (this._waitForFirstSyncHandle) {
this._waitForFirstSyncHandle.dispose(); this._waitForFirstSyncHandle.dispose();
this._waitForFirstSyncHandle = null; this._waitForFirstSyncHandle = null;

View file

@ -100,6 +100,12 @@ export class Sync {
this._status.set(SyncStatus.Stopped); this._status.set(SyncStatus.Stopped);
} }
} }
try {
await this._session.afterSyncCompleted();
} catch (err) {
console.err("error during after sync completed, continuing to sync.", err.stack);
// swallowing error here apart from logging
}
} }
} }
@ -127,7 +133,7 @@ export class Sync {
const roomChanges = []; const roomChanges = [];
let sessionChanges; let sessionChanges;
try { try {
sessionChanges = this._session.writeSync(syncToken, syncFilterId, response.account_data, syncTxn); sessionChanges = this._session.writeSync(response, syncFilterId, syncTxn);
// to_device // to_device
// presence // presence
if (response.rooms) { if (response.rooms) {

View file

@ -16,8 +16,13 @@ limitations under the License.
import anotherjson from "../../../lib/another-json/index.js"; import anotherjson from "../../../lib/another-json/index.js";
const ACCOUNT_SESSION_KEY = "olmAccount"; // use common prefix so it's easy to clear properties that are not e2ee related during session clear
const DEVICE_KEY_FLAG_SESSION_KEY = "areDeviceKeysUploaded"; export const SESSION_KEY_PREFIX = "e2ee:";
const ACCOUNT_SESSION_KEY = SESSION_KEY_PREFIX + "olmAccount";
const DEVICE_KEY_FLAG_SESSION_KEY = SESSION_KEY_PREFIX + "areDeviceKeysUploaded";
const SERVER_OTK_COUNT_SESSION_KEY = SESSION_KEY_PREFIX + "serverOTKCount";
const OLM_ALGORITHM = "m.olm.v1.curve25519-aes-sha2";
const MEGOLM_ALGORITHM = "m.megolm.v1.aes-sha2";
export class Account { export class Account {
static async load({olm, pickleKey, hsApi, userId, deviceId, txn}) { static async load({olm, pickleKey, hsApi, userId, deviceId, txn}) {
@ -26,7 +31,9 @@ export class Account {
const account = new olm.Account(); const account = new olm.Account();
const areDeviceKeysUploaded = await txn.session.get(DEVICE_KEY_FLAG_SESSION_KEY); const areDeviceKeysUploaded = await txn.session.get(DEVICE_KEY_FLAG_SESSION_KEY);
account.unpickle(pickleKey, pickledAccount); account.unpickle(pickleKey, pickledAccount);
return new Account({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded}); const serverOTKCount = await txn.session.get(SERVER_OTK_COUNT_SESSION_KEY);
return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount});
} }
} }
@ -40,16 +47,19 @@ export class Account {
const areDeviceKeysUploaded = false; const areDeviceKeysUploaded = false;
await txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount); await txn.session.add(ACCOUNT_SESSION_KEY, pickledAccount);
await txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded); await txn.session.add(DEVICE_KEY_FLAG_SESSION_KEY, areDeviceKeysUploaded);
return new Account({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded}); await txn.session.add(SERVER_OTK_COUNT_SESSION_KEY, 0);
return new Account({pickleKey, hsApi, account, userId,
deviceId, areDeviceKeysUploaded, serverOTKCount: 0});
} }
constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded}) { constructor({pickleKey, hsApi, account, userId, deviceId, areDeviceKeysUploaded, serverOTKCount}) {
this._pickleKey = pickleKey; this._pickleKey = pickleKey;
this._hsApi = hsApi; this._hsApi = hsApi;
this._account = account; this._account = account;
this._userId = userId; this._userId = userId;
this._deviceId = deviceId; this._deviceId = deviceId;
this._areDeviceKeysUploaded = areDeviceKeysUploaded; this._areDeviceKeysUploaded = areDeviceKeysUploaded;
this._serverOTKCount = serverOTKCount;
} }
async uploadKeys(storage) { async uploadKeys(storage) {
@ -65,12 +75,17 @@ export class Account {
if (oneTimeKeysEntries.length) { if (oneTimeKeysEntries.length) {
payload.one_time_keys = this._oneTimeKeysPayload(oneTimeKeysEntries); payload.one_time_keys = this._oneTimeKeysPayload(oneTimeKeysEntries);
} }
await this._hsApi.uploadKeys(payload); const response = await this._hsApi.uploadKeys(payload).response();
this._serverOTKCount = response?.one_time_key_counts?.signed_curve25519;
// TODO: should we not modify this in the txn like we do elsewhere?
// we'd have to pickle and unpickle the account to clone it though ...
// and the upload has succeed at this point, so in-memory would be correct
// but in-storage not if the txn fails.
await this._updateSessionStorage(storage, sessionStore => { await this._updateSessionStorage(storage, sessionStore => {
if (oneTimeKeysEntries.length) { if (oneTimeKeysEntries.length) {
this._account.mark_keys_as_published(); this._account.mark_keys_as_published();
sessionStore.set(ACCOUNT_SESSION_KEY, this._account.pickle(this._pickleKey)); sessionStore.set(ACCOUNT_SESSION_KEY, this._account.pickle(this._pickleKey));
sessionStore.set(SERVER_OTK_COUNT_SESSION_KEY, this._serverOTKCount);
} }
if (!this._areDeviceKeysUploaded) { if (!this._areDeviceKeysUploaded) {
this._areDeviceKeysUploaded = true; this._areDeviceKeysUploaded = true;
@ -80,14 +95,52 @@ export class Account {
} }
} }
async generateOTKsIfNeeded(storage) {
const maxOTKs = this._account.max_number_of_one_time_keys();
const limit = maxOTKs / 2;
if (this._serverOTKCount < limit) {
// TODO: cache unpublishedOTKCount, so we don't have to parse this JSON on every sync iteration
// for now, we only determine it when serverOTKCount is sufficiently low, which is should rarely be,
// and recheck
const oneTimeKeys = JSON.parse(this._account.one_time_keys());
const oneTimeKeysEntries = Object.entries(oneTimeKeys.curve25519);
const unpublishedOTKCount = oneTimeKeysEntries.length;
const totalOTKCount = this._serverOTKCount + unpublishedOTKCount;
if (totalOTKCount < limit) {
// we could in theory also generated the keys and store them in
// writeSync, but then we would have to clone the account to avoid side-effects.
await this._updateSessionStorage(storage, sessionStore => {
const newKeyCount = maxOTKs - totalOTKCount;
this._account.generate_one_time_keys(newKeyCount);
sessionStore.set(ACCOUNT_SESSION_KEY, this._account.pickle(this._pickleKey));
});
return true;
}
}
return false;
}
writeSync(deviceOneTimeKeysCount, txn) {
// we only upload signed_curve25519 otks
const otkCount = deviceOneTimeKeysCount.signed_curve25519;
if (Number.isSafeInteger(otkCount) && otkCount !== this._serverOTKCount) {
txn.session.set(SERVER_OTK_COUNT_SESSION_KEY, otkCount);
return otkCount;
}
}
afterSync(otkCount) {
// could also be undefined
if (Number.isSafeInteger(otkCount)) {
this._serverOTKCount = otkCount;
}
}
_deviceKeysPayload(identityKeys) { _deviceKeysPayload(identityKeys) {
const obj = { const obj = {
user_id: this._userId, user_id: this._userId,
device_id: this._deviceId, device_id: this._deviceId,
algorithms: [ algorithms: [OLM_ALGORITHM, MEGOLM_ALGORITHM],
"m.olm.v1.curve25519-aes-sha2",
"m.megolm.v1.aes-sha2"
],
keys: {} keys: {}
}; };
for (const [algorithm, pubKey] of Object.entries(identityKeys)) { for (const [algorithm, pubKey] of Object.entries(identityKeys)) {
@ -114,7 +167,7 @@ export class Account {
storage.storeNames.session storage.storeNames.session
]); ]);
try { try {
callback(txn.session); await callback(txn.session);
} catch (err) { } catch (err) {
txn.abort(); txn.abort();
throw err; throw err;