Skip to content

Commit afb3dea

Browse files
author
Anton Spaans
committed
Adding a test-helper class TestCoroutineContext.
1 parent 20dbd9f commit afb3dea

File tree

4 files changed

+721
-0
lines changed

4 files changed

+721
-0
lines changed

core/kotlinx-coroutines-core/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ This module provides debugging facilities for coroutines (run JVM with `-ea` or
6969
and [newCoroutineContext] function to write user-defined coroutine builders that work with these
7070
debugging facilities.
7171

72+
This module provides a special CoroutineContext type [TestCoroutineCoroutineContext][kotlinx.coroutines.experimental.test.TestCoroutineContext] that
73+
allows the writer of code that contains Coroutines with delays and timeouts to write non-flaky unit-tests for that code allowing these tests to
74+
terminate in near zero time. See the documentation for this class for more information.
75+
7276
# Package kotlinx.coroutines.experimental
7377

7478
General-purpose coroutine builders, contexts, and helper functions.
@@ -93,6 +97,10 @@ Low-level primitives for finer-grained control of coroutines.
9397

9498
Optional time unit support for multiplatform projects.
9599

100+
# Package kotlinx.coroutines.experimental.test
101+
102+
Components to ease writing unit-tests for code that contains coroutines with delays and timeouts.
103+
96104
<!--- MODULE kotlinx-coroutines-core -->
97105
<!--- INDEX kotlinx.coroutines.experimental -->
98106
[launch]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental/launch.html
@@ -148,4 +156,6 @@ Optional time unit support for multiplatform projects.
148156
<!--- INDEX kotlinx.coroutines.experimental.selects -->
149157
[kotlinx.coroutines.experimental.selects.select]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.selects/select.html
150158
[kotlinx.coroutines.experimental.selects.SelectBuilder.onTimeout]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.selects/-select-builder/on-timeout.html
159+
<!--- INDEX kotlinx.coroutines.experimental.test -->
160+
[kotlinx.coroutines.experimental.test.TestCoroutineContext]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines.experimental.test/-test-coroutine-context/index.html
151161
<!--- END -->

core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package kotlinx.coroutines.experimental.internal
1818

19+
import java.util.*
20+
1921
/**
2022
* @suppress **This is unstable API and it is subject to change.**
2123
*/
@@ -36,6 +38,11 @@ public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {
3638

3739
public val isEmpty: Boolean get() = size == 0
3840

41+
public fun clear() = synchronized(this) {
42+
Arrays.fill(a, 0, size, null)
43+
size = 0
44+
}
45+
3946
public fun peek(): T? = synchronized(this) { firstImpl() }
4047

4148
public fun removeFirstOrNull(): T? = synchronized(this) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package kotlinx.coroutines.experimental.test
18+
19+
import kotlinx.coroutines.experimental.*
20+
import kotlinx.coroutines.experimental.internal.*
21+
import java.util.concurrent.TimeUnit
22+
import kotlin.coroutines.experimental.*
23+
24+
/**
25+
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
26+
* code, especially tests, that deal with delays and timeouts in Coroutines.
27+
*
28+
* Provide an instance of this TestCoroutineContext when calling the *non-blocking* [launch] or [async]
29+
* and then advance time or trigger the actions to make the co-routines execute as soon as possible.
30+
*
31+
* This works much like the *TestScheduler* in RxJava2, which allows to speed up tests that deal
32+
* with non-blocking Rx chains that contain delays, timeouts, intervals and such.
33+
*
34+
* This dispatcher can also handle *blocking* coroutines that are started by [runBlocking].
35+
* This dispatcher's virtual time will be automatically advanced based based on the delayed actions
36+
* within the Coroutine(s).
37+
*
38+
* @param name A user-readable name for debugging purposes.
39+
*/
40+
class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
41+
private val uncaughtExceptions = mutableListOf<Throwable>()
42+
43+
private val ctxDispatcher = Dispatcher()
44+
45+
private val ctxHandler = CoroutineExceptionHandler { _, exception ->
46+
uncaughtExceptions += exception
47+
}
48+
49+
// The ordered queue for the runnable tasks.
50+
private val queue = ThreadSafeHeap<TimedRunnable>()
51+
52+
// The per-scheduler global order counter.
53+
private var counter = 0L
54+
55+
// Storing time in nanoseconds internally.
56+
private var time = 0L
57+
58+
/**
59+
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
60+
*/
61+
public val exceptions: List<Throwable> get() = uncaughtExceptions
62+
63+
// -- CoroutineContext implementation
64+
65+
public override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
66+
operation(operation(initial, ctxDispatcher), ctxHandler)
67+
68+
@Suppress("UNCHECKED_CAST")
69+
public override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = when {
70+
key === ContinuationInterceptor -> ctxDispatcher as E
71+
key === CoroutineExceptionHandler -> ctxHandler as E
72+
else -> null
73+
}
74+
75+
public override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = when {
76+
key === ContinuationInterceptor -> ctxHandler
77+
key === CoroutineExceptionHandler -> ctxDispatcher
78+
else -> this
79+
}
80+
81+
/**
82+
* Returns the current virtual clock-time as it is known to this CoroutineContext.
83+
*
84+
* @param unit The [TimeUnit] in which the clock-time must be returned.
85+
* @return The virtual clock-time
86+
*/
87+
public fun now(unit: TimeUnit = TimeUnit.MILLISECONDS)=
88+
unit.convert(time, TimeUnit.NANOSECONDS)
89+
90+
/**
91+
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
92+
*
93+
* The returned delay-time can be larger than the specified delay-time if the code
94+
* under test contains *blocking* Coroutines.
95+
*
96+
* @param delayTime The amount of time to move the CoroutineContext's clock forward.
97+
* @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
98+
* @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
99+
*/
100+
public fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Long {
101+
val oldTime = time
102+
advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
103+
return unit.convert(time - oldTime, TimeUnit.NANOSECONDS)
104+
}
105+
106+
/**
107+
* Moves the CoroutineContext's clock-time to a particular moment in time.
108+
*
109+
* @param targetTime The point in time to which to move the CoroutineContext's clock.
110+
* @param unit The [TimeUnit] in which [targetTime] is expressed.
111+
*/
112+
fun advanceTimeTo(targetTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
113+
val nanoTime = unit.toNanos(targetTime)
114+
triggerActions(nanoTime)
115+
if (nanoTime > time) time = nanoTime
116+
}
117+
118+
/**
119+
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
120+
* before this CoroutineContext's present virtual clock-time.
121+
*/
122+
public fun triggerActions() = triggerActions(time)
123+
124+
/**
125+
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
126+
* mess with your coroutines work. This method should usually be called on tear-down of a
127+
* unit test.
128+
*/
129+
public fun cancelAllActions() {
130+
// An 'is-empty' test is required to avoid a NullPointerException in the 'clear()' method
131+
if (!queue.isEmpty) queue.clear()
132+
}
133+
134+
/**
135+
* This method does nothing if there is one unhandled exception that satisfies the given predicate.
136+
* Otherwise it throws an [AssertionError] with the given message.
137+
*
138+
* (this method will clear the list of unhandled exceptions)
139+
*
140+
* @param message Message of the [AssertionError]. Defaults to an empty String.
141+
* @param predicate The predicate that must be satisfied.
142+
*/
143+
public fun assertUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
144+
if (uncaughtExceptions.size != 1 || !predicate(uncaughtExceptions[0])) throw AssertionError(message)
145+
uncaughtExceptions.clear()
146+
}
147+
148+
/**
149+
* This method does nothing if there are no unhandled exceptions or all of them satisfy the given predicate.
150+
* Otherwise it throws an [AssertionError] with the given message.
151+
*
152+
* (this method will clear the list of unhandled exceptions)
153+
*
154+
* @param message Message of the [AssertionError]. Defaults to an empty String.
155+
* @param predicate The predicate that must be satisfied.
156+
*/
157+
public fun assertAllUnhandledExceptions(message: String = "", predicate: (Throwable) -> Boolean) {
158+
if (!uncaughtExceptions.all(predicate)) throw AssertionError(message)
159+
uncaughtExceptions.clear()
160+
}
161+
162+
/**
163+
* This method does nothing if one or more unhandled exceptions satisfy the given predicate.
164+
* Otherwise it throws an [AssertionError] with the given message.
165+
*
166+
* (this method will clear the list of unhandled exceptions)
167+
*
168+
* @param message Message of the [AssertionError]. Defaults to an empty String.
169+
* @param predicate The predicate that must be satisfied.
170+
*/
171+
public fun assertAnyUnhandledException(message: String = "", predicate: (Throwable) -> Boolean) {
172+
if (!uncaughtExceptions.any(predicate)) throw AssertionError(message)
173+
uncaughtExceptions.clear()
174+
}
175+
176+
/**
177+
* This method does nothing if the list of unhandled exceptions satisfy the given predicate.
178+
* Otherwise it throws an [AssertionError] with the given message.
179+
*
180+
* (this method will clear the list of unhandled exceptions)
181+
*
182+
* @param message Message of the [AssertionError]. Defaults to an empty String.
183+
* @param predicate The predicate that must be satisfied.
184+
*/
185+
public fun assertExceptions(message: String = "", predicate: (List<Throwable>) -> Boolean) {
186+
if (!predicate(uncaughtExceptions)) throw AssertionError(message)
187+
uncaughtExceptions.clear()
188+
}
189+
190+
private fun post(block: Runnable) =
191+
queue.addLast(TimedRunnable(block, counter++))
192+
193+
private fun postDelayed(block: Runnable, delayTime: Long) =
194+
TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime))
195+
.also {
196+
queue.addLast(it)
197+
}
198+
199+
private fun processNextEvent(): Long {
200+
val current = queue.peek()
201+
if (current != null) {
202+
/** Automatically advance time for [EventLoop]-callbacks */
203+
triggerActions(current.time)
204+
}
205+
return if (queue.isEmpty) Long.MAX_VALUE else 0L
206+
}
207+
208+
private fun triggerActions(targetTime: Long) {
209+
while (true) {
210+
val current = queue.removeFirstIf { it.time <= targetTime } ?: break
211+
// If the scheduled time is 0 (immediate) use current virtual time
212+
if (current.time != 0L) time = current.time
213+
current.run()
214+
}
215+
}
216+
217+
public override fun toString(): String = name ?: "TestCoroutineContext@$hexAddress"
218+
219+
private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
220+
override fun dispatch(context: CoroutineContext, block: Runnable) = post(block)
221+
222+
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
223+
postDelayed(Runnable {
224+
with(continuation) { resumeUndispatched(Unit) }
225+
}, unit.toMillis(time))
226+
}
227+
228+
override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
229+
val node = postDelayed(block, unit.toMillis(time))
230+
return object : DisposableHandle {
231+
override fun dispose() {
232+
queue.remove(node)
233+
}
234+
}
235+
}
236+
237+
override fun processNextEvent() = this@TestCoroutineContext.processNextEvent()
238+
239+
public override fun toString(): String = "Dispatcher(${this@TestCoroutineContext})"
240+
}
241+
}
242+
243+
private class TimedRunnable(
244+
private val run: Runnable,
245+
private val count: Long = 0,
246+
@JvmField internal val time: Long = 0
247+
) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
248+
override var index: Int = 0
249+
250+
override fun run() = run.run()
251+
252+
override fun compareTo(other: TimedRunnable) = if (time == other.time) {
253+
count.compareTo(other.count)
254+
} else {
255+
time.compareTo(other.time)
256+
}
257+
258+
override fun toString() = "TimedRunnable(time=$time, run=$run)"
259+
}
260+
261+
/**
262+
* Executes a block of code in which a unit-test can be written using the provided [TestCoroutineContext]. The provided
263+
* [TestCoroutineContext] is available in the [testBody] as the `this` receiver.
264+
*
265+
* The [testBody] is executed and an [AssertionError] is thrown if the list of unhandled exceptions is not empty and
266+
* contains any exception that is not a [CancellationException].
267+
*
268+
* If the [testBody] successfully executes one of the [TestCoroutineContext.assertAllUnhandledExceptions],
269+
* [TestCoroutineContext.assertAnyUnhandledException], [TestCoroutineContext.assertUnhandledException] or
270+
* [TestCoroutineContext.assertExceptions], the list of unhandled exceptions will have been cleared and this method will
271+
* not throw an [AssertionError].
272+
*
273+
* @param testContext The provided [TestCoroutineContext]. If not specified, a default [TestCoroutineContext] will be
274+
* provided instead.
275+
* @param testBody The code of the unit-test.
276+
*/
277+
public fun withTestContext(testContext: TestCoroutineContext = TestCoroutineContext(), testBody: TestCoroutineContext.() -> Unit) {
278+
with (testContext) {
279+
testBody()
280+
281+
if (!exceptions.all { it is CancellationException }) {
282+
throw AssertionError("Coroutine encountered unhandled exceptions:\n${exceptions}")
283+
}
284+
}
285+
}

0 commit comments

Comments
 (0)