Merge pull request #105 from vector-im/bwindels/fix-otk-sync-race

Fix race between next /sync and upload OTKs
This commit is contained in:
Bruno Windels 2020-09-21 16:04:30 +00:00 committed by GitHub
commit f84f06758c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 20 deletions

View file

@ -384,16 +384,22 @@ export class Session {
// sync transaction succeeded, modify object state now
this._syncInfo = syncInfo;
}
if (this._e2eeAccount && e2eeAccountChanges) {
if (this._e2eeAccount) {
this._e2eeAccount.afterSync(e2eeAccountChanges);
}
}
async afterSyncCompleted() {
const needsToUploadOTKs = await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
async afterSyncCompleted(isCatchupSync) {
const promises = [this._deviceMessageHandler.decryptPending(this.rooms)];
if (needsToUploadOTKs) {
promises.push(this._e2eeAccount.uploadKeys(this._storage));
// we don't start uploading one-time keys until we've caught up with
// to-device messages, to help us avoid throwing away one-time-keys that we
// are about to receive messages for
// (https://github.com/vector-im/riot-web/issues/2782).
if (!isCatchupSync) {
const needsToUploadOTKs = await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
if (needsToUploadOTKs) {
promises.push(this._e2eeAccount.uploadKeys(this._storage));
}
}
// run key upload and decryption in parallel
await Promise.all(promises);

View file

@ -93,17 +93,32 @@ export class Sync {
}
async _syncLoop(syncToken) {
let afterSyncCompletedPromise = Promise.resolve();
// if syncToken is falsy, it will first do an initial sync ...
while(this._status.get() !== SyncStatus.Stopped) {
let roomStates;
try {
console.log(`starting sync request with since ${syncToken} ...`);
const timeout = syncToken ? INCREMENTAL_TIMEOUT : undefined;
const syncResult = await this._syncRequest(syncToken, timeout, afterSyncCompletedPromise);
// unless we are happily syncing already, we want the server to return
// as quickly as possible, even if there are no events queued. This
// serves two purposes:
//
// * When the connection dies, we want to know asap when it comes back,
// so that we can hide the error from the user. (We don't want to
// have to wait for an event or a timeout).
//
// * We want to know if the server has any to_device messages queued up
// for us. We do that by calling it with a zero timeout until it
// doesn't give us any more to_device messages.
const timeout = this._status.get() === SyncStatus.Syncing ? INCREMENTAL_TIMEOUT : 0;
const syncResult = await this._syncRequest(syncToken, timeout);
syncToken = syncResult.syncToken;
roomStates = syncResult.roomStates;
this._status.set(SyncStatus.Syncing);
// initial sync or catchup sync
if (this._status.get() !== SyncStatus.Syncing && syncResult.hadToDeviceMessages) {
this._status.set(SyncStatus.CatchupSync);
} else {
this._status.set(SyncStatus.Syncing);
}
} catch (err) {
if (!(err instanceof AbortError)) {
console.warn("stopping sync because of error");
@ -113,15 +128,16 @@ export class Sync {
}
}
if (this._status.get() !== SyncStatus.Stopped) {
afterSyncCompletedPromise = this._runAfterSyncCompleted(roomStates);
await this._runAfterSyncCompleted(roomStates);
}
}
}
async _runAfterSyncCompleted(roomStates) {
const isCatchupSync = this._status.get() === SyncStatus.CatchupSync;
const sessionPromise = (async () => {
try {
await this._session.afterSyncCompleted();
await this._session.afterSyncCompleted(isCatchupSync);
} catch (err) {
console.error("error during session afterSyncCompleted, continuing", err.stack);
}
@ -144,7 +160,7 @@ export class Sync {
await Promise.all(roomsPromises.concat(sessionPromise));
}
async _syncRequest(syncToken, timeout, prevAfterSyncCompletedPromise) {
async _syncRequest(syncToken, timeout) {
let {syncFilterId} = this._session;
if (typeof syncFilterId !== "string") {
this._currentRequest = this._hsApi.createFilter(this._session.user.id, {room: {state: {lazy_load_members: true}}});
@ -153,9 +169,6 @@ export class Sync {
const totalRequestTimeout = timeout + (80 * 1000); // same as riot-web, don't get stuck on wedged long requests
this._currentRequest = this._hsApi.sync(syncToken, syncFilterId, timeout, {timeout: totalRequestTimeout});
const response = await this._currentRequest.response();
// wait here for the afterSyncCompleted step of the previous sync to complete
// before we continue processing this sync response
await prevAfterSyncCompletedPromise;
const isInitialSync = !syncToken;
syncToken = response.next_batch;
@ -190,7 +203,12 @@ export class Sync {
rs.room.afterSync(rs.changes);
}
return {syncToken, roomStates};
const toDeviceEvents = response.to_device?.events;
return {
syncToken,
roomStates,
hadToDeviceMessages: Array.isArray(toDeviceEvents) && toDeviceEvents.length > 0,
};
}
async _openPrepareSyncTxn() {

View file

@ -171,7 +171,7 @@ export class Account {
writeSync(deviceOneTimeKeysCount, txn) {
// we only upload signed_curve25519 otks
const otkCount = deviceOneTimeKeysCount.signed_curve25519;
const otkCount = deviceOneTimeKeysCount.signed_curve25519 || 0;
if (Number.isSafeInteger(otkCount) && otkCount !== this._serverOTKCount) {
txn.session.set(SERVER_OTK_COUNT_SESSION_KEY, otkCount);
return otkCount;

View file

@ -115,7 +115,12 @@ export class Decryption {
}
// could not decrypt with any existing session
if (typeof plaintext !== "string" && isPreKeyMessage(message)) {
const createResult = this._createSessionAndDecrypt(senderKey, message, timestamp);
let createResult;
try {
createResult = this._createSessionAndDecrypt(senderKey, message, timestamp);
} catch (error) {
throw new DecryptionError(`Could not create inbound olm session: ${error.message}`, event, {senderKey, error});
}
senderKeyDecryption.addNewSession(createResult.session);
plaintext = createResult.plaintext;
}
@ -123,8 +128,8 @@ export class Decryption {
let payload;
try {
payload = JSON.parse(plaintext);
} catch (err) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err});
} catch (error) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error});
}
this._validatePayload(payload, event);
return new DecryptionResult(payload, senderKey, payload.keys);