|
3 | 3 | */
|
4 | 4 |
|
5 | 5 | package kotlinx.coroutines.test.internal
|
| 6 | + |
| 7 | +import kotlinx.atomicfu.* |
6 | 8 | import kotlinx.coroutines.*
|
| 9 | +import kotlinx.coroutines.test.* |
7 | 10 | import kotlin.coroutines.*
|
8 | 11 |
|
9 | 12 | /**
|
10 | 13 | * The testable main dispatcher used by kotlinx-coroutines-test.
|
11 | 14 | * It is a [MainCoroutineDispatcher] that delegates all actions to a settable delegate.
|
12 | 15 | */
|
13 |
| -internal class TestMainDispatcher(var delegate: CoroutineDispatcher): |
| 16 | +internal class TestMainDispatcher(delegate: CoroutineDispatcher): |
14 | 17 | MainCoroutineDispatcher(),
|
15 | 18 | Delay
|
16 | 19 | {
|
17 |
| - private val mainDispatcher = delegate // the initial value passed to the constructor |
| 20 | + private val mainDispatcher = delegate |
| 21 | + private var delegate = NonConcurrentlyModifiable(mainDispatcher, "Dispatchers.Main") |
18 | 22 |
|
19 | 23 | private val delay
|
20 |
| - get() = delegate as? Delay ?: defaultDelay |
| 24 | + get() = delegate.value as? Delay ?: defaultDelay |
21 | 25 |
|
22 | 26 | override val immediate: MainCoroutineDispatcher
|
23 |
| - get() = (delegate as? MainCoroutineDispatcher)?.immediate ?: this |
| 27 | + get() = (delegate.value as? MainCoroutineDispatcher)?.immediate ?: this |
| 28 | + |
| 29 | + override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.value.dispatch(context, block) |
24 | 30 |
|
25 |
| - override fun dispatch(context: CoroutineContext, block: Runnable) = delegate.dispatch(context, block) |
| 31 | + override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.value.isDispatchNeeded(context) |
26 | 32 |
|
27 |
| - override fun isDispatchNeeded(context: CoroutineContext): Boolean = delegate.isDispatchNeeded(context) |
| 33 | + override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.value.dispatchYield(context, block) |
28 | 34 |
|
29 |
| - override fun dispatchYield(context: CoroutineContext, block: Runnable) = delegate.dispatchYield(context, block) |
| 35 | + fun setDispatcher(dispatcher: CoroutineDispatcher) { |
| 36 | + delegate.value = dispatcher |
| 37 | + } |
30 | 38 |
|
31 | 39 | fun resetDispatcher() {
|
32 |
| - delegate = mainDispatcher |
| 40 | + delegate.value = mainDispatcher |
33 | 41 | }
|
34 | 42 |
|
35 | 43 | override fun scheduleResumeAfterDelay(timeMillis: Long, continuation: CancellableContinuation<Unit>) =
|
36 | 44 | delay.scheduleResumeAfterDelay(timeMillis, continuation)
|
37 | 45 |
|
38 | 46 | override fun invokeOnTimeout(timeMillis: Long, block: Runnable, context: CoroutineContext): DisposableHandle =
|
39 | 47 | delay.invokeOnTimeout(timeMillis, block, context)
|
| 48 | + |
| 49 | + companion object { |
| 50 | + internal val currentTestDispatcher |
| 51 | + get() = (Dispatchers.Main as? TestMainDispatcher)?.delegate?.value as? TestDispatcher |
| 52 | + |
| 53 | + internal val currentTestScheduler |
| 54 | + get() = currentTestDispatcher?.scheduler |
| 55 | + } |
| 56 | + |
| 57 | + /** |
| 58 | + * A wrapper around a value that attempts to throw when writing happens concurrently with reading. |
| 59 | + * |
| 60 | + * The read operations never throw. Instead, the failures detected inside them will be remembered and thrown on the |
| 61 | + * next modification. |
| 62 | + */ |
| 63 | + private class NonConcurrentlyModifiable<T>(private val initialValue: T, private val name: String) { |
| 64 | + private val readers = atomic(0) // number of concurrent readers |
| 65 | + private val isWriting = atomic(false) // a modification is happening currently |
| 66 | + private val exceptionWhenReading: AtomicRef<Throwable?> = atomic(null) // exception from reading |
| 67 | + private val _value = atomic(initialValue) // the backing field for the value |
| 68 | + |
| 69 | + private fun concurrentWW() = IllegalStateException("$name is modified concurrently") |
| 70 | + private fun concurrentRW() = IllegalStateException("$name is used concurrently with setting it") |
| 71 | + |
| 72 | + var value: T |
| 73 | + get() { |
| 74 | + readers.incrementAndGet() |
| 75 | + if (isWriting.value) exceptionWhenReading.value = concurrentRW() |
| 76 | + val result = _value.value |
| 77 | + readers.decrementAndGet() |
| 78 | + return result |
| 79 | + } |
| 80 | + set(value: T) { |
| 81 | + exceptionWhenReading.getAndSet(null)?.let { throw it } |
| 82 | + if (readers.value != 0) throw concurrentRW() |
| 83 | + if (!isWriting.compareAndSet(expect = false, update = true)) throw concurrentWW() |
| 84 | + _value.value = value |
| 85 | + isWriting.value = false |
| 86 | + if (readers.value != 0) throw concurrentRW() |
| 87 | + } |
| 88 | + } |
40 | 89 | }
|
41 | 90 |
|
42 | 91 | @Suppress("INVISIBLE_MEMBER")
|
|
0 commit comments