Skip to content

Commit 8ccbf47

Browse files
committed
Fix race condition in pair select
This bug was introduced by #1524. The crux of problem is that TryOffer/PollDesc.onPrepare method is no longer allowed to update fields in these classes (like "resumeToken" and "pollResult") after call to tryResumeSend/Receive method, because the latter will complete the ongoing atomic operation and helper method might find it complete and try reading "resumeToken" which was not initialized yet. This change removes "pollResult" field which was not really needed ("result.pollResult" field is used) and removes "resumeToken" by exploiting the fact that current implementation of CancellableContinuationImpl does not need a token anymore. However, CancellableContinuation.tryResume/completeResume ABI is left intact, because it is used by 3rd party code. This fix lead to overall simplification of the code. A number of fields and an auxiliary IdempotentTokenValue class are removed, tokens used to indicate various results are consolidated, so that resume success is now consistently indicated by a single RESUME_TOKEN symbol. Fixes #1561
1 parent 3dbe82b commit 8ccbf47

File tree

5 files changed

+74
-108
lines changed

5 files changed

+74
-108
lines changed

kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ private const val UNDECIDED = 0
1414
private const val SUSPENDED = 1
1515
private const val RESUMED = 2
1616

17+
@JvmField
18+
@SharedImmutable
19+
internal val RESUME_TOKEN = Symbol("RESUME_TOKEN")
20+
1721
/**
1822
* @suppress **This is unstable API and it is subject to change.**
1923
*/
@@ -285,20 +289,21 @@ internal open class CancellableContinuationImpl<in T>(
285289
}
286290
}
287291

292+
// Note: Always returns RESUME_TOKEN | null
288293
override fun tryResume(value: T, idempotent: Any?): Any? {
289294
_state.loop { state ->
290295
when (state) {
291296
is NotCompleted -> {
292297
val update: Any? = if (idempotent == null) value else
293-
CompletedIdempotentResult(idempotent, value, state)
298+
CompletedIdempotentResult(idempotent, value)
294299
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
295300
disposeParentHandle()
296-
return state
301+
return RESUME_TOKEN
297302
}
298303
is CompletedIdempotentResult -> {
299304
return if (state.idempotentResume === idempotent) {
300305
assert { state.result === value } // "Non-idempotent resume"
301-
state.token
306+
RESUME_TOKEN
302307
} else {
303308
null
304309
}
@@ -315,15 +320,16 @@ internal open class CancellableContinuationImpl<in T>(
315320
val update = CompletedExceptionally(exception)
316321
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
317322
disposeParentHandle()
318-
return state
323+
return RESUME_TOKEN
319324
}
320325
else -> return null // cannot resume -- not active anymore
321326
}
322327
}
323328
}
324329

330+
// note: token is always RESUME_TOKEN
325331
override fun completeResume(token: Any) {
326-
// note: We don't actually use token anymore, because handler needs to be invoked on cancellation only
332+
assert { token === RESUME_TOKEN }
327333
dispatchResume(resumeMode)
328334
}
329335

@@ -375,8 +381,7 @@ private class InvokeOnCancel( // Clashes with InvokeOnCancellation
375381

376382
private class CompletedIdempotentResult(
377383
@JvmField val idempotentResume: Any?,
378-
@JvmField val result: Any?,
379-
@JvmField val token: NotCompleted
384+
@JvmField val result: Any?
380385
) {
381386
override fun toString(): String = "CompletedIdempotentResult[$result]"
382387
}

kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt

+34-67
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
4848
val receive = takeFirstReceiveOrPeekClosed() ?: return OFFER_FAILED
4949
val token = receive.tryResumeReceive(element, null)
5050
if (token != null) {
51-
receive.completeResumeReceive(token)
51+
assert { token === RESUME_TOKEN }
52+
receive.completeResumeReceive(element)
5253
return receive.offerResult
5354
}
5455
}
@@ -65,7 +66,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
6566
val failure = select.performAtomicTrySelect(offerOp)
6667
if (failure != null) return failure
6768
val receive = offerOp.result
68-
receive.completeResumeReceive(offerOp.resumeToken!!)
69+
receive.completeResumeReceive(element)
6970
return receive.offerResult
7071
}
7172

@@ -354,8 +355,6 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
354355
@JvmField val element: E,
355356
queue: LockFreeLinkedListHead
356357
) : RemoveFirstDesc<ReceiveOrClosed<E>>(queue) {
357-
@JvmField var resumeToken: Any? = null
358-
359358
override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
360359
is Closed<*> -> affected
361360
!is ReceiveOrClosed<*> -> OFFER_FAILED
@@ -367,7 +366,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
367366
val affected = prepareOp.affected as ReceiveOrClosed<E> // see "failure" impl
368367
val token = affected.tryResumeReceive(element, prepareOp) ?: return REMOVE_PREPARED
369368
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
370-
resumeToken = token
369+
assert { token === RESUME_TOKEN }
371370
return null
372371
}
373372
}
@@ -454,8 +453,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
454453
override fun tryResumeSend(otherOp: PrepareOp?): Any? =
455454
select.trySelectOther(otherOp)
456455

457-
override fun completeResumeSend(token: Any) {
458-
assert { token === SELECT_STARTED }
456+
override fun completeResumeSend() {
459457
block.startCoroutine(receiver = channel, completion = select.completion)
460458
}
461459

@@ -475,8 +473,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
475473
@JvmField val element: E
476474
) : Send() {
477475
override val pollResult: Any? get() = element
478-
override fun tryResumeSend(otherOp: PrepareOp?): Any? = SEND_RESUMED.also { otherOp?.finishPrepare() }
479-
override fun completeResumeSend(token: Any) { assert { token === SEND_RESUMED } }
476+
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
477+
override fun completeResumeSend() {}
480478
override fun resumeSendClosed(closed: Closed<*>) {}
481479
}
482480
}
@@ -511,7 +509,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
511509
val send = takeFirstSendOrPeekClosed() ?: return POLL_FAILED
512510
val token = send.tryResumeSend(null)
513511
if (token != null) {
514-
send.completeResumeSend(token)
512+
assert { token === RESUME_TOKEN }
513+
send.completeResumeSend()
515514
return send.pollResult
516515
}
517516
}
@@ -528,8 +527,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
528527
val failure = select.performAtomicTrySelect(pollOp)
529528
if (failure != null) return failure
530529
val send = pollOp.result
531-
send.completeResumeSend(pollOp.resumeToken!!)
532-
return pollOp.pollResult
530+
send.completeResumeSend()
531+
return pollOp.result.pollResult
533532
}
534533

535534
// ------ state functions & helpers for concrete implementations ------
@@ -673,9 +672,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
673672
* @suppress **This is unstable API and it is subject to change.**
674673
*/
675674
protected class TryPollDesc<E>(queue: LockFreeLinkedListHead) : RemoveFirstDesc<Send>(queue) {
676-
@JvmField var resumeToken: Any? = null
677-
@JvmField var pollResult: E? = null
678-
679675
override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
680676
is Closed<*> -> affected
681677
!is Send -> POLL_FAILED
@@ -687,8 +683,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
687683
val affected = prepareOp.affected as Send // see "failure" impl
688684
val token = affected.tryResumeSend(prepareOp) ?: return REMOVE_PREPARED
689685
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
690-
resumeToken = token
691-
pollResult = affected.pollResult as E
686+
assert { token === RESUME_TOKEN }
692687
return null
693688
}
694689
}
@@ -908,7 +903,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
908903
return cont.tryResume(resumeValue(value), otherOp?.desc)
909904
}
910905

911-
override fun completeResumeReceive(token: Any) = cont.completeResume(token)
906+
override fun completeResumeReceive(value: E) = cont.completeResume(RESUME_TOKEN)
907+
912908
override fun resumeReceiveClosed(closed: Closed<*>) {
913909
when {
914910
receiveMode == RECEIVE_NULL_ON_CLOSE && closed.closeCause == null -> cont.resume(null)
@@ -925,25 +921,16 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
925921
) : Receive<E>() {
926922
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
927923
otherOp?.finishPrepare()
928-
val token = cont.tryResume(true, otherOp?.desc)
929-
if (token != null) {
930-
/*
931-
When otherOp != null this invocation can be stale and we cannot directly update iterator.result
932-
Instead, we save both token & result into a temporary IdempotentTokenValue object and
933-
set iterator result only in completeResumeReceive that is going to be invoked just once
934-
*/
935-
if (otherOp != null) return IdempotentTokenValue(token, value)
936-
iterator.result = value
937-
}
938-
return token
924+
return cont.tryResume(true, otherOp?.desc)
939925
}
940926

941-
override fun completeResumeReceive(token: Any) {
942-
if (token is IdempotentTokenValue<*>) {
943-
iterator.result = token.value
944-
cont.completeResume(token.token)
945-
} else
946-
cont.completeResume(token)
927+
override fun completeResumeReceive(value: E) {
928+
/*
929+
When otherOp != null invocation of tryResumeReceive can happen multiple times and much later,
930+
but completeResumeReceive is called once so we set iterator result here.
931+
*/
932+
iterator.result = value
933+
cont.completeResume(RESUME_TOKEN)
947934
}
948935

949936
override fun resumeReceiveClosed(closed: Closed<*>) {
@@ -966,14 +953,11 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
966953
@JvmField val block: suspend (Any?) -> R,
967954
@JvmField val receiveMode: Int
968955
) : Receive<E>(), DisposableHandle {
969-
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
970-
val result = select.trySelectOther(otherOp)
971-
return if (result === SELECT_STARTED) value ?: NULL_VALUE else result
972-
}
956+
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? =
957+
select.trySelectOther(otherOp)
973958

974959
@Suppress("UNCHECKED_CAST")
975-
override fun completeResumeReceive(token: Any) {
976-
val value: E = NULL_VALUE.unbox<E>(token)
960+
override fun completeResumeReceive(value: E) {
977961
block.startCoroutine(if (receiveMode == RECEIVE_RESULT) ValueOrClosed.value(value) else value, select.completion)
978962
}
979963

@@ -997,11 +981,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
997981

998982
override fun toString(): String = "ReceiveSelect[$select,receiveMode=$receiveMode]"
999983
}
1000-
1001-
private class IdempotentTokenValue<out E>(
1002-
@JvmField val token: Any,
1003-
@JvmField val value: E
1004-
)
1005984
}
1006985

1007986
// receiveMode values
@@ -1025,18 +1004,6 @@ internal val POLL_FAILED: Any = Symbol("POLL_FAILED")
10251004
@SharedImmutable
10261005
internal val ENQUEUE_FAILED: Any = Symbol("ENQUEUE_FAILED")
10271006

1028-
@JvmField
1029-
@SharedImmutable
1030-
internal val NULL_VALUE: Symbol = Symbol("NULL_VALUE")
1031-
1032-
@JvmField
1033-
@SharedImmutable
1034-
internal val CLOSE_RESUMED: Any = Symbol("CLOSE_RESUMED")
1035-
1036-
@JvmField
1037-
@SharedImmutable
1038-
internal val SEND_RESUMED: Any = Symbol("SEND_RESUMED")
1039-
10401007
@JvmField
10411008
@SharedImmutable
10421009
internal val HANDLER_INVOKED: Any = Symbol("ON_CLOSE_HANDLER_INVOKED")
@@ -1050,10 +1017,10 @@ internal abstract class Send : LockFreeLinkedListNode() {
10501017
abstract val pollResult: Any? // E | Closed
10511018
// Returns: null - failure,
10521019
// RETRY_ATOMIC for retry (only when otherOp != null),
1053-
// otherwise token for completeResumeSend
1020+
// RESUME_TOKEN on success (call completeResumeSend)
10541021
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
10551022
abstract fun tryResumeSend(otherOp: PrepareOp?): Any?
1056-
abstract fun completeResumeSend(token: Any)
1023+
abstract fun completeResumeSend()
10571024
abstract fun resumeSendClosed(closed: Closed<*>)
10581025
}
10591026

@@ -1064,10 +1031,10 @@ internal interface ReceiveOrClosed<in E> {
10641031
val offerResult: Any // OFFER_SUCCESS | Closed
10651032
// Returns: null - failure,
10661033
// RETRY_ATOMIC for retry (only when otherOp != null),
1067-
// otherwise token for completeResumeReceive
1034+
// RESUME_TOKEN on success (call completeResumeReceive)
10681035
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
10691036
fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any?
1070-
fun completeResumeReceive(token: Any)
1037+
fun completeResumeReceive(value: E)
10711038
}
10721039

10731040
/**
@@ -1082,7 +1049,7 @@ internal class SendElement(
10821049
otherOp?.finishPrepare()
10831050
return cont.tryResume(Unit, otherOp?.desc)
10841051
}
1085-
override fun completeResumeSend(token: Any) = cont.completeResume(token)
1052+
override fun completeResumeSend() = cont.completeResume(RESUME_TOKEN)
10861053
override fun resumeSendClosed(closed: Closed<*>) = cont.resumeWithException(closed.sendException)
10871054
override fun toString(): String = "SendElement($pollResult)"
10881055
}
@@ -1098,10 +1065,10 @@ internal class Closed<in E>(
10981065

10991066
override val offerResult get() = this
11001067
override val pollResult get() = this
1101-
override fun tryResumeSend(otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
1102-
override fun completeResumeSend(token: Any) { assert { token === CLOSE_RESUMED } }
1103-
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
1104-
override fun completeResumeReceive(token: Any) { assert { token === CLOSE_RESUMED } }
1068+
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
1069+
override fun completeResumeSend() {}
1070+
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
1071+
override fun completeResumeReceive(value: E) {}
11051072
override fun resumeSendClosed(closed: Closed<*>) = assert { false } // "Should be never invoked"
11061073
override fun toString(): String = "Closed[$closeCause]"
11071074
}

kotlinx-coroutines-core/common/src/channels/ArrayBroadcastChannel.kt

+8-8
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ internal class ArrayBroadcastChannel<E>(
134134
private tailrec fun updateHead(addSub: Subscriber<E>? = null, removeSub: Subscriber<E>? = null) {
135135
// update head in a tail rec loop
136136
var send: Send? = null
137-
var token: Any? = null
138137
bufferLock.withLock {
139138
if (addSub != null) {
140139
addSub.subHead = tail // start from last element
@@ -163,8 +162,9 @@ internal class ArrayBroadcastChannel<E>(
163162
while (true) {
164163
send = takeFirstSendOrPeekClosed() ?: break // when when no sender
165164
if (send is Closed<*>) break // break when closed for send
166-
token = send!!.tryResumeSend(null)
165+
val token = send!!.tryResumeSend(null)
167166
if (token != null) {
167+
assert { token === RESUME_TOKEN }
168168
// put sent element to the buffer
169169
buffer[(tail % capacity).toInt()] = (send as Send).pollResult
170170
this.size = size + 1
@@ -177,7 +177,7 @@ internal class ArrayBroadcastChannel<E>(
177177
return // done updating here -> return
178178
}
179179
// we only get out of the lock normally when there is a sender to resume
180-
send!!.completeResumeSend(token!!)
180+
send!!.completeResumeSend()
181181
// since we've just sent an element, we might need to resume some receivers
182182
checkSubOffers()
183183
// tailrec call to recheck
@@ -229,9 +229,9 @@ internal class ArrayBroadcastChannel<E>(
229229
// it means that `checkOffer` must be retried after every `unlock`
230230
if (!subLock.tryLock()) break
231231
val receive: ReceiveOrClosed<E>?
232-
val token: Any?
232+
var result: Any?
233233
try {
234-
val result = peekUnderLock()
234+
result = peekUnderLock()
235235
when {
236236
result === POLL_FAILED -> continue@loop // must retest `needsToCheckOfferWithoutLock` outside of the lock
237237
result is Closed<*> -> {
@@ -242,15 +242,15 @@ internal class ArrayBroadcastChannel<E>(
242242
// find a receiver for an element
243243
receive = takeFirstReceiveOrPeekClosed() ?: break // break when no one's receiving
244244
if (receive is Closed<*>) break // noting more to do if this sub already closed
245-
token = receive.tryResumeReceive(result as E, null)
246-
if (token == null) continue // bail out here to next iteration (see for next receiver)
245+
val token = receive.tryResumeReceive(result as E, null) ?: continue
246+
assert { token === RESUME_TOKEN }
247247
val subHead = this.subHead
248248
this.subHead = subHead + 1 // retrieved element for this subscriber
249249
updated = true
250250
} finally {
251251
subLock.unlock()
252252
}
253-
receive!!.completeResumeReceive(token!!)
253+
receive!!.completeResumeReceive(result as E)
254254
}
255255
// do close outside of lock if needed
256256
closed?.also { close(cause = it.closeCause) }

0 commit comments

Comments
 (0)