1
+ package kotlinx.coroutines.sync
2
+
3
+ import kotlinx.atomicfu.*
4
+ import kotlinx.coroutines.CancellableContinuation
5
+ import kotlinx.coroutines.internal.SharedImmutable
6
+ import kotlinx.coroutines.internal.Symbol
7
+ import kotlinx.coroutines.internal.systemProp
8
+ import kotlinx.coroutines.suspendAtomicCancellableCoroutine
9
+ import kotlin.coroutines.resume
10
+ import kotlin.jvm.JvmField
11
+ import kotlin.math.max
12
+
13
+ public interface Semaphore {
14
+ public val availablePermits: Int
15
+ public fun tryAcquire (): Boolean
16
+ public suspend fun acquire ()
17
+ public fun release ()
18
+ }
19
+
20
+ public fun Semaphore (maxPermits : Int ): Semaphore = SemaphoreImpl (maxPermits)
21
+
22
+ public suspend inline fun <T > Semaphore.withSemaphore (action : () -> T ): T {
23
+ acquire()
24
+ try {
25
+ return action()
26
+ } finally {
27
+ release()
28
+ }
29
+ }
30
+
31
+ private class SemaphoreImpl (@JvmField val maxPermits : Int ): Semaphore {
32
+ init {
33
+ require(maxPermits > 0 ) { " Semaphore should have at least 1 permit" }
34
+ }
35
+
36
+ private val _availablePermits = atomic(maxPermits)
37
+ override val availablePermits: Int get() = max(_availablePermits .value.toInt(), 0 )
38
+
39
+ private val enqIdx = atomic(0L )
40
+ private val deqIdx = atomic(0L )
41
+
42
+ private val head: AtomicRef <Segment >
43
+ private val tail: AtomicRef <Segment >
44
+
45
+ init {
46
+ val emptyNode = Segment (0 , null )
47
+ head = atomic(emptyNode)
48
+ tail = atomic(emptyNode)
49
+ }
50
+
51
+ override fun tryAcquire (): Boolean {
52
+ _availablePermits .loop { p ->
53
+ if (p <= 0 ) return false
54
+ if (_availablePermits .compareAndSet(p, p - 1 )) return true
55
+ }
56
+ }
57
+
58
+ override suspend fun acquire () {
59
+ val p = _availablePermits .getAndDecrement()
60
+ if (p > 0 ) return // permit acquired
61
+ addToQueueAndSuspend()
62
+ }
63
+
64
+ override fun release () {
65
+ val p = _availablePermits .getAndUpdate { cur ->
66
+ check(cur < maxPermits) { " Cannot" }
67
+ cur + 1
68
+ }
69
+ if (p >= 0 ) return // no waiters
70
+ resumeNextFromQueue()
71
+ }
72
+
73
+ private suspend fun addToQueueAndSuspend () = suspendAtomicCancellableCoroutine<Unit > sc@ { cont ->
74
+ val tail = this .tail.value
75
+ val enqIdx = enqIdx.getAndIncrement()
76
+ val segment = findOrCreateSegment(enqIdx / SEGMENT_SIZE , tail)
77
+ val i = (enqIdx % SEGMENT_SIZE ).toInt()
78
+ if (segment == = null || segment[i].value == = RESUMED || ! segment[i].compareAndSet(null , cont)) {
79
+ cont.resume(Unit )
80
+ return @sc
81
+ }
82
+ cont.invokeOnCancellation {
83
+ segment.clean(i)
84
+ release()
85
+ }
86
+ }
87
+
88
+ private fun resumeNextFromQueue () {
89
+ val head = this .head.value
90
+ val deqIdx = deqIdx.getAndIncrement()
91
+ val segment = getHeadAndUpdate(deqIdx / SEGMENT_SIZE , head) ? : return
92
+ val i = (deqIdx % SEGMENT_SIZE ).toInt()
93
+ val cont = segment[i].getAndUpdate {
94
+ if (it == = CLEANED ) it else RESUMED
95
+ }
96
+ if (cont == = CLEANED ) return
97
+ cont as CancellableContinuation <Unit >
98
+ cont.resume(Unit )
99
+ }
100
+
101
+ /* *
102
+ * Finds or creates segment similarly to [findOrCreateSegment],
103
+ * but updates the [head] reference to the found segment as well.
104
+ */
105
+ private fun getHeadAndUpdate (id : Long , headOrOutdated : Segment ): Segment ? {
106
+ // Check whether the provided segment has the required `id`
107
+ // and just return it in this case.
108
+ if (headOrOutdated.id == id) {
109
+ return headOrOutdated
110
+ }
111
+ // Find (or even create) the required segment
112
+ // and update the `head` pointer.
113
+ val head = findOrCreateSegment(id, headOrOutdated) ? : return null
114
+ moveHeadForward(head)
115
+ // We should clean `prev` references on `head` updates,
116
+ // so they do not reference to the old segments. However,
117
+ // it is fine to clean the `prev` reference of the new head only.
118
+ // The previous "chain" of segments becomes no longer available from
119
+ // segment queue structure and can be collected by GC.
120
+ //
121
+ // Note, that in practice it would be better to clean `next` references as well,
122
+ // since it helps some GC (on JVM). However, this breaks the algorithm.
123
+ head.prev.value = null
124
+ return head
125
+ }
126
+
127
+ /* *
128
+ * Finds or creates a segment with the specified [id] if it exists,
129
+ * or with a minimal but greater than the specified `id`
130
+ * (`segment.id >= id`) if the required segment was removed
131
+ * This method starts search from the provided [cur] segment,
132
+ * going by `next` references. Returns `null` if this channels is closed
133
+ * and a new segment should be added.
134
+ */
135
+ private fun findOrCreateSegment (id : Long , cur : Segment ): Segment ? {
136
+ if (cur.id > id) return null
137
+ // This method goes through `next` references and
138
+ // adds new segments if needed, similarly to the `push` in
139
+ // the Michael-Scott queue algorithm.
140
+ var cur = cur
141
+ while (cur.id < id) {
142
+ var curNext = cur.next.value
143
+ if (curNext == null ) {
144
+ // Add a new segment.
145
+ val newTail = Segment (cur.id + 1 , cur)
146
+ curNext = if (cur.next.compareAndSet(null , newTail)) {
147
+ if (cur.removed) {
148
+ cur.remove()
149
+ }
150
+ moveTailForward(newTail)
151
+ newTail
152
+ } else {
153
+ cur.next.value!!
154
+ }
155
+ }
156
+ cur = curNext
157
+ }
158
+ return cur
159
+ }
160
+
161
+ /* *
162
+ * Updates [head] to the specified segment
163
+ * if its `id` is greater.
164
+ */
165
+ private fun moveHeadForward (new : Segment ) {
166
+ while (true ) {
167
+ val cur = head.value
168
+ if (cur.id > new.id) return
169
+ if (this .head.compareAndSet(cur, new)) return
170
+ }
171
+ }
172
+
173
+ /* *
174
+ * Updates [tail] to the specified segment
175
+ * if its `id` is greater.
176
+ */
177
+ private fun moveTailForward (new : Segment ) {
178
+ while (true ) {
179
+ val cur = this .tail.value
180
+ if (cur.id > new.id) return
181
+ if (this .tail.compareAndSet(cur, new)) return
182
+ }
183
+ }
184
+
185
+ }
186
+
187
+ private class Segment (@JvmField val id : Long ) {
188
+ constructor (id: Long , prev: Segment ? ) : this (id) {
189
+ this .prev.value = prev
190
+ }
191
+
192
+ // == Waiters Array ==
193
+ private val waiters = atomicArrayOfNulls<Any ?>(SEGMENT_SIZE )
194
+
195
+ operator fun get (index : Int ): AtomicRef <Any ?> = waiters[index]
196
+
197
+ // == Michael-Scott Queue + Fast Removing from the Middle ==
198
+
199
+ // Pointer to the next segments, updates
200
+ // similarly to the Michael-Scott queue algorithm.
201
+ val next = atomic<Segment ?>(null ) // null (not set) | Segment | CLOSED
202
+ // Pointer to the previous non-empty segment (can be null!),
203
+ // updates lazily (see `remove()` function).
204
+ val prev = atomic<Segment ?>(null )
205
+ // Number of cleaned waiters in this segment.
206
+ private val cleaned = atomic(0 )
207
+ val removed get() = cleaned.value == SEGMENT_SIZE
208
+
209
+ /* *
210
+ * Cleans the waiter located by the specified index in this segment.
211
+ */
212
+ fun clean (index : Int ) {
213
+ // Clean the specified waiter and
214
+ // check if all node items are cleaned.
215
+ waiters[index].value = CLEANED
216
+ if (cleaned.incrementAndGet() < SEGMENT_SIZE ) return
217
+ // Remove this node
218
+ remove()
219
+ }
220
+
221
+ /* *
222
+ * Removes this node from the waiting queue and cleans all references to it.
223
+ */
224
+ fun remove () {
225
+ var next = this .next.value ? : return // tail can't be removed
226
+ // Find the first non-removed node (tail is always non-removed)
227
+ while (next.removed) {
228
+ next = this .next.value ? : return
229
+ }
230
+ // Find the first non-removed `prev` and remove this node
231
+ var prev = prev.value
232
+ while (true ) {
233
+ if (prev == null ) {
234
+ next.prev.value = null
235
+ return
236
+ }
237
+ if (prev.removed) {
238
+ prev = prev.prev.value
239
+ continue
240
+ }
241
+ next.movePrevToLeft(prev)
242
+ prev.movePrevNextToRight(next)
243
+ if (next.removed || ! prev.removed) return
244
+ prev = prev.prev.value
245
+ }
246
+ }
247
+
248
+ /* *
249
+ * Update [Segment.next] pointer to the specified one if
250
+ * the `id` of the specified segment is greater.
251
+ */
252
+ private fun movePrevNextToRight (next : Segment ) {
253
+ while (true ) {
254
+ val curNext = this .next.value as Segment
255
+ if (next.id <= curNext.id) return
256
+ if (this .next.compareAndSet(curNext, next)) return
257
+ }
258
+ }
259
+
260
+ /* *
261
+ * Update [Segment.prev] pointer to the specified segment if
262
+ * its `id` is lower.
263
+ */
264
+ private fun movePrevToLeft (prev : Segment ) {
265
+ while (true ) {
266
+ val curPrev = this .prev.value ? : return
267
+ if (curPrev.id <= prev.id) return
268
+ if (this .prev.compareAndSet(curPrev, prev)) return
269
+ }
270
+ }
271
+ }
272
+
273
+ @SharedImmutable
274
+ private val RESUMED = Symbol (" RESUMED" )
275
+ @SharedImmutable
276
+ private val CLEANED = Symbol (" CLEANED" )
277
+ @SharedImmutable
278
+ private val SEGMENT_SIZE = systemProp(" kotlinx.coroutines.semaphore.segmentSize" , 32 )
0 commit comments