Skip to content

Commit 753850a

Browse files
committed
Review fixes
1 parent 04825a8 commit 753850a

File tree

9 files changed

+314
-276
lines changed

9 files changed

+314
-276
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ public abstract interface class kotlinx/coroutines/sync/Semaphore {
991991
public final class kotlinx/coroutines/sync/SemaphoreKt {
992992
public static final fun Semaphore (II)Lkotlinx/coroutines/sync/Semaphore;
993993
public static synthetic fun Semaphore$default (IIILjava/lang/Object;)Lkotlinx/coroutines/sync/Semaphore;
994-
public static final fun withSemaphore (Lkotlinx/coroutines/sync/Semaphore;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
994+
public static final fun withPermit (Lkotlinx/coroutines/sync/Semaphore;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
995995
}
996996

997997
public final class kotlinx/coroutines/test/TestCoroutineContext : kotlin/coroutines/CoroutineContext {

gradle.properties

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ kotlin_version=1.3.30
55

66
# Dependencies
77
junit_version=4.12
8-
atomicfu_version=0.12.5
8+
atomicfu_version=0.12.6
99
html_version=0.6.8
1010
lincheck_version=2.0
1111
dokka_version=0.9.16-rdev-2-mpp-hacks

kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt

+37-31
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package kotlinx.coroutines.internal
22

33
import kotlinx.atomicfu.AtomicRef
44
import kotlinx.atomicfu.atomic
5+
import kotlinx.atomicfu.loop
56

67
/**
78
* Essentially, this segment queue is an infinite array of segments, which is represented as
8-
* a Michael-Scott queue of them. All segments are instances of [Segment] interface and
9+
* a Michael-Scott queue of them. All segments are instances of [Segment] class and
910
* follow in natural order (see [Segment.id]) in the queue.
1011
*
1112
* In some data structures, like `Semaphore`, this queue is used for storing suspended continuations
@@ -16,12 +17,15 @@ import kotlinx.atomicfu.atomic
1617
internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Boolean = false) {
1718
private val _head: AtomicRef<S?>
1819
/**
19-
* Returns the first segment in the queue. All segments with lower [id]
20+
* Returns the first segment in the queue.
2021
*/
21-
protected val first: S? get() = _head.value
22+
protected val head: S? get() = _head.value
2223

2324
private val _tail: AtomicRef<S?>
24-
protected val last: S? get() = _tail.value
25+
/**
26+
* Returns the last segment in the queue.
27+
*/
28+
protected val tail: S? get() = _tail.value
2529

2630
init {
2731
val initialSegment = if (createFirstSegmentLazily) null else newSegment(0)
@@ -37,7 +41,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
3741

3842
/**
3943
* Finds a segment with the specified [id] following by next references from the
40-
* [startFrom] segment. The typical use-case is reading [last] or [first], doing some
44+
* [startFrom] segment. The typical use-case is reading [tail] or [head], doing some
4145
* synchronization, and invoking [getSegment] or [getSegmentAndMoveFirst] correspondingly
4246
* to find the required segment.
4347
*/
@@ -50,7 +54,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
5054
if (_head.compareAndSet(null, firstSegment))
5155
startFrom = firstSegment
5256
else {
53-
startFrom = first!!
57+
startFrom = head!!
5458
}
5559
}
5660
if (startFrom.id > id) return null
@@ -61,18 +65,18 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
6165
// uses it. This way, only one segment with each id can be in the queue.
6266
var cur: S = startFrom
6367
while (cur.id < id) {
64-
var curNext = cur.next.value
68+
var curNext = cur.next
6569
if (curNext == null) {
6670
// Add a new segment.
6771
val newTail = newSegment(cur.id + 1, cur)
68-
curNext = if (cur.next.compareAndSet(null, newTail)) {
72+
curNext = if (cur.casNext(null, newTail)) {
6973
if (cur.removed) {
7074
cur.remove()
7175
}
7276
moveTailForward(newTail)
7377
newTail
7478
} else {
75-
cur.next.value!!
79+
cur.next!!
7680
}
7781
}
7882
cur = curNext
@@ -82,7 +86,7 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
8286
}
8387

8488
/**
85-
* Invokes [getSegment] and replaces [first] with the result if its [id] is greater.
89+
* Invokes [getSegment] and replaces [head] with the result if its [id] is greater.
8690
*/
8791
protected fun getSegmentAndMoveFirst(startFrom: S?, id: Long): S? {
8892
if (startFrom !== null && startFrom.id == id) return startFrom
@@ -96,10 +100,9 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
96100
* if its `id` is greater.
97101
*/
98102
private fun moveHeadForward(new: S) {
99-
while (true) {
100-
val cur = first!!
101-
if (cur.id > new.id) return
102-
if (_head.compareAndSet(cur, new)) {
103+
_head.loop { curHead ->
104+
if (curHead!!.id > new.id) return
105+
if (_head.compareAndSet(curHead, new)) {
103106
new.prev.value = null
104107
return
105108
}
@@ -111,10 +114,9 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
111114
* if its `id` is greater.
112115
*/
113116
private fun moveTailForward(new: S) {
114-
while (true) {
115-
val cur = last
116-
if (cur !== null && cur.id > new.id) return
117-
if (_tail.compareAndSet(cur, new)) return
117+
_tail.loop { curTail ->
118+
if (curTail !== null && curTail.id > new.id) return
119+
if (_tail.compareAndSet(curTail, new)) return
118120
}
119121
}
120122
}
@@ -126,7 +128,9 @@ internal abstract class SegmentQueue<S: Segment<S>>(createFirstSegmentLazily: Bo
126128
*/
127129
internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
128130
// Pointer to the next segment, updates similarly to the Michael-Scott queue algorithm.
129-
val next = atomic<S?>(null)
131+
private val _next = atomic<S?>(null)
132+
val next: S? get() = _next.value
133+
fun casNext(expected: S?, value: S?): Boolean = _next.compareAndSet(expected, value)
130134
// Pointer to the previous segment, updates in [remove] function.
131135
val prev = atomic<S?>(null)
132136

@@ -147,28 +151,30 @@ internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
147151
fun remove() {
148152
check(removed) { " The segment should be logically removed at first "}
149153
// Read `next` and `prev` pointers.
150-
val next = this.next.value ?: return // tail cannot be removed
151-
val prev = prev.value ?: return // head cannot be removed
154+
var next = this._next.value ?: return // tail cannot be removed
155+
var prev = prev.value ?: return // head cannot be removed
152156
// Link `next` and `prev`.
157+
prev.moveNextToRight(next)
158+
while (prev.removed) {
159+
prev = prev.prev.value ?: break
160+
prev.moveNextToRight(next)
161+
}
153162
next.movePrevToLeft(prev)
154-
prev.movePrevNextToRight(next)
155-
// Check whether `prev` and `next` are still in the queue
156-
// and help with removing them if needed.
157-
if (prev.removed)
158-
prev.remove()
159-
if (next.removed)
160-
next.remove()
163+
while (next.removed) {
164+
next = next.next ?: break
165+
next.movePrevToLeft(prev)
166+
}
161167
}
162168

163169
/**
164170
* Updates [next] pointer to the specified segment if
165171
* the [id] of the specified segment is greater.
166172
*/
167-
private fun movePrevNextToRight(next: S) {
173+
private fun moveNextToRight(next: S) {
168174
while (true) {
169-
val curNext = this.next.value as S
175+
val curNext = this._next.value as S
170176
if (next.id <= curNext.id) return
171-
if (this.next.compareAndSet(curNext, next)) return
177+
if (this._next.compareAndSet(curNext, next)) return
172178
}
173179
}
174180

kotlinx-coroutines-core/common/src/sync/Semaphore.kt

+57-31
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
package kotlinx.coroutines.sync
22

3-
import kotlinx.atomicfu.*
3+
import kotlinx.atomicfu.atomic
4+
import kotlinx.atomicfu.atomicArrayOfNulls
5+
import kotlinx.atomicfu.getAndUpdate
6+
import kotlinx.atomicfu.loop
47
import kotlinx.coroutines.*
58
import kotlinx.coroutines.internal.*
69
import kotlin.coroutines.resume
7-
import kotlin.jvm.JvmField
810
import kotlin.math.max
911

1012
/**
1113
* A counting semaphore for coroutines. It maintains a number of available permits.
1214
* Each [acquire] suspends if necessary until a permit is available, and then takes it.
1315
* Each [release] adds a permit, potentially releasing a suspended acquirer.
1416
*
15-
* Semaphore with `maxPermits = 1` is essentially a [Mutex].
17+
* Semaphore with `permits = 1` is essentially a [Mutex].
1618
**/
1719
public interface Semaphore {
1820
/**
19-
* Returns the current number of available permits available in this semaphore.
21+
* Returns the current number of permits available in this semaphore.
2022
*/
2123
public val availablePermits: Int
2224

@@ -27,8 +29,8 @@ public interface Semaphore {
2729
* This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this
2830
* function is suspended, this function immediately resumes with [CancellationException].
2931
*
30-
* *Cancellation of suspended lock invocation is atomic* -- when this function
31-
* throws [CancellationException] it means that the mutex was not locked.
32+
* *Cancellation of suspended semaphore acquisition` is atomic* -- when this function
33+
* throws [CancellationException] it means that the semaphore was not acquired.
3234
*
3335
* Note, that this function does not check for cancellation when it is not suspended.
3436
* Use [yield] or [CoroutineScope.isActive] to periodically check for cancellation in tight loops if needed.
@@ -47,23 +49,28 @@ public interface Semaphore {
4749
/**
4850
* Releases a permit, returning it into this semaphore. Resumes the first
4951
* suspending acquirer if there is one at the point of invocation.
52+
* Throws [IllegalStateException] if there is no acquired permit
53+
* at the point of invocation.
5054
*/
5155
public fun release()
5256
}
5357

5458
/**
5559
* Creates new [Semaphore] instance.
60+
* @param permits the number of permits available in this semaphore.
61+
* @param acquiredPermits the number of already acquired permits,
62+
* should be between `0` and `permits` (inclusively).
5663
*/
5764
@Suppress("FunctionName")
58-
public fun Semaphore(maxPermits: Int, acquiredPermits: Int = 0): Semaphore = SemaphoreImpl(maxPermits, acquiredPermits)
65+
public fun Semaphore(permits: Int, acquiredPermits: Int = 0): Semaphore = SemaphoreImpl(permits, acquiredPermits)
5966

6067
/**
61-
* Executes the given [action] with acquiring a permit from this semaphore at the beginning
68+
* Executes the given [action], acquiring a permit from this semaphore at the beginning
6269
* and releasing it after the [action] is completed.
6370
*
6471
* @return the return value of the [action].
6572
*/
66-
public suspend inline fun <T> Semaphore.withSemaphore(action: () -> T): T {
73+
public suspend inline fun <T> Semaphore.withPermit(action: () -> T): T {
6774
acquire()
6875
try {
6976
return action()
@@ -72,24 +79,24 @@ public suspend inline fun <T> Semaphore.withSemaphore(action: () -> T): T {
7279
}
7380
}
7481

75-
private class SemaphoreImpl(@JvmField val maxPermits: Int, acquiredPermits: Int)
82+
private class SemaphoreImpl(private val permits: Int, acquiredPermits: Int)
7683
: Semaphore, SegmentQueue<SemaphoreSegment>(createFirstSegmentLazily = true)
7784
{
7885
init {
79-
require(maxPermits > 0) { "Semaphore should have at least 1 permit"}
80-
require(acquiredPermits in 0..maxPermits) { "The number of acquired permits should be ≥ 0 and ≤ maxPermits" }
86+
require(permits > 0) { "Semaphore should have at least 1 permit"}
87+
require(acquiredPermits in 0..permits) { "The number of acquired permits should be ≥ 0 and ≤ permits" }
8188
}
8289

8390
override fun newSegment(id: Long, prev: SemaphoreSegment?)= SemaphoreSegment(id, prev)
8491

8592
/**
8693
* This counter indicates a number of available permits if it is non-negative,
8794
* or the size with minus sign otherwise. Note, that 32-bit counter is enough here
88-
* since the maximal number of available permits is [maxPermits] which is [Int],
95+
* since the maximal number of available permits is [permits] which is [Int],
8996
* and the maximum number of waiting acquirers cannot be greater than 2^31 in any
9097
* real application.
9198
*/
92-
private val _availablePermits = atomic(maxPermits)
99+
private val _availablePermits = atomic(permits)
93100
override val availablePermits: Int get() = max(_availablePermits.value, 0)
94101

95102
// The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`;
@@ -115,51 +122,68 @@ private class SemaphoreImpl(@JvmField val maxPermits: Int, acquiredPermits: Int)
115122

116123
override fun release() {
117124
val p = _availablePermits.getAndUpdate { cur ->
118-
check(cur < maxPermits) { "The number of acquired permits cannot be greater than maxPermits" }
125+
check(cur < permits) { "The number of acquired permits cannot be greater than `permits`" }
119126
cur + 1
120127
}
121128
if (p >= 0) return // no waiters
122129
resumeNextFromQueue()
123130
}
124131

125132
private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutine<Unit> sc@ { cont ->
126-
val last = this.last
133+
val last = this.tail
127134
val enqIdx = enqIdx.getAndIncrement()
128135
val segment = getSegment(last, enqIdx / SEGMENT_SIZE)
129136
val i = (enqIdx % SEGMENT_SIZE).toInt()
130-
if (segment === null || segment[i].value === RESUMED || !segment[i].compareAndSet(null, cont)) {
137+
if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) {
131138
// already resumed
132139
cont.resume(Unit)
133140
return@sc
134141
}
135-
cont.invokeOnCancellation(handler = object : CancelHandler() {
136-
override fun invoke(cause: Throwable?) {
137-
segment.cancel(i)
138-
release()
139-
}
140-
}.asHandler)
142+
cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, i).asHandler)
141143
}
142144

145+
@Suppress("UNCHECKED_CAST")
143146
private fun resumeNextFromQueue() {
144-
val first = this.first
147+
val first = this.head
145148
val deqIdx = deqIdx.getAndIncrement()
146149
val segment = getSegmentAndMoveFirst(first, deqIdx / SEGMENT_SIZE) ?: return
147150
val i = (deqIdx % SEGMENT_SIZE).toInt()
148-
val cont = segment[i].getAndUpdate {
149-
if (it === CANCELLED) CANCELLED else RESUMED
151+
val cont = segment.getAndUpdate(i) {
152+
// Cancelled continuation invokes `release`
153+
// and resumes next suspended acquirer if needed.
154+
if (it === CANCELLED) return
155+
RESUMED
150156
}
151157
if (cont === null) return // just resumed
152-
if (cont === CANCELLED) return // Cancelled continuation invokes `release`
153-
// and resumes next suspended acquirer if needed.
154-
cont as CancellableContinuation<Unit>
155-
cont.resume(Unit)
158+
(cont as CancellableContinuation<Unit>).resume(Unit)
156159
}
157160
}
158161

162+
private class CancelSemaphoreAcquisitionHandler(private val semaphore: Semaphore, private val segment: SemaphoreSegment, private val index: Int) : CancelHandler() {
163+
override fun invoke(cause: Throwable?) {
164+
segment.cancel(index)
165+
semaphore.release()
166+
}
167+
168+
override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]"
169+
}
170+
159171
private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<SemaphoreSegment>(id, prev) {
160172
private val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
161173

162-
operator fun get(index: Int): AtomicRef<Any?> = acquirers[index]
174+
@Suppress("NOTHING_TO_INLINE")
175+
inline fun get(index: Int): Any? = acquirers[index].value
176+
177+
@Suppress("NOTHING_TO_INLINE")
178+
inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value)
179+
180+
inline fun getAndUpdate(index: Int, function: (Any?) -> Any?): Any? {
181+
while (true) {
182+
val cur = acquirers[index].value
183+
val upd = function(cur)
184+
if (cas(index, cur, upd)) return cur
185+
}
186+
}
163187

164188
private val cancelledSlots = atomic(0)
165189
override val removed get() = cancelledSlots.value == SEGMENT_SIZE
@@ -173,6 +197,8 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment<Semap
173197
if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE)
174198
remove()
175199
}
200+
201+
override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]"
176202
}
177203

178204
@SharedImmutable

kotlinx-coroutines-core/common/test/sync/SemaphoreTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class SemaphoreTest : TestBase() {
6666
fun withSemaphoreTest() = runTest {
6767
val semaphore = Semaphore(1)
6868
assertEquals(1, semaphore.availablePermits)
69-
semaphore.withSemaphore {
69+
semaphore.withPermit {
7070
assertEquals(0, semaphore.availablePermits)
7171
}
7272
assertEquals(1, semaphore.availablePermits)

0 commit comments

Comments
 (0)