Skip to content

Commit 2f8bff1

Browse files
committed
Fix race condition in pair select
This bug was introduced by #1524. The crux of problem is that TryOffer/PollDesc.onPrepare method is no longer allowed to update fields in these classes (like "resumeToken" and "pollResult") after call to tryResumeSend/Receive method, because the latter will complete the ongoing atomic operation and helper method might find it complete and try reading "resumeToken" which was not initialized yet. This change removes "pollResult" field which was not really needed ("result.pollResult" field is used) and removes "resumeToken" by exploiting the fact that current implementation of CancellableContinuationImpl does not need a token anymore. However, CancellableContinuation.tryResume/completeResume ABI is left intact, because it is used by 3rd party code. This fix lead to overall simplification of the code. A number of fields and an auxiliary IdempotentTokenValue class are removed, tokens used to indicate various results are consolidated, so that resume success is now consistently indicated by a single RESUME_TOKEN symbol. Fixes #1561
1 parent fd27d55 commit 2f8bff1

File tree

5 files changed

+74
-108
lines changed

5 files changed

+74
-108
lines changed

kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt

+12-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ private const val UNDECIDED = 0
1414
private const val SUSPENDED = 1
1515
private const val RESUMED = 2
1616

17+
@JvmField
18+
@SharedImmutable
19+
internal val RESUME_TOKEN = Symbol("RESUME_TOKEN")
20+
1721
/**
1822
* @suppress **This is unstable API and it is subject to change.**
1923
*/
@@ -347,20 +351,21 @@ internal open class CancellableContinuationImpl<in T>(
347351
parentHandle = NonDisposableHandle
348352
}
349353

354+
// Note: Always returns RESUME_TOKEN | null
350355
override fun tryResume(value: T, idempotent: Any?): Any? {
351356
_state.loop { state ->
352357
when (state) {
353358
is NotCompleted -> {
354359
val update: Any? = if (idempotent == null) value else
355-
CompletedIdempotentResult(idempotent, value, state)
360+
CompletedIdempotentResult(idempotent, value)
356361
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
357362
detachChildIfNonResuable()
358-
return state
363+
return RESUME_TOKEN
359364
}
360365
is CompletedIdempotentResult -> {
361366
return if (state.idempotentResume === idempotent) {
362367
assert { state.result === value } // "Non-idempotent resume"
363-
state.token
368+
RESUME_TOKEN
364369
} else {
365370
null
366371
}
@@ -377,15 +382,16 @@ internal open class CancellableContinuationImpl<in T>(
377382
val update = CompletedExceptionally(exception)
378383
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
379384
detachChildIfNonResuable()
380-
return state
385+
return RESUME_TOKEN
381386
}
382387
else -> return null // cannot resume -- not active anymore
383388
}
384389
}
385390
}
386391

392+
// note: token is always RESUME_TOKEN
387393
override fun completeResume(token: Any) {
388-
// note: We don't actually use token anymore, because handler needs to be invoked on cancellation only
394+
assert { token === RESUME_TOKEN }
389395
dispatchResume(resumeMode)
390396
}
391397

@@ -437,8 +443,7 @@ private class InvokeOnCancel( // Clashes with InvokeOnCancellation
437443

438444
private class CompletedIdempotentResult(
439445
@JvmField val idempotentResume: Any?,
440-
@JvmField val result: Any?,
441-
@JvmField val token: NotCompleted
446+
@JvmField val result: Any?
442447
) {
443448
override fun toString(): String = "CompletedIdempotentResult[$result]"
444449
}

kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt

+34-67
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
4848
val receive = takeFirstReceiveOrPeekClosed() ?: return OFFER_FAILED
4949
val token = receive.tryResumeReceive(element, null)
5050
if (token != null) {
51-
receive.completeResumeReceive(token)
51+
assert { token === RESUME_TOKEN }
52+
receive.completeResumeReceive(element)
5253
return receive.offerResult
5354
}
5455
}
@@ -65,7 +66,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
6566
val failure = select.performAtomicTrySelect(offerOp)
6667
if (failure != null) return failure
6768
val receive = offerOp.result
68-
receive.completeResumeReceive(offerOp.resumeToken!!)
69+
receive.completeResumeReceive(element)
6970
return receive.offerResult
7071
}
7172

@@ -354,8 +355,6 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
354355
@JvmField val element: E,
355356
queue: LockFreeLinkedListHead
356357
) : RemoveFirstDesc<ReceiveOrClosed<E>>(queue) {
357-
@JvmField var resumeToken: Any? = null
358-
359358
override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
360359
is Closed<*> -> affected
361360
!is ReceiveOrClosed<*> -> OFFER_FAILED
@@ -367,7 +366,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
367366
val affected = prepareOp.affected as ReceiveOrClosed<E> // see "failure" impl
368367
val token = affected.tryResumeReceive(element, prepareOp) ?: return REMOVE_PREPARED
369368
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
370-
resumeToken = token
369+
assert { token === RESUME_TOKEN }
371370
return null
372371
}
373372
}
@@ -454,8 +453,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
454453
override fun tryResumeSend(otherOp: PrepareOp?): Any? =
455454
select.trySelectOther(otherOp)
456455

457-
override fun completeResumeSend(token: Any) {
458-
assert { token === SELECT_STARTED }
456+
override fun completeResumeSend() {
459457
block.startCoroutine(receiver = channel, completion = select.completion)
460458
}
461459

@@ -475,8 +473,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
475473
@JvmField val element: E
476474
) : Send() {
477475
override val pollResult: Any? get() = element
478-
override fun tryResumeSend(otherOp: PrepareOp?): Any? = SEND_RESUMED.also { otherOp?.finishPrepare() }
479-
override fun completeResumeSend(token: Any) { assert { token === SEND_RESUMED } }
476+
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
477+
override fun completeResumeSend() {}
480478
override fun resumeSendClosed(closed: Closed<*>) {}
481479
}
482480
}
@@ -511,7 +509,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
511509
val send = takeFirstSendOrPeekClosed() ?: return POLL_FAILED
512510
val token = send.tryResumeSend(null)
513511
if (token != null) {
514-
send.completeResumeSend(token)
512+
assert { token === RESUME_TOKEN }
513+
send.completeResumeSend()
515514
return send.pollResult
516515
}
517516
}
@@ -528,8 +527,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
528527
val failure = select.performAtomicTrySelect(pollOp)
529528
if (failure != null) return failure
530529
val send = pollOp.result
531-
send.completeResumeSend(pollOp.resumeToken!!)
532-
return pollOp.pollResult
530+
send.completeResumeSend()
531+
return pollOp.result.pollResult
533532
}
534533

535534
// ------ state functions & helpers for concrete implementations ------
@@ -673,9 +672,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
673672
* @suppress **This is unstable API and it is subject to change.**
674673
*/
675674
protected class TryPollDesc<E>(queue: LockFreeLinkedListHead) : RemoveFirstDesc<Send>(queue) {
676-
@JvmField var resumeToken: Any? = null
677-
@JvmField var pollResult: E? = null
678-
679675
override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
680676
is Closed<*> -> affected
681677
!is Send -> POLL_FAILED
@@ -687,8 +683,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
687683
val affected = prepareOp.affected as Send // see "failure" impl
688684
val token = affected.tryResumeSend(prepareOp) ?: return REMOVE_PREPARED
689685
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
690-
resumeToken = token
691-
pollResult = affected.pollResult as E
686+
assert { token === RESUME_TOKEN }
692687
return null
693688
}
694689
}
@@ -908,7 +903,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
908903
return cont.tryResume(resumeValue(value), otherOp?.desc)
909904
}
910905

911-
override fun completeResumeReceive(token: Any) = cont.completeResume(token)
906+
override fun completeResumeReceive(value: E) = cont.completeResume(RESUME_TOKEN)
907+
912908
override fun resumeReceiveClosed(closed: Closed<*>) {
913909
when {
914910
receiveMode == RECEIVE_NULL_ON_CLOSE && closed.closeCause == null -> cont.resume(null)
@@ -925,25 +921,16 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
925921
) : Receive<E>() {
926922
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
927923
otherOp?.finishPrepare()
928-
val token = cont.tryResume(true, otherOp?.desc)
929-
if (token != null) {
930-
/*
931-
When otherOp != null this invocation can be stale and we cannot directly update iterator.result
932-
Instead, we save both token & result into a temporary IdempotentTokenValue object and
933-
set iterator result only in completeResumeReceive that is going to be invoked just once
934-
*/
935-
if (otherOp != null) return IdempotentTokenValue(token, value)
936-
iterator.result = value
937-
}
938-
return token
924+
return cont.tryResume(true, otherOp?.desc)
939925
}
940926

941-
override fun completeResumeReceive(token: Any) {
942-
if (token is IdempotentTokenValue<*>) {
943-
iterator.result = token.value
944-
cont.completeResume(token.token)
945-
} else
946-
cont.completeResume(token)
927+
override fun completeResumeReceive(value: E) {
928+
/*
929+
When otherOp != null invocation of tryResumeReceive can happen multiple times and much later,
930+
but completeResumeReceive is called once so we set iterator result here.
931+
*/
932+
iterator.result = value
933+
cont.completeResume(RESUME_TOKEN)
947934
}
948935

949936
override fun resumeReceiveClosed(closed: Closed<*>) {
@@ -966,14 +953,11 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
966953
@JvmField val block: suspend (Any?) -> R,
967954
@JvmField val receiveMode: Int
968955
) : Receive<E>(), DisposableHandle {
969-
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
970-
val result = select.trySelectOther(otherOp)
971-
return if (result === SELECT_STARTED) value ?: NULL_VALUE else result
972-
}
956+
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? =
957+
select.trySelectOther(otherOp)
973958

974959
@Suppress("UNCHECKED_CAST")
975-
override fun completeResumeReceive(token: Any) {
976-
val value: E = NULL_VALUE.unbox<E>(token)
960+
override fun completeResumeReceive(value: E) {
977961
block.startCoroutine(if (receiveMode == RECEIVE_RESULT) ValueOrClosed.value(value) else value, select.completion)
978962
}
979963

@@ -997,11 +981,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
997981

998982
override fun toString(): String = "ReceiveSelect[$select,receiveMode=$receiveMode]"
999983
}
1000-
1001-
private class IdempotentTokenValue<out E>(
1002-
@JvmField val token: Any,
1003-
@JvmField val value: E
1004-
)
1005984
}
1006985

1007986
// receiveMode values
@@ -1025,18 +1004,6 @@ internal val POLL_FAILED: Any = Symbol("POLL_FAILED")
10251004
@SharedImmutable
10261005
internal val ENQUEUE_FAILED: Any = Symbol("ENQUEUE_FAILED")
10271006

1028-
@JvmField
1029-
@SharedImmutable
1030-
internal val NULL_VALUE: Symbol = Symbol("NULL_VALUE")
1031-
1032-
@JvmField
1033-
@SharedImmutable
1034-
internal val CLOSE_RESUMED: Any = Symbol("CLOSE_RESUMED")
1035-
1036-
@JvmField
1037-
@SharedImmutable
1038-
internal val SEND_RESUMED: Any = Symbol("SEND_RESUMED")
1039-
10401007
@JvmField
10411008
@SharedImmutable
10421009
internal val HANDLER_INVOKED: Any = Symbol("ON_CLOSE_HANDLER_INVOKED")
@@ -1050,10 +1017,10 @@ internal abstract class Send : LockFreeLinkedListNode() {
10501017
abstract val pollResult: Any? // E | Closed
10511018
// Returns: null - failure,
10521019
// RETRY_ATOMIC for retry (only when otherOp != null),
1053-
// otherwise token for completeResumeSend
1020+
// RESUME_TOKEN on success (call completeResumeSend)
10541021
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
10551022
abstract fun tryResumeSend(otherOp: PrepareOp?): Any?
1056-
abstract fun completeResumeSend(token: Any)
1023+
abstract fun completeResumeSend()
10571024
abstract fun resumeSendClosed(closed: Closed<*>)
10581025
}
10591026

@@ -1064,10 +1031,10 @@ internal interface ReceiveOrClosed<in E> {
10641031
val offerResult: Any // OFFER_SUCCESS | Closed
10651032
// Returns: null - failure,
10661033
// RETRY_ATOMIC for retry (only when otherOp != null),
1067-
// otherwise token for completeResumeReceive
1034+
// RESUME_TOKEN on success (call completeResumeReceive)
10681035
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
10691036
fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any?
1070-
fun completeResumeReceive(token: Any)
1037+
fun completeResumeReceive(value: E)
10711038
}
10721039

10731040
/**
@@ -1082,7 +1049,7 @@ internal class SendElement(
10821049
otherOp?.finishPrepare()
10831050
return cont.tryResume(Unit, otherOp?.desc)
10841051
}
1085-
override fun completeResumeSend(token: Any) = cont.completeResume(token)
1052+
override fun completeResumeSend() = cont.completeResume(RESUME_TOKEN)
10861053
override fun resumeSendClosed(closed: Closed<*>) = cont.resumeWithException(closed.sendException)
10871054
override fun toString(): String = "SendElement($pollResult)"
10881055
}
@@ -1098,10 +1065,10 @@ internal class Closed<in E>(
10981065

10991066
override val offerResult get() = this
11001067
override val pollResult get() = this
1101-
override fun tryResumeSend(otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
1102-
override fun completeResumeSend(token: Any) { assert { token === CLOSE_RESUMED } }
1103-
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
1104-
override fun completeResumeReceive(token: Any) { assert { token === CLOSE_RESUMED } }
1068+
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
1069+
override fun completeResumeSend() {}
1070+
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
1071+
override fun completeResumeReceive(value: E) {}
11051072
override fun resumeSendClosed(closed: Closed<*>) = assert { false } // "Should be never invoked"
11061073
override fun toString(): String = "Closed[$closeCause]"
11071074
}

kotlinx-coroutines-core/common/src/channels/ArrayBroadcastChannel.kt

+8-8
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ internal class ArrayBroadcastChannel<E>(
143143
private tailrec fun updateHead(addSub: Subscriber<E>? = null, removeSub: Subscriber<E>? = null) {
144144
// update head in a tail rec loop
145145
var send: Send? = null
146-
var token: Any? = null
147146
bufferLock.withLock {
148147
if (addSub != null) {
149148
addSub.subHead = tail // start from last element
@@ -172,8 +171,9 @@ internal class ArrayBroadcastChannel<E>(
172171
while (true) {
173172
send = takeFirstSendOrPeekClosed() ?: break // when when no sender
174173
if (send is Closed<*>) break // break when closed for send
175-
token = send!!.tryResumeSend(null)
174+
val token = send!!.tryResumeSend(null)
176175
if (token != null) {
176+
assert { token === RESUME_TOKEN }
177177
// put sent element to the buffer
178178
buffer[(tail % capacity).toInt()] = (send as Send).pollResult
179179
this.size = size + 1
@@ -186,7 +186,7 @@ internal class ArrayBroadcastChannel<E>(
186186
return // done updating here -> return
187187
}
188188
// we only get out of the lock normally when there is a sender to resume
189-
send!!.completeResumeSend(token!!)
189+
send!!.completeResumeSend()
190190
// since we've just sent an element, we might need to resume some receivers
191191
checkSubOffers()
192192
// tailrec call to recheck
@@ -239,9 +239,9 @@ internal class ArrayBroadcastChannel<E>(
239239
// it means that `checkOffer` must be retried after every `unlock`
240240
if (!subLock.tryLock()) break
241241
val receive: ReceiveOrClosed<E>?
242-
val token: Any?
242+
var result: Any?
243243
try {
244-
val result = peekUnderLock()
244+
result = peekUnderLock()
245245
when {
246246
result === POLL_FAILED -> continue@loop // must retest `needsToCheckOfferWithoutLock` outside of the lock
247247
result is Closed<*> -> {
@@ -252,15 +252,15 @@ internal class ArrayBroadcastChannel<E>(
252252
// find a receiver for an element
253253
receive = takeFirstReceiveOrPeekClosed() ?: break // break when no one's receiving
254254
if (receive is Closed<*>) break // noting more to do if this sub already closed
255-
token = receive.tryResumeReceive(result as E, null)
256-
if (token == null) continue // bail out here to next iteration (see for next receiver)
255+
val token = receive.tryResumeReceive(result as E, null) ?: continue
256+
assert { token === RESUME_TOKEN }
257257
val subHead = this.subHead
258258
this.subHead = subHead + 1 // retrieved element for this subscriber
259259
updated = true
260260
} finally {
261261
subLock.unlock()
262262
}
263-
receive!!.completeResumeReceive(token!!)
263+
receive!!.completeResumeReceive(result as E)
264264
}
265265
// do close outside of lock if needed
266266
closed?.also { close(cause = it.closeCause) }

0 commit comments

Comments
 (0)