Skip to content

Commit ac144ef

Browse files
committed
Add Semaphore
Fixes #1088
1 parent 64be795 commit ac144ef

File tree

4 files changed

+451
-0
lines changed

4 files changed

+451
-0
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+12
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,18 @@ public final class kotlinx/coroutines/sync/MutexKt {
978978
public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
979979
}
980980

981+
public abstract interface class kotlinx/coroutines/sync/Semaphore {
982+
public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
983+
public abstract fun getAvailablePermits ()I
984+
public abstract fun release ()V
985+
public abstract fun tryAcquire ()Z
986+
}
987+
988+
public final class kotlinx/coroutines/sync/SemaphoreKt {
989+
public static final fun Semaphore (I)Lkotlinx/coroutines/sync/Semaphore;
990+
public static final fun withSemaphore (Lkotlinx/coroutines/sync/Semaphore;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
991+
}
992+
981993
public final class kotlinx/coroutines/test/TestCoroutineContext : kotlin/coroutines/CoroutineContext {
982994
public fun <init> ()V
983995
public fun <init> (Ljava/lang/String;)V
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
package kotlinx.coroutines.sync
2+
3+
import kotlinx.atomicfu.*
4+
import kotlinx.coroutines.CancellableContinuation
5+
import kotlinx.coroutines.internal.SharedImmutable
6+
import kotlinx.coroutines.internal.Symbol
7+
import kotlinx.coroutines.internal.systemProp
8+
import kotlinx.coroutines.suspendAtomicCancellableCoroutine
9+
import kotlin.coroutines.resume
10+
import kotlin.jvm.JvmField
11+
import kotlin.math.max
12+
13+
public interface Semaphore {
14+
public val availablePermits: Int
15+
public fun tryAcquire(): Boolean
16+
public suspend fun acquire()
17+
public fun release()
18+
}
19+
20+
public fun Semaphore(maxPermits: Int): Semaphore = SemaphoreImpl(maxPermits)
21+
22+
public suspend inline fun <T> Semaphore.withSemaphore(action: () -> T): T {
23+
acquire()
24+
try {
25+
return action()
26+
} finally {
27+
release()
28+
}
29+
}
30+
31+
private class SemaphoreImpl(@JvmField val maxPermits: Int): Semaphore {
32+
init {
33+
require(maxPermits > 0) { "Semaphore should have at least 1 permit"}
34+
}
35+
36+
private val _availablePermits = atomic(maxPermits)
37+
override val availablePermits: Int get() = max(_availablePermits.value.toInt(), 0)
38+
39+
private val enqIdx = atomic(0L)
40+
private val deqIdx = atomic(0L)
41+
42+
private val head: AtomicRef<Segment>
43+
private val tail: AtomicRef<Segment>
44+
45+
init {
46+
val emptyNode = Segment(0, null)
47+
head = atomic(emptyNode)
48+
tail = atomic(emptyNode)
49+
}
50+
51+
override fun tryAcquire(): Boolean {
52+
_availablePermits.loop { p ->
53+
if (p <= 0) return false
54+
if (_availablePermits.compareAndSet(p, p - 1)) return true
55+
}
56+
}
57+
58+
override suspend fun acquire() {
59+
val p = _availablePermits.getAndDecrement()
60+
if (p > 0) return // permit acquired
61+
addToQueueAndSuspend()
62+
}
63+
64+
override fun release() {
65+
val p = _availablePermits.getAndUpdate { cur ->
66+
check(cur < maxPermits) { "Cannot" }
67+
cur + 1
68+
}
69+
if (p >= 0) return // no waiters
70+
resumeNextFromQueue()
71+
}
72+
73+
private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutine<Unit> sc@ { cont ->
74+
val tail = this.tail.value
75+
val enqIdx = enqIdx.getAndIncrement()
76+
val segment = findOrCreateSegment(enqIdx / SEGMENT_SIZE, tail)
77+
val i = (enqIdx % SEGMENT_SIZE).toInt()
78+
if (segment === null || segment[i].value === RESUMED || !segment[i].compareAndSet(null, cont)) {
79+
cont.resume(Unit)
80+
return@sc
81+
}
82+
cont.invokeOnCancellation {
83+
segment.clean(i)
84+
release()
85+
}
86+
}
87+
88+
private fun resumeNextFromQueue() {
89+
val head = this.head.value
90+
val deqIdx = deqIdx.getAndIncrement()
91+
val segment = getHeadAndUpdate(deqIdx / SEGMENT_SIZE, head) ?: return
92+
val i = (deqIdx % SEGMENT_SIZE).toInt()
93+
val cont = segment[i].getAndUpdate {
94+
if (it === CLEANED) it else RESUMED
95+
}
96+
if (cont === CLEANED) return
97+
cont as CancellableContinuation<Unit>
98+
cont.resume(Unit)
99+
}
100+
101+
/**
102+
* Finds or creates segment similarly to [findOrCreateSegment],
103+
* but updates the [head] reference to the found segment as well.
104+
*/
105+
private fun getHeadAndUpdate(id: Long, headOrOutdated: Segment): Segment? {
106+
// Check whether the provided segment has the required `id`
107+
// and just return it in this case.
108+
if (headOrOutdated.id == id) {
109+
return headOrOutdated
110+
}
111+
// Find (or even create) the required segment
112+
// and update the `head` pointer.
113+
val head = findOrCreateSegment(id, headOrOutdated) ?: return null
114+
moveHeadForward(head)
115+
// We should clean `prev` references on `head` updates,
116+
// so they do not reference to the old segments. However,
117+
// it is fine to clean the `prev` reference of the new head only.
118+
// The previous "chain" of segments becomes no longer available from
119+
// segment queue structure and can be collected by GC.
120+
//
121+
// Note, that in practice it would be better to clean `next` references as well,
122+
// since it helps some GC (on JVM). However, this breaks the algorithm.
123+
head.prev.value = null
124+
return head
125+
}
126+
127+
/**
128+
* Finds or creates a segment with the specified [id] if it exists,
129+
* or with a minimal but greater than the specified `id`
130+
* (`segment.id >= id`) if the required segment was removed
131+
* This method starts search from the provided [cur] segment,
132+
* going by `next` references. Returns `null` if this channels is closed
133+
* and a new segment should be added.
134+
*/
135+
private fun findOrCreateSegment(id: Long, cur: Segment): Segment? {
136+
if (cur.id > id) return null
137+
// This method goes through `next` references and
138+
// adds new segments if needed, similarly to the `push` in
139+
// the Michael-Scott queue algorithm.
140+
var cur = cur
141+
while (cur.id < id) {
142+
var curNext = cur.next.value
143+
if (curNext == null) {
144+
// Add a new segment.
145+
val newTail = Segment(cur.id + 1, cur)
146+
curNext = if (cur.next.compareAndSet(null, newTail)) {
147+
if (cur.removed) {
148+
cur.remove()
149+
}
150+
moveTailForward(newTail)
151+
newTail
152+
} else {
153+
cur.next.value!!
154+
}
155+
}
156+
cur = curNext
157+
}
158+
return cur
159+
}
160+
161+
/**
162+
* Updates [head] to the specified segment
163+
* if its `id` is greater.
164+
*/
165+
private fun moveHeadForward(new: Segment) {
166+
while (true) {
167+
val cur = head.value
168+
if (cur.id > new.id) return
169+
if (this.head.compareAndSet(cur, new)) return
170+
}
171+
}
172+
173+
/**
174+
* Updates [tail] to the specified segment
175+
* if its `id` is greater.
176+
*/
177+
private fun moveTailForward(new: Segment) {
178+
while (true) {
179+
val cur = this.tail.value
180+
if (cur.id > new.id) return
181+
if (this.tail.compareAndSet(cur, new)) return
182+
}
183+
}
184+
185+
}
186+
187+
private class Segment(@JvmField val id: Long) {
188+
constructor(id: Long, prev: Segment?) : this(id) {
189+
this.prev.value = prev
190+
}
191+
192+
// == Waiters Array ==
193+
private val waiters = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
194+
195+
operator fun get(index: Int): AtomicRef<Any?> = waiters[index]
196+
197+
// == Michael-Scott Queue + Fast Removing from the Middle ==
198+
199+
// Pointer to the next segments, updates
200+
// similarly to the Michael-Scott queue algorithm.
201+
val next = atomic<Segment?>(null) // null (not set) | Segment | CLOSED
202+
// Pointer to the previous non-empty segment (can be null!),
203+
// updates lazily (see `remove()` function).
204+
val prev = atomic<Segment?>(null)
205+
// Number of cleaned waiters in this segment.
206+
private val cleaned = atomic(0)
207+
val removed get() = cleaned.value == SEGMENT_SIZE
208+
209+
/**
210+
* Cleans the waiter located by the specified index in this segment.
211+
*/
212+
fun clean(index: Int) {
213+
// Clean the specified waiter and
214+
// check if all node items are cleaned.
215+
waiters[index].value = CLEANED
216+
if (cleaned.incrementAndGet() < SEGMENT_SIZE) return
217+
// Remove this node
218+
remove()
219+
}
220+
221+
/**
222+
* Removes this node from the waiting queue and cleans all references to it.
223+
*/
224+
fun remove() {
225+
var next = this.next.value ?: return // tail can't be removed
226+
// Find the first non-removed node (tail is always non-removed)
227+
while (next.removed) {
228+
next = this.next.value ?: return
229+
}
230+
// Find the first non-removed `prev` and remove this node
231+
var prev = prev.value
232+
while (true) {
233+
if (prev == null) {
234+
next.prev.value = null
235+
return
236+
}
237+
if (prev.removed) {
238+
prev = prev.prev.value
239+
continue
240+
}
241+
next.movePrevToLeft(prev)
242+
prev.movePrevNextToRight(next)
243+
if (next.removed || !prev.removed) return
244+
prev = prev.prev.value
245+
}
246+
}
247+
248+
/**
249+
* Update [Segment.next] pointer to the specified one if
250+
* the `id` of the specified segment is greater.
251+
*/
252+
private fun movePrevNextToRight(next: Segment) {
253+
while (true) {
254+
val curNext = this.next.value as Segment
255+
if (next.id <= curNext.id) return
256+
if (this.next.compareAndSet(curNext, next)) return
257+
}
258+
}
259+
260+
/**
261+
* Update [Segment.prev] pointer to the specified segment if
262+
* its `id` is lower.
263+
*/
264+
private fun movePrevToLeft(prev: Segment) {
265+
while (true) {
266+
val curPrev = this.prev.value ?: return
267+
if (curPrev.id <= prev.id) return
268+
if (this.prev.compareAndSet(curPrev, prev)) return
269+
}
270+
}
271+
}
272+
273+
@SharedImmutable
274+
private val RESUMED = Symbol("RESUMED")
275+
@SharedImmutable
276+
private val CLEANED = Symbol("CLEANED")
277+
@SharedImmutable
278+
private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.semaphore.segmentSize", 32)

0 commit comments

Comments
 (0)