diff --git a/kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt b/kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt index 09e9deb838..3482aa931c 100644 --- a/kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt +++ b/kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt @@ -372,25 +372,34 @@ internal class CoroutineScheduler( * Dispatches execution of a runnable [block] with a hint to a scheduler whether * this [block] may execute blocking operations (IO, system calls, locking primitives etc.) * - * @param taskContext concurrency context of given [block] - * @param fair whether the task should be dispatched fairly (strict FIFO) or not (semi-FIFO) + * [taskContext] -- concurrency context of given [block]. + * [tailDispatch] -- whether this [dispatch] call is the last action the (presumably) worker thread does in its current task. + * If `true`, then the task will be dispatched in a FIFO manner and no additional workers will be requested, + * but only if the current thread is a corresponding worker thread. + * Note that caller cannot be ensured that it is being executed on worker thread for the following reasons: + * * [CoroutineStart.UNDISPATCHED] + * * Concurrent [close] that effectively shutdowns the worker thread */ - fun dispatch(block: Runnable, taskContext: TaskContext = NonBlockingContext, fair: Boolean = false) { + fun dispatch(block: Runnable, taskContext: TaskContext = NonBlockingContext, tailDispatch: Boolean = false) { trackTask() // this is needed for virtual time support val task = createTask(block, taskContext) // try to submit the task to the local queue and act depending on the result - val notAdded = submitToLocalQueue(task, fair) + val currentWorker = currentWorker() + val notAdded = currentWorker.submitToLocalQueue(task, tailDispatch) if (notAdded != null) { if (!addToGlobalQueue(notAdded)) { // Global queue is closed in the last step of close/shutdown -- no more tasks should be accepted throw RejectedExecutionException("$schedulerName was terminated") } } + val skipUnpark = tailDispatch && currentWorker != null // Checking 'task' instead of 'notAdded' is completely okay if (task.mode == TaskMode.NON_BLOCKING) { + if (skipUnpark) return signalCpuWork() } else { - signalBlockingWork() + // Increment blocking tasks anyway + signalBlockingWork(skipUnpark = skipUnpark) } } @@ -404,9 +413,10 @@ internal class CoroutineScheduler( return TaskImpl(block, nanoTime, taskContext) } - private fun signalBlockingWork() { + private fun signalBlockingWork(skipUnpark: Boolean) { // Use state snapshot to avoid thread overprovision val stateSnapshot = incrementBlockingTasks() + if (skipUnpark) return if (tryUnpark()) return if (tryCreateWorker(stateSnapshot)) return tryUnpark() // Try unpark again in case there was race between permit release and parking @@ -481,19 +491,19 @@ internal class CoroutineScheduler( * Returns `null` if task was successfully added or an instance of the * task that was not added or replaced (thus should be added to global queue). */ - private fun submitToLocalQueue(task: Task, fair: Boolean): Task? { - val worker = currentWorker() ?: return task + private fun Worker?.submitToLocalQueue(task: Task, tailDispatch: Boolean): Task? { + if (this == null) return task /* * This worker could have been already terminated from this thread by close/shutdown and it should not * accept any more tasks into its local queue. */ - if (worker.state === WorkerState.TERMINATED) return task + if (state === WorkerState.TERMINATED) return task // Do not add CPU tasks in local queue if we are not able to execute it - if (task.mode === TaskMode.NON_BLOCKING && worker.state === WorkerState.BLOCKING) { + if (task.mode === TaskMode.NON_BLOCKING && state === WorkerState.BLOCKING) { return task } - worker.mayHaveLocalTasks = true - return worker.localQueue.add(task, fair = fair) + mayHaveLocalTasks = true + return localQueue.add(task, fair = tailDispatch) } private fun currentWorker(): Worker? = (Thread.currentThread() as? Worker)?.takeIf { it.scheduler == this } diff --git a/kotlinx-coroutines-core/jvm/src/scheduling/Dispatcher.kt b/kotlinx-coroutines-core/jvm/src/scheduling/Dispatcher.kt index bd1ba95dd8..bbc2b35b16 100644 --- a/kotlinx-coroutines-core/jvm/src/scheduling/Dispatcher.kt +++ b/kotlinx-coroutines-core/jvm/src/scheduling/Dispatcher.kt @@ -65,7 +65,7 @@ open class ExperimentalCoroutineDispatcher( override fun dispatchYield(context: CoroutineContext, block: Runnable): Unit = try { - coroutineScheduler.dispatch(block, fair = true) + coroutineScheduler.dispatch(block, tailDispatch = true) } catch (e: RejectedExecutionException) { DefaultExecutor.dispatchYield(context, block) } @@ -101,9 +101,9 @@ open class ExperimentalCoroutineDispatcher( return LimitingDispatcher(this, parallelism, TaskMode.NON_BLOCKING) } - internal fun dispatchWithContext(block: Runnable, context: TaskContext, fair: Boolean) { + internal fun dispatchWithContext(block: Runnable, context: TaskContext, tailDispatch: Boolean) { try { - coroutineScheduler.dispatch(block, context, fair) + coroutineScheduler.dispatch(block, context, tailDispatch) } catch (e: RejectedExecutionException) { // Context shouldn't be lost here to properly invoke before/after task DefaultExecutor.enqueue(coroutineScheduler.createTask(block, context)) @@ -147,7 +147,7 @@ private class LimitingDispatcher( override fun dispatch(context: CoroutineContext, block: Runnable) = dispatch(block, false) - private fun dispatch(block: Runnable, fair: Boolean) { + private fun dispatch(block: Runnable, tailDispatch: Boolean) { var taskToSchedule = block while (true) { // Commit in-flight tasks slot @@ -155,7 +155,7 @@ private class LimitingDispatcher( // Fast path, if parallelism limit is not reached, dispatch task and return if (inFlight <= parallelism) { - dispatcher.dispatchWithContext(taskToSchedule, this, fair) + dispatcher.dispatchWithContext(taskToSchedule, this, tailDispatch) return } @@ -185,6 +185,10 @@ private class LimitingDispatcher( } } + override fun dispatchYield(context: CoroutineContext, block: Runnable) { + dispatch(block, tailDispatch = true) + } + override fun toString(): String { return "${super.toString()}[dispatcher = $dispatcher]" } diff --git a/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherTest.kt b/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherTest.kt index 66b93be9cf..f31752c8b5 100644 --- a/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherTest.kt +++ b/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherTest.kt @@ -194,10 +194,10 @@ class BlockingCoroutineDispatcherTest : SchedulerTestBase() { fun testYield() = runBlocking { corePoolSize = 1 maxPoolSize = 1 - val ds = blockingDispatcher(1) - val outerJob = launch(ds) { + val bd = blockingDispatcher(1) + val outerJob = launch(bd) { expect(1) - val innerJob = launch(ds) { + val innerJob = launch(bd) { // Do nothing expect(3) } @@ -215,6 +215,21 @@ class BlockingCoroutineDispatcherTest : SchedulerTestBase() { finish(5) } + @Test + fun testUndispatchedYield() = runTest { + expect(1) + corePoolSize = 1 + maxPoolSize = 1 + val blockingDispatcher = blockingDispatcher(1) + val job = launch(blockingDispatcher, CoroutineStart.UNDISPATCHED) { + expect(2) + yield() + } + expect(3) + job.join() + finish(4) + } + @Test(expected = IllegalArgumentException::class) fun testNegativeParallelism() { blockingDispatcher(-1) diff --git a/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherThreadLimitStressTest.kt b/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherThreadLimitStressTest.kt index 123fe3c9c4..c1fda44487 100644 --- a/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherThreadLimitStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/scheduling/BlockingCoroutineDispatcherThreadLimitStressTest.kt @@ -21,7 +21,6 @@ class BlockingCoroutineDispatcherThreadLimitStressTest : SchedulerTestBase() { private val concurrentWorkers = AtomicInteger(0) @Test - @Ignore fun testLimitParallelismToOne() = runTest { val limitingDispatcher = blockingDispatcher(1) // Do in bursts to avoid OOM diff --git a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineDispatcherTest.kt b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineDispatcherTest.kt index 062b849c0a..3cd77da74a 100644 --- a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineDispatcherTest.kt +++ b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineDispatcherTest.kt @@ -117,6 +117,18 @@ class CoroutineDispatcherTest : SchedulerTestBase() { finish(5) } + @Test + fun testUndispatchedYield() = runTest { + expect(1) + val job = launch(dispatcher, CoroutineStart.UNDISPATCHED) { + expect(2) + yield() + } + expect(3) + job.join() + finish(4) + } + @Test fun testThreadName() = runBlocking { val initialCount = Thread.getAllStackTraces().keys.asSequence() diff --git a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerCloseStressTest.kt b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerCloseStressTest.kt index f91b0a9131..473b429283 100644 --- a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerCloseStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerCloseStressTest.kt @@ -10,7 +10,6 @@ import org.junit.Test import org.junit.runner.* import org.junit.runners.* import java.util.* -import java.util.concurrent.* import kotlin.test.* @RunWith(Parameterized::class) @@ -79,6 +78,10 @@ class CoroutineSchedulerCloseStressTest(private val mode: Mode) : TestBase() { } else { if (rnd.nextBoolean()) { delay(1000) + val t = Thread.currentThread() + if (!t.name.contains("DefaultDispatcher-worker")) { + val a = 2 + } } else { yield() } diff --git a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerTest.kt b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerTest.kt index ff831950b5..38145af8c9 100644 --- a/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerTest.kt +++ b/kotlinx-coroutines-core/jvm/test/scheduling/CoroutineSchedulerTest.kt @@ -82,7 +82,7 @@ class CoroutineSchedulerTest : TestBase() { it.dispatch(Runnable { expect(2) finishLatch.countDown() - }, fair = true) + }, tailDispatch = true) }) startLatch.countDown()