diff --git a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt index 45573f30cc..340737b1f6 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt @@ -201,6 +201,7 @@ public abstract class CoroutineDispatcher : * * This method should generally be exception-safe. An exception thrown from this method * may leave the coroutines that use this dispatcher in an inconsistent and hard-to-debug state. + * It is assumed that if any exceptions do get thrown from this method, then [block] will not be executed. * * This method must not immediately call [block]. Doing so may result in `StackOverflowError` * when `dispatch` is invoked repeatedly, for example when [yield] is called in a loop. diff --git a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt index 488331fc37..1361327b6f 100644 --- a/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt @@ -65,7 +65,17 @@ internal class LimitedDispatcher( // `runningWorkers` when they observed an empty queue. if (!tryAllocateWorker()) return val task = obtainTaskOrDeallocateWorker() ?: return - startWorker(Worker(task)) + try { + startWorker(Worker(task)) + } catch (e: Throwable) { + /* If we failed to start a worker, we should decrement the counter. + The queue is in an inconsistent state--it's non-empty despite the target parallelism not having been + reached--but at least a properly functioning worker will have a chance to correct this if some future + dispatch does succeed. + If we don't decrement the counter, it will be impossible to ever reach the target parallelism again. */ + runningWorkers.decrementAndGet() + throw e + } } /** @@ -107,21 +117,29 @@ internal class LimitedDispatcher( */ private inner class Worker(private var currentTask: Runnable) : Runnable { override fun run() { - var fairnessCounter = 0 - while (true) { - try { - currentTask.run() - } catch (e: Throwable) { - handleCoroutineException(EmptyCoroutineContext, e) + try { + var fairnessCounter = 0 + while (true) { + try { + currentTask.run() + } catch (e: Throwable) { + handleCoroutineException(EmptyCoroutineContext, e) + } + currentTask = obtainTaskOrDeallocateWorker() ?: return + // 16 is our out-of-thin-air constant to emulate fairness. Used in JS dispatchers as well + if (++fairnessCounter >= 16 && dispatcher.safeIsDispatchNeeded(this@LimitedDispatcher)) { + // Do "yield" to let other views execute their runnable as well + // Note that we do not decrement 'runningWorkers' as we are still committed to our part of work + dispatcher.safeDispatch(this@LimitedDispatcher, this) + return + } } - currentTask = obtainTaskOrDeallocateWorker() ?: return - // 16 is our out-of-thin-air constant to emulate fairness. Used in JS dispatchers as well - if (++fairnessCounter >= 16 && dispatcher.safeIsDispatchNeeded(this@LimitedDispatcher)) { - // Do "yield" to let other views execute their runnable as well - // Note that we do not decrement 'runningWorkers' as we are still committed to our part of work - dispatcher.safeDispatch(this@LimitedDispatcher, this) - return + } catch (e: Throwable) { + // If the worker failed, we should deallocate its slot + synchronized(workerAllocationLock) { + runningWorkers.decrementAndGet() } + throw e } } } @@ -132,4 +150,4 @@ internal fun Int.checkParallelism() = require(this >= 1) { "Expected positive pa internal fun CoroutineDispatcher.namedOrThis(name: String?): CoroutineDispatcher { if (name != null) return NamedDispatcher(this, name) return this -} \ No newline at end of file +} diff --git a/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt b/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt index d5b84edf19..2d3dea634d 100644 --- a/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt +++ b/kotlinx-coroutines-core/common/test/LimitedParallelismSharedTest.kt @@ -1,6 +1,8 @@ package kotlinx.coroutines import kotlinx.coroutines.testing.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext import kotlin.test.* class LimitedParallelismSharedTest : TestBase() { @@ -28,4 +30,30 @@ class LimitedParallelismSharedTest : TestBase() { assertFailsWith { Dispatchers.Default.limitedParallelism(Int.MIN_VALUE) } Dispatchers.Default.limitedParallelism(Int.MAX_VALUE) } + + /** + * Checks that even if the dispatcher sporadically fails, the limited dispatcher will still allow reaching the + * target parallelism level. + */ + @Test + fun testLimitedParallelismOfOccasionallyFailingDispatcher() { + val limit = 5 + var doFail = false + val workerQueue = mutableListOf() + val limited = object: CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + if (doFail) throw TestException() + workerQueue.add(block) + } + }.limitedParallelism(limit) + repeat(6 * limit) { + try { + limited.dispatch(EmptyCoroutineContext, Runnable { /* do nothing */ }) + } catch (_: DispatchException) { + // ignore + } + doFail = !doFail + } + assertEquals(limit, workerQueue.size) + } } diff --git a/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt b/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt index 211ff04cdc..c5d0fbeef6 100644 --- a/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt @@ -7,6 +7,8 @@ import org.junit.runner.* import org.junit.runners.* import java.util.concurrent.* import java.util.concurrent.atomic.* +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext import kotlin.test.* @RunWith(Parameterized::class) @@ -84,6 +86,58 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas } } + /** + * Checks that dispatcher failures during fairness redispatches don't prevent reaching the target parallelism. + */ + @Test + fun testLimitedFailingDispatcherReachesTargetParallelism() = runTest { + val keepFailing = AtomicBoolean(true) + val occasionallyFailing = object: CoroutineDispatcher() { + override fun dispatch(context: CoroutineContext, block: Runnable) { + if (keepFailing.get() && ThreadLocalRandom.current().nextBoolean()) throw TestException() + executor.dispatch(context, block) + } + }.limitedParallelism(targetParallelism) + doStress { + repeat(1000) { + keepFailing.set(true) // we want the next tasks to sporadically fail + // Start some tasks to make sure redispatching for fairness is happening + repeat(targetParallelism * 16 + 1) { + // targetParallelism * 16 + 1 because we need at least one worker to go through a fairness yield + // with high probability. + try { + occasionallyFailing.dispatch(EmptyCoroutineContext, Runnable { + // do nothing. + }) + } catch (_: DispatchException) { + // ignore + } + } + keepFailing.set(false) // we want the next tasks to succeed + val barrier = CyclicBarrier(targetParallelism + 1) + repeat(targetParallelism) { + launch(occasionallyFailing) { + barrier.await() + } + } + val success = launch(Dispatchers.Default) { + // Successfully awaited parallelism + 1 + barrier.await() + } + // Feed the dispatcher with more tasks to make sure it's not stuck + while (success.isActive) { + Thread.sleep(1) + repeat(targetParallelism) { + occasionallyFailing.dispatch(EmptyCoroutineContext, Runnable { + // do nothing. + }) + } + } + coroutineContext.job.children.toList().joinAll() + } + } + } + private suspend inline fun doStress(crossinline block: suspend CoroutineScope.() -> Unit) { repeat(stressTestMultiplier) { coroutineScope {