1
+ /*
2
+ * Copyright 2016-2018 JetBrains s.r.o.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ package kotlinx.coroutines.experimental.test
18
+
19
+ import kotlinx.coroutines.experimental.*
20
+ import kotlinx.coroutines.experimental.internal.*
21
+ import java.util.concurrent.TimeUnit
22
+ import kotlin.coroutines.experimental.*
23
+
24
+ /* *
25
+ * This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
26
+ * code, especially tests, that deal with delays and timeouts in Coroutines.
27
+ *
28
+ * Provide an instance of this TestCoroutineContext when calling the *non-blocking* [launch] or [async]
29
+ * and then advance time or trigger the actions to make the co-routines execute as soon as possible.
30
+ *
31
+ * This works much like the *TestScheduler* in RxJava2, which allows to speed up tests that deal
32
+ * with non-blocking Rx chains that contain delays, timeouts, intervals and such.
33
+ *
34
+ * This dispatcher can also handle *blocking* coroutines that are started by [runBlocking].
35
+ * This dispatcher's virtual time will be automatically advanced based based on the delayed actions
36
+ * within the Coroutine(s).
37
+ *
38
+ * @param name A user-readable name for debugging purposes.
39
+ */
40
+ class TestCoroutineContext (private val name : String? = null ) : CoroutineContext {
41
+ private val uncaughtExceptions = mutableListOf<Throwable >()
42
+
43
+ private val ctxDispatcher = Dispatcher ()
44
+
45
+ private val ctxHandler = CoroutineExceptionHandler { _, exception ->
46
+ uncaughtExceptions + = exception
47
+ }
48
+
49
+ // The ordered queue for the runnable tasks.
50
+ private val queue = ThreadSafeHeap <TimedRunnable >()
51
+
52
+ // The per-scheduler global order counter.
53
+ private var counter = 0L
54
+
55
+ // Storing time in nanoseconds internally.
56
+ private var time = 0L
57
+
58
+ /* *
59
+ * Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
60
+ */
61
+ public val exceptions: List <Throwable > get() = uncaughtExceptions
62
+
63
+ // -- CoroutineContext implementation
64
+
65
+ public override fun <R > fold (initial : R , operation : (R , CoroutineContext .Element ) -> R ): R =
66
+ operation(operation(initial, ctxDispatcher), ctxHandler)
67
+
68
+ @Suppress(" UNCHECKED_CAST" )
69
+ public override fun <E : CoroutineContext .Element > get (key : CoroutineContext .Key <E >): E ? = when {
70
+ key == = ContinuationInterceptor -> ctxDispatcher as E
71
+ key == = CoroutineExceptionHandler -> ctxHandler as E
72
+ else -> null
73
+ }
74
+
75
+ public override fun minusKey (key : CoroutineContext .Key <* >): CoroutineContext = when {
76
+ key == = ContinuationInterceptor -> ctxHandler
77
+ key == = CoroutineExceptionHandler -> ctxDispatcher
78
+ else -> this
79
+ }
80
+
81
+ /* *
82
+ * Returns the current virtual clock-time as it is known to this CoroutineContext.
83
+ *
84
+ * @param unit The [TimeUnit] in which the clock-time must be returned.
85
+ * @return The virtual clock-time
86
+ */
87
+ public fun now (unit : TimeUnit = TimeUnit .MILLISECONDS )=
88
+ unit.convert(time, TimeUnit .NANOSECONDS )
89
+
90
+ /* *
91
+ * Moves the CoroutineContext's virtual clock forward by a specified amount of time.
92
+ *
93
+ * The returned delay-time can be larger than the specified delay-time if the code
94
+ * under test contains *blocking* Coroutines.
95
+ *
96
+ * @param delayTime The amount of time to move the CoroutineContext's clock forward.
97
+ * @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
98
+ * @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
99
+ */
100
+ public fun advanceTimeBy (delayTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ): Long {
101
+ val oldTime = time
102
+ advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit .NANOSECONDS )
103
+ return unit.convert(time - oldTime, TimeUnit .NANOSECONDS )
104
+ }
105
+
106
+ /* *
107
+ * Moves the CoroutineContext's clock-time to a particular moment in time.
108
+ *
109
+ * @param targetTime The point in time to which to move the CoroutineContext's clock.
110
+ * @param unit The [TimeUnit] in which [targetTime] is expressed.
111
+ */
112
+ fun advanceTimeTo (targetTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ) {
113
+ val nanoTime = unit.toNanos(targetTime)
114
+ triggerActions(nanoTime)
115
+ if (nanoTime > time) time = nanoTime
116
+ }
117
+
118
+ /* *
119
+ * Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
120
+ * before this CoroutineContext's present virtual clock-time.
121
+ */
122
+ public fun triggerActions () = triggerActions(time)
123
+
124
+ /* *
125
+ * Cancels all not yet triggered actions. Be careful calling this, since it can seriously
126
+ * mess with your coroutines work. This method should usually be called on tear-down of a
127
+ * unit test.
128
+ */
129
+ public fun cancelAllActions () {
130
+ // An 'is-empty' test is required to avoid a NullPointerException in the 'clear()' method
131
+ if (! queue.isEmpty) queue.clear()
132
+ }
133
+
134
+ /* *
135
+ * This method does nothing if there is one unhandled exception that satisfies the given predicate.
136
+ * Otherwise it throws an [AssertionError] with the given message.
137
+ *
138
+ * (this method will clear the list of unhandled exceptions)
139
+ *
140
+ * @param message Message of the [AssertionError]. Defaults to an empty String.
141
+ * @param predicate The predicate that must be satisfied.
142
+ */
143
+ public fun assertUnhandledException (message : String = "", predicate : (Throwable ) -> Boolean ) {
144
+ if (uncaughtExceptions.size != 1 || ! predicate(uncaughtExceptions[0 ])) throw AssertionError (message)
145
+ uncaughtExceptions.clear()
146
+ }
147
+
148
+ /* *
149
+ * This method does nothing if there are no unhandled exceptions or all of them satisfy the given predicate.
150
+ * Otherwise it throws an [AssertionError] with the given message.
151
+ *
152
+ * (this method will clear the list of unhandled exceptions)
153
+ *
154
+ * @param message Message of the [AssertionError]. Defaults to an empty String.
155
+ * @param predicate The predicate that must be satisfied.
156
+ */
157
+ public fun assertAllUnhandledExceptions (message : String = "", predicate : (Throwable ) -> Boolean ) {
158
+ if (! uncaughtExceptions.all(predicate)) throw AssertionError (message)
159
+ uncaughtExceptions.clear()
160
+ }
161
+
162
+ /* *
163
+ * This method does nothing if one or more unhandled exceptions satisfy the given predicate.
164
+ * Otherwise it throws an [AssertionError] with the given message.
165
+ *
166
+ * (this method will clear the list of unhandled exceptions)
167
+ *
168
+ * @param message Message of the [AssertionError]. Defaults to an empty String.
169
+ * @param predicate The predicate that must be satisfied.
170
+ */
171
+ public fun assertAnyUnhandledException (message : String = "", predicate : (Throwable ) -> Boolean ) {
172
+ if (! uncaughtExceptions.any(predicate)) throw AssertionError (message)
173
+ uncaughtExceptions.clear()
174
+ }
175
+
176
+ /* *
177
+ * This method does nothing if the list of unhandled exceptions satisfy the given predicate.
178
+ * Otherwise it throws an [AssertionError] with the given message.
179
+ *
180
+ * (this method will clear the list of unhandled exceptions)
181
+ *
182
+ * @param message Message of the [AssertionError]. Defaults to an empty String.
183
+ * @param predicate The predicate that must be satisfied.
184
+ */
185
+ public fun assertExceptions (message : String = "", predicate : (List <Throwable >) -> Boolean ) {
186
+ if (! predicate(uncaughtExceptions)) throw AssertionError (message)
187
+ uncaughtExceptions.clear()
188
+ }
189
+
190
+ private fun post (block : Runnable ) =
191
+ queue.addLast(TimedRunnable (block, counter++ ))
192
+
193
+ private fun postDelayed (block : Runnable , delayTime : Long ) =
194
+ TimedRunnable (block, counter++ , time + TimeUnit .MILLISECONDS .toNanos(delayTime))
195
+ .also {
196
+ queue.addLast(it)
197
+ }
198
+
199
+ private fun processNextEvent (): Long {
200
+ val current = queue.peek()
201
+ if (current != null ) {
202
+ /* * Automatically advance time for [EventLoop]-callbacks */
203
+ triggerActions(current.time)
204
+ }
205
+ return if (queue.isEmpty) Long .MAX_VALUE else 0L
206
+ }
207
+
208
+ private fun triggerActions (targetTime : Long ) {
209
+ while (true ) {
210
+ val current = queue.removeFirstIf { it.time <= targetTime } ? : break
211
+ // If the scheduled time is 0 (immediate) use current virtual time
212
+ if (current.time != 0L ) time = current.time
213
+ current.run ()
214
+ }
215
+ }
216
+
217
+ public override fun toString (): String = name ? : " TestCoroutineContext@$hexAddress "
218
+
219
+ private inner class Dispatcher : CoroutineDispatcher (), Delay, EventLoop {
220
+ override fun dispatch (context : CoroutineContext , block : Runnable ) = post(block)
221
+
222
+ override fun scheduleResumeAfterDelay (time : Long , unit : TimeUnit , continuation : CancellableContinuation <Unit >) {
223
+ postDelayed(Runnable {
224
+ with (continuation) { resumeUndispatched(Unit ) }
225
+ }, unit.toMillis(time))
226
+ }
227
+
228
+ override fun invokeOnTimeout (time : Long , unit : TimeUnit , block : Runnable ): DisposableHandle {
229
+ val node = postDelayed(block, unit.toMillis(time))
230
+ return object : DisposableHandle {
231
+ override fun dispose () {
232
+ queue.remove(node)
233
+ }
234
+ }
235
+ }
236
+
237
+ override fun processNextEvent () = this @TestCoroutineContext.processNextEvent()
238
+
239
+ public override fun toString (): String = " Dispatcher(${this @TestCoroutineContext} )"
240
+ }
241
+ }
242
+
243
+ private class TimedRunnable (
244
+ private val run : Runnable ,
245
+ private val count : Long = 0 ,
246
+ @JvmField internal val time : Long = 0
247
+ ) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
248
+ override var index: Int = 0
249
+
250
+ override fun run () = run.run ()
251
+
252
+ override fun compareTo (other : TimedRunnable ) = if (time == other.time) {
253
+ count.compareTo(other.count)
254
+ } else {
255
+ time.compareTo(other.time)
256
+ }
257
+
258
+ override fun toString () = " TimedRunnable(time=$time , run=$run )"
259
+ }
260
+
261
+ /* *
262
+ * Executes a block of code in which a unit-test can be written using the provided [TestCoroutineContext]. The provided
263
+ * [TestCoroutineContext] is available in the [testBody] as the `this` receiver.
264
+ *
265
+ * The [testBody] is executed and an [AssertionError] is thrown if the list of unhandled exceptions is not empty and
266
+ * contains any exception that is not a [CancellationException].
267
+ *
268
+ * If the [testBody] successfully executes one of the [TestCoroutineContext.assertAllUnhandledExceptions],
269
+ * [TestCoroutineContext.assertAnyUnhandledException], [TestCoroutineContext.assertUnhandledException] or
270
+ * [TestCoroutineContext.assertExceptions], the list of unhandled exceptions will have been cleared and this method will
271
+ * not throw an [AssertionError].
272
+ *
273
+ * @param testContext The provided [TestCoroutineContext]. If not specified, a default [TestCoroutineContext] will be
274
+ * provided instead.
275
+ * @param testBody The code of the unit-test.
276
+ */
277
+ public fun withTestContext (testContext : TestCoroutineContext = TestCoroutineContext (), testBody : TestCoroutineContext .() -> Unit ) {
278
+ with (testContext) {
279
+ testBody()
280
+
281
+ if (! exceptions.all { it is CancellationException }) {
282
+ throw AssertionError (" Coroutine encountered unhandled exceptions:\n ${exceptions} " )
283
+ }
284
+ }
285
+ }
286
+
287
+ /* Some helper functions */
288
+ public fun TestCoroutineContext.launch (
289
+ start : CoroutineStart = CoroutineStart .DEFAULT ,
290
+ parent : Job ? = null,
291
+ onCompletion : CompletionHandler ? = null,
292
+ block : suspend CoroutineScope .() -> Unit
293
+ ) = launch(this , start, parent, onCompletion, block)
294
+
295
+ public fun <T > TestCoroutineContext.async (
296
+ start : CoroutineStart = CoroutineStart .DEFAULT ,
297
+ parent : Job ? = null,
298
+ onCompletion : CompletionHandler ? = null,
299
+ block : suspend CoroutineScope .() -> T
300
+
301
+ ) = async(this , start, parent, onCompletion, block)
302
+
303
+ public fun <T > TestCoroutineContext.runBlocking (
304
+ block : suspend CoroutineScope .() -> T
305
+ ) = runBlocking(this , block)
0 commit comments