Skip to content

Add Semaphore #1101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package benchmarks

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.scheduling.ExperimentalCoroutineDispatcher
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import org.openjdk.jmh.annotations.*
import java.util.concurrent.ForkJoinPool
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.TimeUnit

@Warmup(iterations = 3, time = 500, timeUnit = TimeUnit.MICROSECONDS)
@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MICROSECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
open class SemaphoreBenchmark {
@Param
private var _1_dispatcher: SemaphoreBenchDispatcherCreator = SemaphoreBenchDispatcherCreator.FORK_JOIN

@Param("0", "1000")
private var _2_coroutines: Int = 0

@Param("1", "2", "4", "8", "32", "128", "100000")
private var _3_maxPermits: Int = 0

@Param("1", "2", "4") // local machine
// @Param("1", "2", "4", "8", "16", "32", "64", "128", "144") // dasquad
// @Param("1", "2", "4", "8", "16", "32", "64", "96") // Google Cloud
private var _4_parallelism: Int = 0

private lateinit var dispatcher: CoroutineDispatcher
private var coroutines = 0

@InternalCoroutinesApi
@Setup
fun setup() {
dispatcher = _1_dispatcher.create(_4_parallelism)
coroutines = if (_2_coroutines == 0) _4_parallelism else _2_coroutines
}

@Benchmark
fun semaphore() = runBlocking {
val n = BATCH_SIZE / coroutines
val semaphore = Semaphore(_3_maxPermits)
val jobs = ArrayList<Job>(coroutines)
repeat(coroutines) {
jobs += GlobalScope.launch {
repeat(n) {
semaphore.withPermit {
doWork(WORK_INSIDE)
}
doWork(WORK_OUTSIDE)
}
}
}
jobs.forEach { it.join() }
}

@Benchmark
fun channelAsSemaphore() = runBlocking {
val n = BATCH_SIZE / coroutines
val semaphore = Channel<Unit>(_3_maxPermits)
val jobs = ArrayList<Job>(coroutines)
repeat(coroutines) {
jobs += GlobalScope.launch {
repeat(n) {
semaphore.send(Unit) // acquire
doWork(WORK_INSIDE)
semaphore.receive() // release
doWork(WORK_OUTSIDE)
}
}
}
jobs.forEach { it.join() }
}
}

enum class SemaphoreBenchDispatcherCreator(val create: (parallelism: Int) -> CoroutineDispatcher) {
FORK_JOIN({ parallelism -> ForkJoinPool(parallelism).asCoroutineDispatcher() }),
EXPERIMENTAL({ parallelism -> ExperimentalCoroutineDispatcher(corePoolSize = parallelism, maxPoolSize = parallelism) })
}

private fun doWork(work: Int) {
// We use geometric distribution here
val p = 1.0 / work
val r = ThreadLocalRandom.current()
while (true) {
if (r.nextDouble() < p) break
}
}

private const val WORK_INSIDE = 80
private const val WORK_OUTSIDE = 40
private const val BATCH_SIZE = 1000000
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,19 @@ public final class kotlinx/coroutines/sync/MutexKt {
public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
}

public abstract interface class kotlinx/coroutines/sync/Semaphore {
public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun getAvailablePermits ()I
public abstract fun release ()V
public abstract fun tryAcquire ()Z
}

public final class kotlinx/coroutines/sync/SemaphoreKt {
public static final fun Semaphore (II)Lkotlinx/coroutines/sync/Semaphore;
public static synthetic fun Semaphore$default (IIILjava/lang/Object;)Lkotlinx/coroutines/sync/Semaphore;
public static final fun withPermit (Lkotlinx/coroutines/sync/Semaphore;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class kotlinx/coroutines/test/TestCoroutineContext : kotlin/coroutines/CoroutineContext {
public fun <init> ()V
public fun <init> (Ljava/lang/String;)V
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ kotlin_version=1.3.31

# Dependencies
junit_version=4.12
atomicfu_version=0.12.7
atomicfu_version=0.12.8
html_version=0.6.8
lincheck_version=2.0
dokka_version=0.9.16-rdev-2-mpp-hacks
Expand Down
8 changes: 4 additions & 4 deletions kotlinx-coroutines-core/common/src/flow/operators/Merge.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel.Factory.OPTIONAL_CHANNEL
import kotlinx.coroutines.flow.internal.*
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.sync.*
import kotlin.coroutines.*
import kotlin.jvm.*
import kotlinx.coroutines.flow.unsafeFlow as flow
Expand Down Expand Up @@ -149,16 +150,15 @@ private class ChannelFlowMerge<T>(

// The actual merge implementation with concurrency limit
private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector<T>) {
val semaphore = Channel<Unit>(concurrency)
val semaphore = Semaphore(concurrency)
@Suppress("UNCHECKED_CAST")
flow.collect { inner ->
// TODO real semaphore (#94)
semaphore.send(Unit) // Acquire concurrency permit
semaphore.acquire() // Acquire concurrency permit
scope.launch {
try {
inner.collect(collector)
} finally {
semaphore.receive() // Release concurrency permit
semaphore.release() // Release concurrency permit
}
}
}
Expand Down
176 changes: 176 additions & 0 deletions kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package kotlinx.coroutines.internal

import kotlinx.atomicfu.AtomicRef
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.loop

/**
* Essentially, this segment queue is an infinite array of segments, which is represented as
* a Michael-Scott queue of them. All segments are instances of [Segment] class and
* follow in natural order (see [Segment.id]) in the queue.
*/
internal abstract class SegmentQueue<S: Segment<S>>() {
private val _head: AtomicRef<S>
/**
* Returns the first segment in the queue.
*/
protected val head: S get() = _head.value

private val _tail: AtomicRef<S>
/**
* Returns the last segment in the queue.
*/
protected val tail: S get() = _tail.value

init {
val initialSegment = newSegment(0)
_head = atomic(initialSegment)
_tail = atomic(initialSegment)
}

/**
* The implementation should create an instance of segment [S] with the specified id
* and initial reference to the previous one.
*/
abstract fun newSegment(id: Long, prev: S? = null): S

/**
* Finds a segment with the specified [id] following by next references from the
* [startFrom] segment. The typical use-case is reading [tail] or [head], doing some
* synchronization, and invoking [getSegment] or [getSegmentAndMoveHead] correspondingly
* to find the required segment.
*/
protected fun getSegment(startFrom: S, id: Long): S? {
// Go through `next` references and add new segments if needed,
// similarly to the `push` in the Michael-Scott queue algorithm.
// The only difference is that `CAS failure` means that the
// required segment has already been added, so the algorithm just
// uses it. This way, only one segment with each id can be in the queue.
var cur: S = startFrom
while (cur.id < id) {
var curNext = cur.next
if (curNext == null) {
// Add a new segment.
val newTail = newSegment(cur.id + 1, cur)
curNext = if (cur.casNext(null, newTail)) {
if (cur.removed) {
cur.remove()
}
moveTailForward(newTail)
newTail
} else {
cur.next!!
}
}
cur = curNext
}
if (cur.id != id) return null
return cur
}

/**
* Invokes [getSegment] and replaces [head] with the result if its [id] is greater.
*/
protected fun getSegmentAndMoveHead(startFrom: S, id: Long): S? {
@Suppress("LeakingThis")
if (startFrom.id == id) return startFrom
val s = getSegment(startFrom, id) ?: return null
moveHeadForward(s)
return s
}

/**
* Updates [head] to the specified segment
* if its `id` is greater.
*/
private fun moveHeadForward(new: S) {
_head.loop { curHead ->
if (curHead.id > new.id) return
if (_head.compareAndSet(curHead, new)) {
new.prev.value = null
return
}
}
}

/**
* Updates [tail] to the specified segment
* if its `id` is greater.
*/
private fun moveTailForward(new: S) {
_tail.loop { curTail ->
if (curTail.id > new.id) return
if (_tail.compareAndSet(curTail, new)) return
}
}
}

/**
* Each segment in [SegmentQueue] has a unique id and is created by [SegmentQueue.newSegment].
* Essentially, this is a node in the Michael-Scott queue algorithm, but with
* maintaining [prev] pointer for efficient [remove] implementation.
*/
internal abstract class Segment<S: Segment<S>>(val id: Long, prev: S?) {
// Pointer to the next segment, updates similarly to the Michael-Scott queue algorithm.
private val _next = atomic<S?>(null)
val next: S? get() = _next.value
fun casNext(expected: S?, value: S?): Boolean = _next.compareAndSet(expected, value)
// Pointer to the previous segment, updates in [remove] function.
val prev = atomic<S?>(null)

/**
* Returns `true` if this segment is logically removed from the queue.
* The [remove] function should be called right after it becomes logically removed.
*/
abstract val removed: Boolean

init {
this.prev.value = prev
}

/**
* Removes this segment physically from the segment queue. The segment should be
* logically removed (so [removed] returns `true`) at the point of invocation.
*/
fun remove() {
check(removed) { " The segment should be logically removed at first "}
// Read `next` and `prev` pointers.
var next = this._next.value ?: return // tail cannot be removed
var prev = prev.value ?: return // head cannot be removed
// Link `next` and `prev`.
prev.moveNextToRight(next)
while (prev.removed) {
prev = prev.prev.value ?: break
prev.moveNextToRight(next)
}
next.movePrevToLeft(prev)
while (next.removed) {
next = next.next ?: break
next.movePrevToLeft(prev)
}
}

/**
* Updates [next] pointer to the specified segment if
* the [id] of the specified segment is greater.
*/
private fun moveNextToRight(next: S) {
while (true) {
val curNext = this._next.value as S
if (next.id <= curNext.id) return
if (this._next.compareAndSet(curNext, next)) return
}
}

/**
* Updates [prev] pointer to the specified segment if
* the [id] of the specified segment is lower.
*/
private fun movePrevToLeft(prev: S) {
while (true) {
val curPrev = this.prev.value ?: return
if (curPrev.id <= prev.id) return
if (this.prev.compareAndSet(curPrev, prev)) return
}
}
}
Loading