16
16
17
17
package kotlinx.coroutines.experimental.test
18
18
19
+ import kotlinx.atomicfu.*
19
20
import kotlinx.coroutines.experimental.*
20
- import java.util.concurrent.PriorityBlockingQueue
21
+ import kotlinx.coroutines.experimental.internal.*
21
22
import java.util.concurrent.TimeUnit
22
- import java.util.concurrent.atomic.AtomicLong
23
- import kotlin.coroutines.experimental.CoroutineContext
24
-
25
- private const val MAX_DELAY = Long .MAX_VALUE - 1
23
+ import kotlin.coroutines.experimental.*
26
24
27
25
/* *
28
26
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
@@ -41,31 +39,54 @@ private const val MAX_DELAY = Long.MAX_VALUE - 1
41
39
* @param name A user-readable name for debugging purposes.
42
40
*/
43
41
class TestCoroutineContext (private val name : String? = null ) : CoroutineContext {
44
- private val caughtExceptions = mutableListOf<Throwable >()
42
+ private val uncaughtExceptions = mutableListOf<Throwable >()
43
+
44
+ private val ctxDispatcher = Dispatcher ()
45
+
46
+ private val ctxHandler = CoroutineExceptionHandler { _, exception ->
47
+ uncaughtExceptions + = exception
48
+ }
45
49
46
- private val context = Dispatcher () + CoroutineExceptionHandler (this ::handleException)
50
+ // The ordered queue for the runnable tasks.
51
+ private val queue = ThreadSafeHeap <TimedRunnable >()
47
52
48
- private val handler = TestHandler ()
53
+ // The per-scheduler global order counter.
54
+ private val counter = atomic(0L )
55
+
56
+ // Storing time in nanoseconds internally.
57
+ private val time = atomic(0L )
49
58
50
59
/* *
51
60
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
52
61
*/
53
- val exceptions: List <Throwable > get() = caughtExceptions
62
+ public val exceptions: List <Throwable > get() = uncaughtExceptions
54
63
55
- override fun <R > fold (initial : R , operation : (R , CoroutineContext .Element ) -> R ): R =
56
- context.fold(initial, operation)
64
+ // -- CoroutineContext implementation
57
65
58
- override fun <E : CoroutineContext .Element > get (key : CoroutineContext .Key <E >): E ? = context[key]
66
+ public override fun <R > fold (initial : R , operation : (R , CoroutineContext .Element ) -> R ): R =
67
+ operation(operation(initial, ctxDispatcher), ctxHandler)
59
68
60
- override fun minusKey (key : CoroutineContext .Key <* >): CoroutineContext = context.minusKey(key)
69
+ @Suppress(" UNCHECKED_CAST" )
70
+ public override fun <E : CoroutineContext .Element > get (key : CoroutineContext .Key <E >): E ? = when {
71
+ key == = ContinuationInterceptor -> ctxDispatcher as E
72
+ key == = CoroutineExceptionHandler -> ctxHandler as E
73
+ else -> null
74
+ }
61
75
76
+ public override fun minusKey (key : CoroutineContext .Key <* >): CoroutineContext = when {
77
+ key == = ContinuationInterceptor -> ctxHandler
78
+ key == = CoroutineExceptionHandler -> ctxDispatcher
79
+ else -> this
80
+ }
81
+
62
82
/* *
63
83
* Returns the current virtual clock-time as it is known to this CoroutineContext.
64
84
*
65
85
* @param unit The [TimeUnit] in which the clock-time must be returned.
66
86
* @return The virtual clock-time
67
87
*/
68
- fun now (unit : TimeUnit = TimeUnit .MILLISECONDS ): Long = handler.now(unit)
88
+ public fun now (unit : TimeUnit = TimeUnit .MILLISECONDS )=
89
+ unit.convert(time.value, TimeUnit .NANOSECONDS )
69
90
70
91
/* *
71
92
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
@@ -77,8 +98,11 @@ class TestCoroutineContext(private val name: String? = null) : CoroutineContext
77
98
* @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
78
99
* @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
79
100
*/
80
- fun advanceTimeBy (delayTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ) =
81
- handler.advanceTimeBy(delayTime, unit)
101
+ public fun advanceTimeBy (delayTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ): Long {
102
+ val oldTime = time.value
103
+ advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit .NANOSECONDS )
104
+ return unit.convert(time.value - oldTime, TimeUnit .NANOSECONDS )
105
+ }
82
106
83
107
/* *
84
108
* Moves the CoroutineContext's clock-time to a particular moment in time.
@@ -87,158 +111,91 @@ class TestCoroutineContext(private val name: String? = null) : CoroutineContext
87
111
* @param unit The [TimeUnit] in which [targetTime] is expressed.
88
112
*/
89
113
fun advanceTimeTo (targetTime : Long , unit : TimeUnit = TimeUnit .MILLISECONDS ) {
90
- handler.advanceTimeTo(targetTime, unit)
114
+ val nanoTime = unit.toNanos(targetTime)
115
+ triggerActions(nanoTime)
116
+ if (nanoTime > time.value) time.value = nanoTime
91
117
}
92
118
93
119
/* *
94
120
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
95
121
* before this CoroutineContext's present virtual clock-time.
96
122
*/
97
- fun triggerActions () {
98
- handler.triggerActions()
99
- }
123
+ public fun triggerActions () = triggerActions(time.value)
100
124
101
125
/* *
102
126
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
103
127
* mess with your coroutines work. This method should usually be called on tear-down of a
104
128
* unit test.
105
129
*/
106
- fun cancelAllActions () {
107
- handler.cancelAllActions()
108
- }
130
+ public fun cancelAllActions () = queue.clear()
109
131
110
- override fun toString (): String = name ? : super .toString()
132
+ private fun post (block : Runnable ) =
133
+ queue.addLast(TimedRunnable (block, counter.getAndIncrement()))
111
134
112
- override fun equals (other : Any? ): Boolean = (other is TestCoroutineContext ) && (other.handler == = handler)
135
+ private fun postDelayed (block : Runnable , delayTime : Long ) =
136
+ TimedRunnable (block, counter.getAndIncrement(), time.value + TimeUnit .MILLISECONDS .toNanos(delayTime))
137
+ .also {
138
+ queue.addLast(it)
139
+ }
113
140
114
- override fun hashCode (): Int = System .identityHashCode(handler)
141
+ private fun processNextEvent (): Long {
142
+ val current = queue.peek()
143
+ if (current != null ) {
144
+ /* * Automatically advance time for [EventLoop]-callbacks */
145
+ triggerActions(current.time)
146
+ }
147
+ return if (queue.isEmpty) Long .MAX_VALUE else 0L
148
+ }
115
149
116
- private fun handleException (@Suppress(" UNUSED_PARAMETER" ) context : CoroutineContext , exception : Throwable ) {
117
- caughtExceptions + = exception
150
+ private fun triggerActions (targetTime : Long ) {
151
+ while (true ) {
152
+ val current = queue.removeFirstIf { it.time <= targetTime } ? : break
153
+ // If the scheduled time is 0 (immediate) use current virtual time
154
+ if (current.time != 0L ) time.value = current.time
155
+ current.run ()
156
+ }
118
157
}
119
158
159
+ public override fun toString (): String = name ? : " TestCoroutineContext@$hexAddress "
160
+
120
161
private inner class Dispatcher : CoroutineDispatcher (), Delay, EventLoop {
121
- override fun dispatch (context : CoroutineContext , block : Runnable ) {
122
- handler.post(block)
123
- }
162
+ override fun dispatch (context : CoroutineContext , block : Runnable ) = post(block)
124
163
125
164
override fun scheduleResumeAfterDelay (time : Long , unit : TimeUnit , continuation : CancellableContinuation <Unit >) {
126
- handler. postDelayed(Runnable {
165
+ postDelayed(Runnable {
127
166
with (continuation) { resumeUndispatched(Unit ) }
128
- }, unit.toMillis(time).coerceAtMost( MAX_DELAY ) )
167
+ }, unit.toMillis(time))
129
168
}
130
169
131
170
override fun invokeOnTimeout (time : Long , unit : TimeUnit , block : Runnable ): DisposableHandle {
132
- handler. postDelayed(block, unit.toMillis(time).coerceAtMost( MAX_DELAY ))
171
+ val node = postDelayed(block, unit.toMillis(time))
133
172
return object : DisposableHandle {
134
173
override fun dispose () {
135
- handler.removeCallbacks(block )
174
+ queue.remove(node )
136
175
}
137
176
}
138
177
}
139
178
140
- override fun processNextEvent () = handler.processNextEvent()
141
- }
142
- }
143
-
144
- private class TestHandler {
145
- // The ordered queue for the runnable tasks.
146
- private val queue = PriorityBlockingQueue <TimedRunnable >(16 )
147
-
148
- // The per-scheduler global order counter.
149
- private var counter = AtomicLong (0L )
150
-
151
- // Storing time in nanoseconds internally.
152
- private var time = AtomicLong (0L )
153
-
154
- private val nextEventTime get() = if (queue.isEmpty()) Long .MAX_VALUE else 0L
155
-
156
- internal fun post (block : Runnable ) {
157
- queue.add(TimedRunnable (block, counter.getAndIncrement()))
158
- }
159
-
160
- internal fun postDelayed (block : Runnable , delayTime : Long ) {
161
- queue.add(TimedRunnable (block, counter.getAndIncrement(), time.get() + TimeUnit .MILLISECONDS .toNanos(delayTime)))
162
- }
163
-
164
- internal fun removeCallbacks (block : Runnable ) {
165
- queue.remove(TimedRunnable (block))
166
- }
167
-
168
- internal fun now (unit : TimeUnit ) = unit.convert(time.get(), TimeUnit .NANOSECONDS )
169
-
170
- internal fun advanceTimeBy (delayTime : Long , unit : TimeUnit ): Long {
171
- val oldTime = time.get()
172
-
173
- advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit .NANOSECONDS )
174
-
175
- return unit.convert(time.get() - oldTime, TimeUnit .NANOSECONDS )
176
- }
177
-
178
- internal fun advanceTimeTo (targetTime : Long , unit : TimeUnit ) {
179
- val nanoTime = unit.toNanos(targetTime)
180
-
181
- triggerActions(nanoTime)
182
-
183
- if (nanoTime > time.get()) {
184
- time.set(nanoTime)
185
- }
186
- }
187
-
188
- internal fun triggerActions () {
189
- triggerActions(time.get())
190
- }
179
+ override fun processNextEvent () = this @TestCoroutineContext.processNextEvent()
191
180
192
- internal fun cancelAllActions () {
193
- queue.clear()
194
- }
195
-
196
- internal fun processNextEvent (): Long {
197
- val current = queue.peek()
198
- if (current != null ) {
199
- /* * Automatically advance time for [EventLoop]-callbacks */
200
- triggerActions(current.time)
201
- }
202
-
203
- return nextEventTime
204
- }
205
-
206
- private fun triggerActions (targetTime : Long ) {
207
- while (true ) {
208
- val current = queue.peek()
209
- if (current == null || current.time > targetTime) {
210
- break
211
- }
212
-
213
- // If the scheduled time is 0 (immediate) use current virtual time
214
- if (current.time != 0L ) {
215
- time.set(current.time)
216
- }
217
-
218
- queue.remove(current)
219
- current.run ()
220
- }
181
+ public override fun toString (): String = " Dispatcher(${this @TestCoroutineContext} )"
221
182
}
222
183
}
223
184
224
185
private class TimedRunnable (
225
- private val run : Runnable ,
226
- private val count : Long = 0 ,
227
- internal val time : Long = 0
228
- ) : Comparable<TimedRunnable>, Runnable {
229
- override fun run () {
230
- run. run ()
231
- }
186
+ private val run : Runnable ,
187
+ private val count : Long = 0 ,
188
+ @JvmField internal val time : Long = 0
189
+ ) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
190
+ override var index : Int = 0
191
+
192
+ override fun run () = run. run ()
232
193
233
194
override fun compareTo (other : TimedRunnable ) = if (time == other.time) {
234
195
count.compareTo(other.count)
235
196
} else {
236
197
time.compareTo(other.time)
237
198
}
238
199
239
- override fun hashCode () = run.hashCode()
240
-
241
- override fun equals (other : Any? ) = other is TimedRunnable && (run == other.run)
242
-
243
- override fun toString () = String .format(" TimedRunnable(time = %d, run = %s)" , time, run.toString())
200
+ override fun toString () = " TimedRunnable(time=$time , run=$run )"
244
201
}
0 commit comments