Merge pull request #105 from vector-im/bwindels/fix-otk-sync-race

Fix race between next /sync and upload OTKs
This commit is contained in:
Bruno Windels 2020-09-21 16:04:30 +00:00 committed by GitHub
commit f84f06758c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 20 deletions

View file

@ -384,16 +384,22 @@ export class Session {
// sync transaction succeeded, modify object state now // sync transaction succeeded, modify object state now
this._syncInfo = syncInfo; this._syncInfo = syncInfo;
} }
if (this._e2eeAccount && e2eeAccountChanges) { if (this._e2eeAccount) {
this._e2eeAccount.afterSync(e2eeAccountChanges); this._e2eeAccount.afterSync(e2eeAccountChanges);
} }
} }
async afterSyncCompleted() { async afterSyncCompleted(isCatchupSync) {
const needsToUploadOTKs = await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
const promises = [this._deviceMessageHandler.decryptPending(this.rooms)]; const promises = [this._deviceMessageHandler.decryptPending(this.rooms)];
if (needsToUploadOTKs) { // we don't start uploading one-time keys until we've caught up with
promises.push(this._e2eeAccount.uploadKeys(this._storage)); // to-device messages, to help us avoid throwing away one-time-keys that we
// are about to receive messages for
// (https://github.com/vector-im/riot-web/issues/2782).
if (!isCatchupSync) {
const needsToUploadOTKs = await this._e2eeAccount.generateOTKsIfNeeded(this._storage);
if (needsToUploadOTKs) {
promises.push(this._e2eeAccount.uploadKeys(this._storage));
}
} }
// run key upload and decryption in parallel // run key upload and decryption in parallel
await Promise.all(promises); await Promise.all(promises);

View file

@ -93,17 +93,32 @@ export class Sync {
} }
async _syncLoop(syncToken) { async _syncLoop(syncToken) {
let afterSyncCompletedPromise = Promise.resolve();
// if syncToken is falsy, it will first do an initial sync ... // if syncToken is falsy, it will first do an initial sync ...
while(this._status.get() !== SyncStatus.Stopped) { while(this._status.get() !== SyncStatus.Stopped) {
let roomStates; let roomStates;
try { try {
console.log(`starting sync request with since ${syncToken} ...`); console.log(`starting sync request with since ${syncToken} ...`);
const timeout = syncToken ? INCREMENTAL_TIMEOUT : undefined; // unless we are happily syncing already, we want the server to return
const syncResult = await this._syncRequest(syncToken, timeout, afterSyncCompletedPromise); // as quickly as possible, even if there are no events queued. This
// serves two purposes:
//
// * When the connection dies, we want to know asap when it comes back,
// so that we can hide the error from the user. (We don't want to
// have to wait for an event or a timeout).
//
// * We want to know if the server has any to_device messages queued up
// for us. We do that by calling it with a zero timeout until it
// doesn't give us any more to_device messages.
const timeout = this._status.get() === SyncStatus.Syncing ? INCREMENTAL_TIMEOUT : 0;
const syncResult = await this._syncRequest(syncToken, timeout);
syncToken = syncResult.syncToken; syncToken = syncResult.syncToken;
roomStates = syncResult.roomStates; roomStates = syncResult.roomStates;
this._status.set(SyncStatus.Syncing); // initial sync or catchup sync
if (this._status.get() !== SyncStatus.Syncing && syncResult.hadToDeviceMessages) {
this._status.set(SyncStatus.CatchupSync);
} else {
this._status.set(SyncStatus.Syncing);
}
} catch (err) { } catch (err) {
if (!(err instanceof AbortError)) { if (!(err instanceof AbortError)) {
console.warn("stopping sync because of error"); console.warn("stopping sync because of error");
@ -113,15 +128,16 @@ export class Sync {
} }
} }
if (this._status.get() !== SyncStatus.Stopped) { if (this._status.get() !== SyncStatus.Stopped) {
afterSyncCompletedPromise = this._runAfterSyncCompleted(roomStates); await this._runAfterSyncCompleted(roomStates);
} }
} }
} }
async _runAfterSyncCompleted(roomStates) { async _runAfterSyncCompleted(roomStates) {
const isCatchupSync = this._status.get() === SyncStatus.CatchupSync;
const sessionPromise = (async () => { const sessionPromise = (async () => {
try { try {
await this._session.afterSyncCompleted(); await this._session.afterSyncCompleted(isCatchupSync);
} catch (err) { } catch (err) {
console.error("error during session afterSyncCompleted, continuing", err.stack); console.error("error during session afterSyncCompleted, continuing", err.stack);
} }
@ -144,7 +160,7 @@ export class Sync {
await Promise.all(roomsPromises.concat(sessionPromise)); await Promise.all(roomsPromises.concat(sessionPromise));
} }
async _syncRequest(syncToken, timeout, prevAfterSyncCompletedPromise) { async _syncRequest(syncToken, timeout) {
let {syncFilterId} = this._session; let {syncFilterId} = this._session;
if (typeof syncFilterId !== "string") { if (typeof syncFilterId !== "string") {
this._currentRequest = this._hsApi.createFilter(this._session.user.id, {room: {state: {lazy_load_members: true}}}); this._currentRequest = this._hsApi.createFilter(this._session.user.id, {room: {state: {lazy_load_members: true}}});
@ -153,9 +169,6 @@ export class Sync {
const totalRequestTimeout = timeout + (80 * 1000); // same as riot-web, don't get stuck on wedged long requests const totalRequestTimeout = timeout + (80 * 1000); // same as riot-web, don't get stuck on wedged long requests
this._currentRequest = this._hsApi.sync(syncToken, syncFilterId, timeout, {timeout: totalRequestTimeout}); this._currentRequest = this._hsApi.sync(syncToken, syncFilterId, timeout, {timeout: totalRequestTimeout});
const response = await this._currentRequest.response(); const response = await this._currentRequest.response();
// wait here for the afterSyncCompleted step of the previous sync to complete
// before we continue processing this sync response
await prevAfterSyncCompletedPromise;
const isInitialSync = !syncToken; const isInitialSync = !syncToken;
syncToken = response.next_batch; syncToken = response.next_batch;
@ -190,7 +203,12 @@ export class Sync {
rs.room.afterSync(rs.changes); rs.room.afterSync(rs.changes);
} }
return {syncToken, roomStates}; const toDeviceEvents = response.to_device?.events;
return {
syncToken,
roomStates,
hadToDeviceMessages: Array.isArray(toDeviceEvents) && toDeviceEvents.length > 0,
};
} }
async _openPrepareSyncTxn() { async _openPrepareSyncTxn() {

View file

@ -171,7 +171,7 @@ export class Account {
writeSync(deviceOneTimeKeysCount, txn) { writeSync(deviceOneTimeKeysCount, txn) {
// we only upload signed_curve25519 otks // we only upload signed_curve25519 otks
const otkCount = deviceOneTimeKeysCount.signed_curve25519; const otkCount = deviceOneTimeKeysCount.signed_curve25519 || 0;
if (Number.isSafeInteger(otkCount) && otkCount !== this._serverOTKCount) { if (Number.isSafeInteger(otkCount) && otkCount !== this._serverOTKCount) {
txn.session.set(SERVER_OTK_COUNT_SESSION_KEY, otkCount); txn.session.set(SERVER_OTK_COUNT_SESSION_KEY, otkCount);
return otkCount; return otkCount;

View file

@ -115,7 +115,12 @@ export class Decryption {
} }
// could not decrypt with any existing session // could not decrypt with any existing session
if (typeof plaintext !== "string" && isPreKeyMessage(message)) { if (typeof plaintext !== "string" && isPreKeyMessage(message)) {
const createResult = this._createSessionAndDecrypt(senderKey, message, timestamp); let createResult;
try {
createResult = this._createSessionAndDecrypt(senderKey, message, timestamp);
} catch (error) {
throw new DecryptionError(`Could not create inbound olm session: ${error.message}`, event, {senderKey, error});
}
senderKeyDecryption.addNewSession(createResult.session); senderKeyDecryption.addNewSession(createResult.session);
plaintext = createResult.plaintext; plaintext = createResult.plaintext;
} }
@ -123,8 +128,8 @@ export class Decryption {
let payload; let payload;
try { try {
payload = JSON.parse(plaintext); payload = JSON.parse(plaintext);
} catch (err) { } catch (error) {
throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, err}); throw new DecryptionError("PLAINTEXT_NOT_JSON", event, {plaintext, error});
} }
this._validatePayload(payload, event); this._validatePayload(payload, event);
return new DecryptionResult(payload, senderKey, payload.keys); return new DecryptionResult(payload, senderKey, payload.keys);