Skip to content

Commit 9d2a4c8

Browse files
committed
Prevent setting Dispatchers.Main concurrently
1 parent ded719f commit 9d2a4c8

File tree

5 files changed

+67
-22
lines changed

5 files changed

+67
-22
lines changed

kotlinx-coroutines-test/common/src/TestCoroutineDispatchers.kt

+5-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package kotlinx.coroutines.test
77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.channels.*
99
import kotlinx.coroutines.flow.*
10-
import kotlinx.coroutines.test.internal.*
1110
import kotlinx.coroutines.test.internal.TestMainDispatcher
1211
import kotlin.coroutines.*
1312

@@ -84,7 +83,8 @@ import kotlin.coroutines.*
8483
public fun UnconfinedTestDispatcher(
8584
scheduler: TestCoroutineScheduler? = null,
8685
name: String? = null
87-
): TestDispatcher = UnconfinedTestDispatcherImpl(scheduler ?: mainTestScheduler ?: TestCoroutineScheduler(), name)
86+
): TestDispatcher = UnconfinedTestDispatcherImpl(
87+
scheduler ?: TestMainDispatcher.currentTestScheduler ?: TestCoroutineScheduler(), name)
8888

8989
private class UnconfinedTestDispatcherImpl(
9090
override val scheduler: TestCoroutineScheduler,
@@ -141,7 +141,8 @@ private class UnconfinedTestDispatcherImpl(
141141
public fun StandardTestDispatcher(
142142
scheduler: TestCoroutineScheduler? = null,
143143
name: String? = null
144-
): TestDispatcher = StandardTestDispatcherImpl(scheduler ?: mainTestScheduler ?: TestCoroutineScheduler(), name)
144+
): TestDispatcher = StandardTestDispatcherImpl(
145+
scheduler ?: TestMainDispatcher.currentTestScheduler ?: TestCoroutineScheduler(), name)
145146

146147
private class StandardTestDispatcherImpl(
147148
override val scheduler: TestCoroutineScheduler = TestCoroutineScheduler(),
@@ -154,7 +155,4 @@ private class StandardTestDispatcherImpl(
154155
}
155156

156157
override fun toString(): String = "${name ?: "StandardTestDispatcher"}[scheduler=$scheduler]"
157-
}
158-
159-
private val mainTestScheduler
160-
get() = ((Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher)?.scheduler
158+
}

kotlinx-coroutines-test/common/src/TestDispatchers.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import kotlin.jvm.*
2121
@ExperimentalCoroutinesApi
2222
public fun Dispatchers.setMain(dispatcher: CoroutineDispatcher) {
2323
require(dispatcher !is TestMainDispatcher) { "Dispatchers.setMain(Dispatchers.Main) is prohibited, probably Dispatchers.resetMain() should be used instead" }
24-
getTestMainDispatcher().delegate = dispatcher
24+
getTestMainDispatcher().setDispatcher(dispatcher)
2525
}
2626

2727
/**

kotlinx-coroutines-test/common/src/internal/TestMainDispatcher.kt

+57-8
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,89 @@
33
*/
44

55
package kotlinx.coroutines.test.internal
6+
7+
import kotlinx.atomicfu.*
68
import kotlinx.coroutines.*
9+
import kotlinx.coroutines.test.*
710
import kotlin.coroutines.*
811

912
/**
1013
* The testable main dispatcher used by kotlinx-coroutines-test.
1114
* It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
1215
*/
13-
internal class TestMainDispatcher(var delegate: CoroutineDispatcher):
16+
internal class TestMainDispatcher(delegate: CoroutineDispatcher):
1417
MainCoroutineDispatcher(),
1518
Delay
1619
{
17-
private val mainDispatcher = delegate // the initial value passed to the constructor
20+
private val mainDispatcher = delegate
21+
private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main")
1822

1923
private val delay
20-
get() = delegate as? Delay ?: defaultDelay
24+
get() = delegate.value as? Delay ?: defaultDelay
2125

2226
override val immediate: MainCoroutineDispatcher
23-
get() = (delegate as? MainCoroutineDispatcher)?.immediate ?: this
27+
get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this
28+
29+
override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block)
2430

25-
override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.dispatch(context, block)
31+
override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context)
2632

27-
override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.isDispatchNeeded(context)
33+
override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block)
2834

29-
override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.dispatchYield(context, block)
35+
fun setDispatcher(dispatcher: CoroutineDispatcher) {
36+
delegate.value = dispatcher
37+
}
3038

3139
fun resetDispatcher() {
32-
delegate = mainDispatcher
40+
delegate.value = mainDispatcher
3341
}
3442

3543
override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
3644
delay.scheduleResumeAfterDelay(timeMillis, continuation)
3745

3846
override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
3947
delay.invokeOnTimeout(timeMillis, block, context)
48+
49+
companion object {
50+
internal val currentTestDispatcher
51+
get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher
52+
53+
internal val currentTestScheduler
54+
get() = currentTestDispatcher?.scheduler
55+
}
56+
57+
/**
58+
* A wrapper around a value that attempts to throw when writing happens concurrently with reading.
59+
*
60+
* The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the
61+
* next modification.
62+
*/
63+
private class NonConcurrentlyModifiable<T>(private val initialValue: T, private val name: String) {
64+
private val readers = atomic(0) // number of concurrent readers
65+
private val isWriting = atomic(false) // a modification is happening currently
66+
private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading
67+
private val _value = atomic(initialValue) // the backing field for the value
68+
69+
private fun concurrentWW() = IllegalStateException("$name is modified concurrently")
70+
private fun concurrentRW() = IllegalStateException("$name is used concurrently with setting it")
71+
72+
var value: T
73+
get() {
74+
readers.incrementAndGet()
75+
if (isWriting.value) exceptionWhenReading.value = concurrentRW()
76+
val result = _value.value
77+
readers.decrementAndGet()
78+
return result
79+
}
80+
set(value: T) {
81+
exceptionWhenReading.getAndSet(null)?.let { throw it }
82+
if (readers.value != 0) throw concurrentRW()
83+
if (!isWriting.compareAndSet(expect = false, update = true)) throw concurrentWW()
84+
_value.value = value
85+
isWriting.value = false
86+
if (readers.value != 0) throw concurrentRW()
87+
}
88+
}
4089
}
4190

4291
@Suppress("INVISIBLE_MEMBER")

kotlinx-coroutines-test/common/test/TestDispatchersTest.kt

+2-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
2424
@NoJs
2525
@Test
2626
fun testMainMocking() = runTest {
27-
val mainAtStart = mainTestDispatcher
27+
val mainAtStart = TestMainDispatcher.currentTestDispatcher
2828
assertNotNull(mainAtStart)
2929
withContext(Dispatchers.Main) {
3030
delay(10)
@@ -35,7 +35,7 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
3535
withContext(Dispatchers.Main) {
3636
delay(10)
3737
}
38-
assertSame(mainAtStart, mainTestDispatcher)
38+
assertSame(mainAtStart, TestMainDispatcher.currentTestDispatcher)
3939
}
4040

4141
/** Tests that the mocked [Dispatchers.Main] correctly forwards [Delay] methods. */
@@ -96,5 +96,3 @@ class TestDispatchersTest: OrderedExecutionTestBase() {
9696
}
9797
}
9898
}
99-
100-
private val mainTestDispatcher get() = ((Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher)

kotlinx-coroutines-test/js/test/FailingTests.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ class FailingTests {
2525
@Test
2626
fun testAfterTestIsConcurrent() = runTest {
2727
try {
28-
val mainAtStart = (Dispatchers.Main as? TestMainDispatcher)?.delegate as? TestDispatcher ?: return@runTest
28+
val mainAtStart = TestMainDispatcher.currentTestDispatcher ?: return@runTest
2929
withContext(Dispatchers.Default) {
3030
// context switch
3131
}
32-
assertNotSame(mainAtStart, (Dispatchers.Main as TestMainDispatcher).delegate)
32+
assertNotSame(mainAtStart, TestMainDispatcher.currentTestDispatcher!!)
3333
} finally {
3434
assertTrue(tearDownEntered)
3535
}

0 commit comments

Comments
 (0)