Skip to content

Commit 6c4ed63

Browse files
committed
Make the semaphore implementation linearizable (ignoring extra suspensions)
1 parent 7c49d4b commit 6c4ed63

File tree

1 file changed

+62
-49
lines changed

1 file changed

+62
-49
lines changed

kotlinx-coroutines-core/common/src/sync/Semaphore.kt

+62-49
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import kotlinx.coroutines.*
99
import kotlinx.coroutines.internal.*
1010
import kotlin.coroutines.*
1111
import kotlin.math.*
12-
import kotlin.native.concurrent.SharedImmutable
12+
import kotlin.native.concurrent.*
1313

1414
/**
1515
* A counting semaphore for coroutines that logically maintains a number of available permits.
@@ -121,76 +121,80 @@ private class SemaphoreImpl(private val permits: Int, acquiredPermits: Int) : Se
121121
}
122122

123123
override suspend fun acquire() {
124-
val p = _availablePermits.getAndDecrement()
125-
if (p > 0) return // permit acquired
126-
addToQueueAndSuspend()
124+
while (true) {
125+
val p = _availablePermits.getAndDecrement()
126+
if (p > 0) return // permit acquired
127+
if (addToQueueAndSuspend()) return
128+
}
127129
}
128130

129131
override fun release() {
130-
val p = incPermits()
131-
if (p >= 0) return // no waiters
132-
resumeNextFromQueue()
133-
}
134-
135-
fun incPermits() = _availablePermits.getAndUpdate { cur ->
136-
check(cur < permits) { "The number of released permits cannot be greater than $permits" }
137-
cur + 1
132+
while (true) {
133+
val p = _availablePermits.getAndUpdate { cur ->
134+
check(cur < permits) { "The number of released permits cannot be greater than $permits" }
135+
cur + 1
136+
}
137+
if (p >= 0) return
138+
if (tryResumeNextFromQueue()) return
139+
}
138140
}
139141

140-
private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Unit> sc@{ cont ->
142+
/**
143+
* Returns `false` if the received permit cannot be used and the calling operation should restart.
144+
*/
145+
private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutineReusable<Boolean> sc@{ cont ->
141146
val curTail = this.tail.value
142147
val enqIdx = enqIdx.getAndIncrement()
143148
val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail,
144-
createNewSegment = ::createSegment).run { segment } // cannot be closed
149+
createNewSegment = ::createSegment).segment // cannot be closed
145150
val i = (enqIdx % SEGMENT_SIZE).toInt()
146-
if (segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
147-
// already resumed
148-
cont.resume(Unit)
151+
if (segment.get(i) === PERMIT || !segment.cas(i, null, cont)) {
152+
// The permit is already in the queue, try to grab it
153+
cont.resume(segment.cas(i, PERMIT, TAKEN))
149154
return@sc
150155
}
151-
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, i).asHandler)
156+
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler)
152157
}
153158

154159
@Suppress("UNCHECKED_CAST")
155-
internal fun resumeNextFromQueue() {
156-
try_again@ while (true) {
157-
val curHead = this.head.value
158-
val deqIdx = deqIdx.getAndIncrement()
159-
val id = deqIdx / SEGMENT_SIZE
160-
val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
161-
createNewSegment = ::createSegment).run { segment } // cannot be closed
162-
segment.cleanPrev()
163-
if (segment.id > id) {
164-
this.deqIdx.updateIfLower(segment.id * SEGMENT_SIZE)
165-
continue@try_again
160+
private fun tryResumeNextFromQueue(): Boolean {
161+
val curHead = this.head.value
162+
val deqIdx = deqIdx.getAndIncrement()
163+
val id = deqIdx / SEGMENT_SIZE
164+
val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
165+
createNewSegment = ::createSegment).segment // cannot be closed
166+
segment.cleanPrev()
167+
if (segment.id > id) return false
168+
val i = (deqIdx % SEGMENT_SIZE).toInt()
169+
val cont = segment.getAndSet(i, PERMIT)
170+
if (cont === CANCELLED) return false
171+
if (cont === null) {
172+
// Wait until an opposite operation comes for a bounded time
173+
repeat(MAX_SPIN_CYCLES) {
174+
if (segment.get(i) === TAKEN) return true
166175
}
167-
val i = (deqIdx % SEGMENT_SIZE).toInt()
168-
val cont = segment.getAndSet(i, RESUMED)
169-
if (cont === null) return // just resumed
170-
if (cont === CANCELLED) continue@try_again
171-
(cont as CancellableContinuation<Unit>).resume(Unit)
172-
return
176+
// Try to break the slot in order not to wait
177+
return !segment.cas(i, PERMIT, BROKEN)
173178
}
179+
return (cont as CancellableContinuation<Boolean>).tryResume()
174180
}
175181
}
176182

177-
private inline fun AtomicLong.updateIfLower(value: Long): Unit = loop { cur ->
178-
if (cur >= value || compareAndSet(cur, value)) return
183+
private fun CancellableContinuation<Boolean>.tryResume(): Boolean {
184+
val token = tryResume(true) ?: return false
185+
completeResume(token)
186+
return true
179187
}
180188

181189
private class CancelSemaphoreAcquisitionHandler(
182-
private val semaphore: SemaphoreImpl,
183190
private val segment: SemaphoreSegment,
184191
private val index: Int
185192
) : CancelHandler() {
186193
override fun invoke(cause: Throwable?) {
187-
val p = semaphore.incPermits()
188-
if (p >= 0) return
189-
if (segment.cancel(index)) return
190-
semaphore.resumeNextFromQueue()
194+
segment.cancel(index)
191195
}
192196

193-
override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
197+
override fun toString() = "CancelSemaphoreAcquisitionHandler[$segment, $index]"
194198
}
195199

196200
private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
@@ -202,6 +206,11 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int)
202206
@Suppress("NOTHING_TO_INLINE")
203207
inline fun get(index: Int): Any? = acquirers[index].value
204208

209+
@Suppress("NOTHING_TO_INLINE")
210+
inline fun set(index: Int, value: Any?) {
211+
acquirers[index].value = value
212+
}
213+
205214
@Suppress("NOTHING_TO_INLINE")
206215
inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value)
207216

@@ -210,19 +219,23 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int)
210219

211220
// Cleans the acquirer slot located by the specified index
212221
// and removes this segment physically if all slots are cleaned.
213-
fun cancel(index: Int): Boolean {
214-
// Try to cancel the slot
215-
val cancelled = getAndSet(index, CANCELLED) !== RESUMED
222+
fun cancel(index: Int) {
223+
// Clean the slot
224+
set(index, CANCELLED)
216225
// Remove this segment if needed
217226
onSlotCleaned()
218-
return cancelled
219227
}
220228

221229
override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]"
222230
}
223-
224231
@SharedImmutable
225-
private val RESUMED = Symbol("RESUMED")
232+
private val MAX_SPIN_CYCLES = systemProp("kotlinx.coroutines.semaphore.maxSpinCycles", 100_000)
233+
@SharedImmutable
234+
private val PERMIT = Symbol("PERMIT")
235+
@SharedImmutable
236+
private val TAKEN = Symbol("TAKEN")
237+
@SharedImmutable
238+
private val BROKEN = Symbol("TAKEN")
226239
@SharedImmutable
227240
private val CANCELLED = Symbol("CANCELLED")
228241
@SharedImmutable

0 commit comments

Comments
 (0)