Skip to content

Fix race condition in pair select #1562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ private const val UNDECIDED = 0
private const val SUSPENDED = 1
private const val RESUMED = 2

@JvmField
@SharedImmutable
internal val RESUME_TOKEN = Symbol("RESUME_TOKEN")

/**
* @suppress **This is unstable API and it is subject to change.**
*/
Expand Down Expand Up @@ -347,20 +351,21 @@ internal open class CancellableContinuationImpl<in T>(
parentHandle = NonDisposableHandle
}

// Note: Always returns RESUME_TOKEN | null
override fun tryResume(value: T, idempotent: Any?): Any? {
_state.loop { state ->
when (state) {
is NotCompleted -> {
val update: Any? = if (idempotent == null) value else
CompletedIdempotentResult(idempotent, value, state)
CompletedIdempotentResult(idempotent, value)
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
detachChildIfNonResuable()
return state
return RESUME_TOKEN
}
is CompletedIdempotentResult -> {
return if (state.idempotentResume === idempotent) {
assert { state.result === value } // "Non-idempotent resume"
state.token
RESUME_TOKEN
} else {
null
}
Expand All @@ -377,15 +382,16 @@ internal open class CancellableContinuationImpl<in T>(
val update = CompletedExceptionally(exception)
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
detachChildIfNonResuable()
return state
return RESUME_TOKEN
}
else -> return null // cannot resume -- not active anymore
}
}
}

// note: token is always RESUME_TOKEN
override fun completeResume(token: Any) {
// note: We don't actually use token anymore, because handler needs to be invoked on cancellation only
assert { token === RESUME_TOKEN }
dispatchResume(resumeMode)
}

Expand Down Expand Up @@ -437,8 +443,7 @@ private class InvokeOnCancel( // Clashes with InvokeOnCancellation

private class CompletedIdempotentResult(
@JvmField val idempotentResume: Any?,
@JvmField val result: Any?,
@JvmField val token: NotCompleted
@JvmField val result: Any?
) {
override fun toString(): String = "CompletedIdempotentResult[$result]"
}
Expand Down
101 changes: 34 additions & 67 deletions kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
val receive = takeFirstReceiveOrPeekClosed() ?: return OFFER_FAILED
val token = receive.tryResumeReceive(element, null)
if (token != null) {
receive.completeResumeReceive(token)
assert { token === RESUME_TOKEN }
receive.completeResumeReceive(element)
return receive.offerResult
}
}
Expand All @@ -65,7 +66,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
val failure = select.performAtomicTrySelect(offerOp)
if (failure != null) return failure
val receive = offerOp.result
receive.completeResumeReceive(offerOp.resumeToken!!)
receive.completeResumeReceive(element)
return receive.offerResult
}

Expand Down Expand Up @@ -354,8 +355,6 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
@JvmField val element: E,
queue: LockFreeLinkedListHead
) : RemoveFirstDesc<ReceiveOrClosed<E>>(queue) {
@JvmField var resumeToken: Any? = null

override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
is Closed<*> -> affected
!is ReceiveOrClosed<*> -> OFFER_FAILED
Expand All @@ -367,7 +366,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
val affected = prepareOp.affected as ReceiveOrClosed<E> // see "failure" impl
val token = affected.tryResumeReceive(element, prepareOp) ?: return REMOVE_PREPARED
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
resumeToken = token
assert { token === RESUME_TOKEN }
return null
}
}
Expand Down Expand Up @@ -454,8 +453,7 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
override fun tryResumeSend(otherOp: PrepareOp?): Any? =
select.trySelectOther(otherOp)

override fun completeResumeSend(token: Any) {
assert { token === SELECT_STARTED }
override fun completeResumeSend() {
block.startCoroutine(receiver = channel, completion = select.completion)
}

Expand All @@ -475,8 +473,8 @@ internal abstract class AbstractSendChannel<E> : SendChannel<E> {
@JvmField val element: E
) : Send() {
override val pollResult: Any? get() = element
override fun tryResumeSend(otherOp: PrepareOp?): Any? = SEND_RESUMED.also { otherOp?.finishPrepare() }
override fun completeResumeSend(token: Any) { assert { token === SEND_RESUMED } }
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
override fun completeResumeSend() {}
override fun resumeSendClosed(closed: Closed<*>) {}
}
}
Expand Down Expand Up @@ -511,7 +509,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
val send = takeFirstSendOrPeekClosed() ?: return POLL_FAILED
val token = send.tryResumeSend(null)
if (token != null) {
send.completeResumeSend(token)
assert { token === RESUME_TOKEN }
send.completeResumeSend()
return send.pollResult
}
}
Expand All @@ -528,8 +527,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
val failure = select.performAtomicTrySelect(pollOp)
if (failure != null) return failure
val send = pollOp.result
send.completeResumeSend(pollOp.resumeToken!!)
return pollOp.pollResult
send.completeResumeSend()
return pollOp.result.pollResult
}

// ------ state functions & helpers for concrete implementations ------
Expand Down Expand Up @@ -673,9 +672,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
* @suppress **This is unstable API and it is subject to change.**
*/
protected class TryPollDesc<E>(queue: LockFreeLinkedListHead) : RemoveFirstDesc<Send>(queue) {
@JvmField var resumeToken: Any? = null
@JvmField var pollResult: E? = null

override fun failure(affected: LockFreeLinkedListNode): Any? = when (affected) {
is Closed<*> -> affected
!is Send -> POLL_FAILED
Expand All @@ -687,8 +683,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
val affected = prepareOp.affected as Send // see "failure" impl
val token = affected.tryResumeSend(prepareOp) ?: return REMOVE_PREPARED
if (token === RETRY_ATOMIC) return RETRY_ATOMIC
resumeToken = token
pollResult = affected.pollResult as E
assert { token === RESUME_TOKEN }
return null
}
}
Expand Down Expand Up @@ -908,7 +903,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
return cont.tryResume(resumeValue(value), otherOp?.desc)
}

override fun completeResumeReceive(token: Any) = cont.completeResume(token)
override fun completeResumeReceive(value: E) = cont.completeResume(RESUME_TOKEN)

override fun resumeReceiveClosed(closed: Closed<*>) {
when {
receiveMode == RECEIVE_NULL_ON_CLOSE && closed.closeCause == null -> cont.resume(null)
Expand All @@ -925,25 +921,16 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
) : Receive<E>() {
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
otherOp?.finishPrepare()
val token = cont.tryResume(true, otherOp?.desc)
if (token != null) {
/*
When otherOp != null this invocation can be stale and we cannot directly update iterator.result
Instead, we save both token & result into a temporary IdempotentTokenValue object and
set iterator result only in completeResumeReceive that is going to be invoked just once
*/
if (otherOp != null) return IdempotentTokenValue(token, value)
iterator.result = value
}
return token
return cont.tryResume(true, otherOp?.desc)
}

override fun completeResumeReceive(token: Any) {
if (token is IdempotentTokenValue<*>) {
iterator.result = token.value
cont.completeResume(token.token)
} else
cont.completeResume(token)
override fun completeResumeReceive(value: E) {
/*
When otherOp != null invocation of tryResumeReceive can happen multiple times and much later,
but completeResumeReceive is called once so we set iterator result here.
*/
iterator.result = value
cont.completeResume(RESUME_TOKEN)
}

override fun resumeReceiveClosed(closed: Closed<*>) {
Expand All @@ -966,14 +953,11 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
@JvmField val block: suspend (Any?) -> R,
@JvmField val receiveMode: Int
) : Receive<E>(), DisposableHandle {
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? {
val result = select.trySelectOther(otherOp)
return if (result === SELECT_STARTED) value ?: NULL_VALUE else result
}
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? =
select.trySelectOther(otherOp)

@Suppress("UNCHECKED_CAST")
override fun completeResumeReceive(token: Any) {
val value: E = NULL_VALUE.unbox<E>(token)
override fun completeResumeReceive(value: E) {
block.startCoroutine(if (receiveMode == RECEIVE_RESULT) ValueOrClosed.value(value) else value, select.completion)
}

Expand All @@ -997,11 +981,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E

override fun toString(): String = "ReceiveSelect[$select,receiveMode=$receiveMode]"
}

private class IdempotentTokenValue<out E>(
@JvmField val token: Any,
@JvmField val value: E
)
}

// receiveMode values
Expand All @@ -1025,18 +1004,6 @@ internal val POLL_FAILED: Any = Symbol("POLL_FAILED")
@SharedImmutable
internal val ENQUEUE_FAILED: Any = Symbol("ENQUEUE_FAILED")

@JvmField
@SharedImmutable
internal val NULL_VALUE: Symbol = Symbol("NULL_VALUE")

@JvmField
@SharedImmutable
internal val CLOSE_RESUMED: Any = Symbol("CLOSE_RESUMED")

@JvmField
@SharedImmutable
internal val SEND_RESUMED: Any = Symbol("SEND_RESUMED")

@JvmField
@SharedImmutable
internal val HANDLER_INVOKED: Any = Symbol("ON_CLOSE_HANDLER_INVOKED")
Expand All @@ -1050,10 +1017,10 @@ internal abstract class Send : LockFreeLinkedListNode() {
abstract val pollResult: Any? // E | Closed
// Returns: null - failure,
// RETRY_ATOMIC for retry (only when otherOp != null),
// otherwise token for completeResumeSend
// RESUME_TOKEN on success (call completeResumeSend)
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
abstract fun tryResumeSend(otherOp: PrepareOp?): Any?
abstract fun completeResumeSend(token: Any)
abstract fun completeResumeSend()
abstract fun resumeSendClosed(closed: Closed<*>)
}

Expand All @@ -1064,10 +1031,10 @@ internal interface ReceiveOrClosed<in E> {
val offerResult: Any // OFFER_SUCCESS | Closed
// Returns: null - failure,
// RETRY_ATOMIC for retry (only when otherOp != null),
// otherwise token for completeResumeReceive
// RESUME_TOKEN on success (call completeResumeReceive)
// Must call otherOp?.finishPrepare() before deciding on result other than RETRY_ATOMIC
fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any?
fun completeResumeReceive(token: Any)
fun completeResumeReceive(value: E)
}

/**
Expand All @@ -1082,7 +1049,7 @@ internal class SendElement(
otherOp?.finishPrepare()
return cont.tryResume(Unit, otherOp?.desc)
}
override fun completeResumeSend(token: Any) = cont.completeResume(token)
override fun completeResumeSend() = cont.completeResume(RESUME_TOKEN)
override fun resumeSendClosed(closed: Closed<*>) = cont.resumeWithException(closed.sendException)
override fun toString(): String = "SendElement($pollResult)"
}
Expand All @@ -1098,10 +1065,10 @@ internal class Closed<in E>(

override val offerResult get() = this
override val pollResult get() = this
override fun tryResumeSend(otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
override fun completeResumeSend(token: Any) { assert { token === CLOSE_RESUMED } }
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = CLOSE_RESUMED.also { otherOp?.finishPrepare() }
override fun completeResumeReceive(token: Any) { assert { token === CLOSE_RESUMED } }
override fun tryResumeSend(otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
override fun completeResumeSend() {}
override fun tryResumeReceive(value: E, otherOp: PrepareOp?): Any? = RESUME_TOKEN.also { otherOp?.finishPrepare() }
override fun completeResumeReceive(value: E) {}
override fun resumeSendClosed(closed: Closed<*>) = assert { false } // "Should be never invoked"
override fun toString(): String = "Closed[$closeCause]"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ internal class ArrayBroadcastChannel<E>(
private tailrec fun updateHead(addSub: Subscriber<E>? = null, removeSub: Subscriber<E>? = null) {
// update head in a tail rec loop
var send: Send? = null
var token: Any? = null
bufferLock.withLock {
if (addSub != null) {
addSub.subHead = tail // start from last element
Expand Down Expand Up @@ -172,8 +171,9 @@ internal class ArrayBroadcastChannel<E>(
while (true) {
send = takeFirstSendOrPeekClosed() ?: break // when when no sender
if (send is Closed<*>) break // break when closed for send
token = send!!.tryResumeSend(null)
val token = send!!.tryResumeSend(null)
if (token != null) {
assert { token === RESUME_TOKEN }
// put sent element to the buffer
buffer[(tail % capacity).toInt()] = (send as Send).pollResult
this.size = size + 1
Expand All @@ -186,7 +186,7 @@ internal class ArrayBroadcastChannel<E>(
return // done updating here -> return
}
// we only get out of the lock normally when there is a sender to resume
send!!.completeResumeSend(token!!)
send!!.completeResumeSend()
// since we've just sent an element, we might need to resume some receivers
checkSubOffers()
// tailrec call to recheck
Expand Down Expand Up @@ -239,9 +239,9 @@ internal class ArrayBroadcastChannel<E>(
// it means that `checkOffer` must be retried after every `unlock`
if (!subLock.tryLock()) break
val receive: ReceiveOrClosed<E>?
val token: Any?
var result: Any?
try {
val result = peekUnderLock()
result = peekUnderLock()
when {
result === POLL_FAILED -> continue@loop // must retest `needsToCheckOfferWithoutLock` outside of the lock
result is Closed<*> -> {
Expand All @@ -252,15 +252,15 @@ internal class ArrayBroadcastChannel<E>(
// find a receiver for an element
receive = takeFirstReceiveOrPeekClosed() ?: break // break when no one's receiving
if (receive is Closed<*>) break // noting more to do if this sub already closed
token = receive.tryResumeReceive(result as E, null)
if (token == null) continue // bail out here to next iteration (see for next receiver)
val token = receive.tryResumeReceive(result as E, null) ?: continue
assert { token === RESUME_TOKEN }
val subHead = this.subHead
this.subHead = subHead + 1 // retrieved element for this subscriber
updated = true
} finally {
subLock.unlock()
}
receive!!.completeResumeReceive(token!!)
receive!!.completeResumeReceive(result as E)
}
// do close outside of lock if needed
closed?.also { close(cause = it.closeCause) }
Expand Down
Loading