1
+ /*
2
+ * Copyright 2016-2018 JetBrains s.r.o.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ package kotlinx.coroutines.experimental.test
18
+
19
+ import kotlinx.coroutines.experimental.*
20
+ import java.util.concurrent.PriorityBlockingQueue
21
+ import java.util.concurrent.TimeUnit
22
+ import java.util.concurrent.atomic.AtomicLong
23
+ import kotlin.coroutines.experimental.CoroutineContext
24
+
25
+ private const val MAX_DELAY = Long .MAX_VALUE - 1
26
+
27
+ /* *
28
+ * This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
29
+ * code, especially tests, that deal with delays and timeouts in Coroutines.
30
+ *
31
+ * Provide an instance of this TestCoroutineContext when calling the *non-blocking* [launch] or [async]
32
+ * and then advance time or trigger the actions to make the co-routines execute as soon as possible.
33
+ *
34
+ * This works much like the *TestScheduler* in RxJava2, which allows to speed up tests that deal
35
+ * with non-blocking Rx chains that contain delays, timeouts, intervals and such.
36
+ *
37
+ * This dispatcher can also handle *blocking* coroutines that are started by [runBlocking].
38
+ * This dispatcher's virtual time will be automatically advanced based based on the delayed actions
39
+ * within the Coroutine(s).
40
+ *
41
+ * @param name A user-readable name for debugging purposes.
42
+ */
43
+ class TestCoroutineContext (private val name : String? = null ) : CoroutineContext {
44
+ private val caughtExceptions = mutableListOf<Throwable >()
45
+
46
+ private val context = Dispatcher () + CoroutineExceptionHandler (this ::handleException)
47
+
48
+ private val handler = TestHandler ()
49
+
50
+ /* *
51
+ * Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
52
+ */
53
+ val exceptions: List <Throwable > get() = caughtExceptions
54
+
55
+ override fun <R > fold (initial : R , operation : (R , CoroutineContext .Element ) -> R ): R =
56
+ context.fold(initial, operation)
57
+
58
+ override fun <E : CoroutineContext .Element > get (key : CoroutineContext .Key <E >): E ? = context[key]
59
+
60
+ override fun minusKey (key : CoroutineContext .Key <* >): CoroutineContext = context.minusKey(key)
61
+
62
+ /* *
63
+ * Returns the current virtual clock-time as it is known to this CoroutineContext.
64
+ *
65
+ * @param unit The [TimeUnit] in which the clock-time must be returned.
66
+ * @return The virtual clock-time
67
+ */
68
+ fun now (unit : TimeUnit = TimeUnit .MILLISECONDS ): Long = handler.now(unit)
69
+
70
+ /* *
71
+ * Moves the CoroutineContext's virtual clock forward by a specified amount of time.
72
+ *
73
+ * The returned delay-time can be larger than the specified delay-time if the code
74
+ * under test contains *blocking* Coroutines.
75
+ *
76
+ * @param delayTime The amount of time to move the CoroutineContext's clock forward.
77
+ * @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
78
+ * @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
79
+ */
80
+ fun advanceTimeBy (delayTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ) =
81
+ handler.advanceTimeBy(delayTime, unit)
82
+
83
+ /* *
84
+ * Moves the CoroutineContext's clock-time to a particular moment in time.
85
+ *
86
+ * @param targetTime The point in time to which to move the CoroutineContext's clock.
87
+ * @param unit The [TimeUnit] in which [targetTime] is expressed.
88
+ */
89
+ fun advanceTimeTo (targetTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ) {
90
+ handler.advanceTimeTo(targetTime, unit)
91
+ }
92
+
93
+ /* *
94
+ * Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
95
+ * before this CoroutineContext's present virtual clock-time.
96
+ */
97
+ fun triggerActions () {
98
+ handler.triggerActions()
99
+ }
100
+
101
+ /* *
102
+ * Cancels all not yet triggered actions. Be careful calling this, since it can seriously
103
+ * mess with your coroutines work. This method should usually be called on tear-down of a
104
+ * unit test.
105
+ */
106
+ fun cancelAllActions () {
107
+ handler.cancelAllActions()
108
+ }
109
+
110
+ override fun toString (): String = name ? : super .toString()
111
+
112
+ override fun equals (other : Any? ): Boolean = (other is TestCoroutineContext ) && (other.handler == = handler)
113
+
114
+ override fun hashCode (): Int = System .identityHashCode(handler)
115
+
116
+ private fun handleException (@Suppress(" UNUSED_PARAMETER" ) context : CoroutineContext , exception : Throwable ) {
117
+ caughtExceptions + = exception
118
+ }
119
+
120
+ private inner class Dispatcher : CoroutineDispatcher (), Delay, EventLoop {
121
+ override fun dispatch (context : CoroutineContext , block : Runnable ) {
122
+ handler.post(block)
123
+ }
124
+
125
+ override fun scheduleResumeAfterDelay (time : Long , unit : TimeUnit , continuation : CancellableContinuation <Unit >) {
126
+ handler.postDelayed(Runnable {
127
+ with (continuation) { resumeUndispatched(Unit ) }
128
+ }, unit.toMillis(time).coerceAtMost(MAX_DELAY ))
129
+ }
130
+
131
+ override fun invokeOnTimeout (time : Long , unit : TimeUnit , block : Runnable ): DisposableHandle {
132
+ handler.postDelayed(block, unit.toMillis(time).coerceAtMost(MAX_DELAY ))
133
+ return object : DisposableHandle {
134
+ override fun dispose () {
135
+ handler.removeCallbacks(block)
136
+ }
137
+ }
138
+ }
139
+
140
+ override fun processNextEvent () = handler.processNextEvent()
141
+ }
142
+ }
143
+
144
+ private class TestHandler {
145
+ // The ordered queue for the runnable tasks.
146
+ private val queue = PriorityBlockingQueue <TimedRunnable >(16 )
147
+
148
+ // The per-scheduler global order counter.
149
+ private var counter = AtomicLong (0L )
150
+
151
+ // Storing time in nanoseconds internally.
152
+ private var time = AtomicLong (0L )
153
+
154
+ private val nextEventTime get() = if (queue.isEmpty()) Long .MAX_VALUE else 0L
155
+
156
+ internal fun post (block : Runnable ) {
157
+ queue.add(TimedRunnable (block, counter.getAndIncrement()))
158
+ }
159
+
160
+ internal fun postDelayed (block : Runnable , delayTime : Long ) {
161
+ queue.add(TimedRunnable (block, counter.getAndIncrement(), time.get() + TimeUnit .MILLISECONDS .toNanos(delayTime)))
162
+ }
163
+
164
+ internal fun removeCallbacks (block : Runnable ) {
165
+ queue.remove(TimedRunnable (block))
166
+ }
167
+
168
+ internal fun now (unit : TimeUnit ) = unit.convert(time.get(), TimeUnit .NANOSECONDS )
169
+
170
+ internal fun advanceTimeBy (delayTime : Long , unit : TimeUnit ): Long {
171
+ val oldTime = time.get()
172
+
173
+ advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit .NANOSECONDS )
174
+
175
+ return unit.convert(time.get() - oldTime, TimeUnit .NANOSECONDS )
176
+ }
177
+
178
+ internal fun advanceTimeTo (targetTime : Long , unit : TimeUnit ) {
179
+ val nanoTime = unit.toNanos(targetTime)
180
+
181
+ triggerActions(nanoTime)
182
+
183
+ if (nanoTime > time.get()) {
184
+ time.set(nanoTime)
185
+ }
186
+ }
187
+
188
+ internal fun triggerActions () {
189
+ triggerActions(time.get())
190
+ }
191
+
192
+ internal fun cancelAllActions () {
193
+ queue.clear()
194
+ }
195
+
196
+ internal fun processNextEvent (): Long {
197
+ val current = queue.peek()
198
+ if (current != null ) {
199
+ /* * Automatically advance time for [EventLoop]-callbacks */
200
+ triggerActions(current.time)
201
+ }
202
+
203
+ return nextEventTime
204
+ }
205
+
206
+ private fun triggerActions (targetTime : Long ) {
207
+ while (true ) {
208
+ val current = queue.peek()
209
+ if (current == null || current.time > targetTime) {
210
+ break
211
+ }
212
+
213
+ // If the scheduled time is 0 (immediate) use current virtual time
214
+ if (current.time != 0L ) {
215
+ time.set(current.time)
216
+ }
217
+
218
+ queue.remove(current)
219
+ current.run ()
220
+ }
221
+ }
222
+ }
223
+
224
+ private class TimedRunnable (
225
+ private val run : Runnable ,
226
+ private val count : Long = 0 ,
227
+ internal val time : Long = 0
228
+ ) : Comparable<TimedRunnable>, Runnable {
229
+ override fun run () {
230
+ run.run ()
231
+ }
232
+
233
+ override fun compareTo (other : TimedRunnable ) = if (time == other.time) {
234
+ count.compareTo(other.count)
235
+ } else {
236
+ time.compareTo(other.time)
237
+ }
238
+
239
+ override fun hashCode () = run.hashCode()
240
+
241
+ override fun equals (other : Any? ) = other is TimedRunnable && (run == other.run)
242
+
243
+ override fun toString () = String .format(" TimedRunnable(time = %d, run = %s)" , time, run.toString())
244
+ }
0 commit comments