1
1
package kotlinx.coroutines.sync
2
2
3
- import kotlinx.atomicfu.*
3
+ import kotlinx.atomicfu.atomic
4
+ import kotlinx.atomicfu.atomicArrayOfNulls
5
+ import kotlinx.atomicfu.getAndUpdate
6
+ import kotlinx.atomicfu.loop
4
7
import kotlinx.coroutines.*
5
8
import kotlinx.coroutines.internal.*
6
9
import kotlin.coroutines.resume
7
- import kotlin.jvm.JvmField
8
10
import kotlin.math.max
9
11
10
12
/* *
11
13
* A counting semaphore for coroutines. It maintains a number of available permits.
12
14
* Each [acquire] suspends if necessary until a permit is available, and then takes it.
13
15
* Each [release] adds a permit, potentially releasing a suspended acquirer.
14
16
*
15
- * Semaphore with `maxPermits = 1` is essentially a [Mutex].
17
+ * Semaphore with `permits = 1` is essentially a [Mutex].
16
18
**/
17
19
public interface Semaphore {
18
20
/* *
19
- * Returns the current number of available permits available in this semaphore.
21
+ * Returns the current number of permits available in this semaphore.
20
22
*/
21
23
public val availablePermits: Int
22
24
@@ -27,8 +29,8 @@ public interface Semaphore {
27
29
* This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this
28
30
* function is suspended, this function immediately resumes with [CancellationException].
29
31
*
30
- * *Cancellation of suspended lock invocation is atomic* -- when this function
31
- * throws [CancellationException] it means that the mutex was not locked .
32
+ * *Cancellation of suspended semaphore acquisition` is atomic* -- when this function
33
+ * throws [CancellationException] it means that the semaphore was not acquired .
32
34
*
33
35
* Note, that this function does not check for cancellation when it is not suspended.
34
36
* Use [yield] or [CoroutineScope.isActive] to periodically check for cancellation in tight loops if needed.
@@ -47,23 +49,28 @@ public interface Semaphore {
47
49
/* *
48
50
* Releases a permit, returning it into this semaphore. Resumes the first
49
51
* suspending acquirer if there is one at the point of invocation.
52
+ * Throws [IllegalStateException] if there is no acquired permit
53
+ * at the point of invocation.
50
54
*/
51
55
public fun release ()
52
56
}
53
57
54
58
/* *
55
59
* Creates new [Semaphore] instance.
60
+ * @param permits the number of permits available in this semaphore.
61
+ * @param acquiredPermits the number of already acquired permits,
62
+ * should be between `0` and `permits` (inclusively).
56
63
*/
57
64
@Suppress(" FunctionName" )
58
- public fun Semaphore (maxPermits : Int , acquiredPermits : Int = 0): Semaphore = SemaphoreImpl (maxPermits , acquiredPermits)
65
+ public fun Semaphore (permits : Int , acquiredPermits : Int = 0): Semaphore = SemaphoreImpl (permits , acquiredPermits)
59
66
60
67
/* *
61
- * Executes the given [action] with acquiring a permit from this semaphore at the beginning
68
+ * Executes the given [action], acquiring a permit from this semaphore at the beginning
62
69
* and releasing it after the [action] is completed.
63
70
*
64
71
* @return the return value of the [action].
65
72
*/
66
- public suspend inline fun <T > Semaphore.withSemaphore (action : () -> T ): T {
73
+ public suspend inline fun <T > Semaphore.withPermit (action : () -> T ): T {
67
74
acquire()
68
75
try {
69
76
return action()
@@ -72,24 +79,24 @@ public suspend inline fun <T> Semaphore.withSemaphore(action: () -> T): T {
72
79
}
73
80
}
74
81
75
- private class SemaphoreImpl (@JvmField val maxPermits : Int , acquiredPermits : Int )
82
+ private class SemaphoreImpl (private val permits : Int , acquiredPermits : Int )
76
83
: Semaphore , SegmentQueue <SemaphoreSegment >(createFirstSegmentLazily = true )
77
84
{
78
85
init {
79
- require(maxPermits > 0 ) { " Semaphore should have at least 1 permit" }
80
- require(acquiredPermits in 0 .. maxPermits ) { " The number of acquired permits should be ≥ 0 and ≤ maxPermits " }
86
+ require(permits > 0 ) { " Semaphore should have at least 1 permit" }
87
+ require(acquiredPermits in 0 .. permits ) { " The number of acquired permits should be ≥ 0 and ≤ permits " }
81
88
}
82
89
83
90
override fun newSegment (id : Long , prev : SemaphoreSegment ? )= SemaphoreSegment (id, prev)
84
91
85
92
/* *
86
93
* This counter indicates a number of available permits if it is non-negative,
87
94
* or the size with minus sign otherwise. Note, that 32-bit counter is enough here
88
- * since the maximal number of available permits is [maxPermits ] which is [Int],
95
+ * since the maximal number of available permits is [permits ] which is [Int],
89
96
* and the maximum number of waiting acquirers cannot be greater than 2^31 in any
90
97
* real application.
91
98
*/
92
- private val _availablePermits = atomic(maxPermits )
99
+ private val _availablePermits = atomic(permits )
93
100
override val availablePermits: Int get() = max(_availablePermits .value, 0 )
94
101
95
102
// The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
@@ -115,51 +122,68 @@ private class SemaphoreImpl(@JvmField val maxPermits: Int, acquiredPermits: Int)
115
122
116
123
override fun release () {
117
124
val p = _availablePermits .getAndUpdate { cur ->
118
- check(cur < maxPermits ) { " The number of acquired permits cannot be greater than maxPermits " }
125
+ check(cur < permits ) { " The number of acquired permits cannot be greater than `permits` " }
119
126
cur + 1
120
127
}
121
128
if (p >= 0 ) return // no waiters
122
129
resumeNextFromQueue()
123
130
}
124
131
125
132
private suspend fun addToQueueAndSuspend () = suspendAtomicCancellableCoroutine<Unit > sc@ { cont ->
126
- val last = this .last
133
+ val last = this .tail
127
134
val enqIdx = enqIdx.getAndIncrement()
128
135
val segment = getSegment(last, enqIdx / SEGMENT_SIZE )
129
136
val i = (enqIdx % SEGMENT_SIZE ).toInt()
130
- if (segment == = null || segment[i].value == = RESUMED || ! segment[i].compareAndSet( null , cont)) {
137
+ if (segment == = null || segment.get(i) == = RESUMED || ! segment.cas(i, null , cont)) {
131
138
// already resumed
132
139
cont.resume(Unit )
133
140
return @sc
134
141
}
135
- cont.invokeOnCancellation(handler = object : CancelHandler () {
136
- override fun invoke (cause : Throwable ? ) {
137
- segment.cancel(i)
138
- release()
139
- }
140
- }.asHandler)
142
+ cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler (this , segment, i).asHandler)
141
143
}
142
144
145
+ @Suppress(" UNCHECKED_CAST" )
143
146
private fun resumeNextFromQueue () {
144
- val first = this .first
147
+ val first = this .head
145
148
val deqIdx = deqIdx.getAndIncrement()
146
149
val segment = getSegmentAndMoveFirst(first, deqIdx / SEGMENT_SIZE ) ? : return
147
150
val i = (deqIdx % SEGMENT_SIZE ).toInt()
148
- val cont = segment[i].getAndUpdate {
149
- if (it == = CANCELLED ) CANCELLED else RESUMED
151
+ val cont = segment.getAndUpdate(i) {
152
+ // Cancelled continuation invokes `release`
153
+ // and resumes next suspended acquirer if needed.
154
+ if (it == = CANCELLED ) return
155
+ RESUMED
150
156
}
151
157
if (cont == = null ) return // just resumed
152
- if (cont == = CANCELLED ) return // Cancelled continuation invokes `release`
153
- // and resumes next suspended acquirer if needed.
154
- cont as CancellableContinuation <Unit >
155
- cont.resume(Unit )
158
+ (cont as CancellableContinuation <Unit >).resume(Unit )
156
159
}
157
160
}
158
161
162
+ private class CancelSemaphoreAcquisitionHandler (private val semaphore : Semaphore , private val segment : SemaphoreSegment , private val index : Int ) : CancelHandler() {
163
+ override fun invoke (cause : Throwable ? ) {
164
+ segment.cancel(index)
165
+ semaphore.release()
166
+ }
167
+
168
+ override fun toString () = " CancelSemaphoreAcquisitionHandler[$semaphore , $segment , $index ]"
169
+ }
170
+
159
171
private class SemaphoreSegment (id : Long , prev : SemaphoreSegment ? ): Segment<SemaphoreSegment>(id, prev) {
160
172
private val acquirers = atomicArrayOfNulls<Any ?>(SEGMENT_SIZE )
161
173
162
- operator fun get (index : Int ): AtomicRef <Any ?> = acquirers[index]
174
+ @Suppress(" NOTHING_TO_INLINE" )
175
+ inline fun get (index : Int ): Any? = acquirers[index].value
176
+
177
+ @Suppress(" NOTHING_TO_INLINE" )
178
+ inline fun cas (index : Int , expected : Any? , value : Any? ): Boolean = acquirers[index].compareAndSet(expected, value)
179
+
180
+ inline fun getAndUpdate (index : Int , function : (Any? ) -> Any? ): Any? {
181
+ while (true ) {
182
+ val cur = acquirers[index].value
183
+ val upd = function(cur)
184
+ if (cas(index, cur, upd)) return cur
185
+ }
186
+ }
163
187
164
188
private val cancelledSlots = atomic(0 )
165
189
override val removed get() = cancelledSlots.value == SEGMENT_SIZE
@@ -173,6 +197,8 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<Semap
173
197
if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE )
174
198
remove()
175
199
}
200
+
201
+ override fun toString () = " SemaphoreSegment[id=$id , hashCode=${hashCode()} ]"
176
202
}
177
203
178
204
@SharedImmutable
0 commit comments