Skip to content

Commit be64c79

Browse files
committed
TestCoroutineContext optimized
* Direct implementation of CoroutineContext (without delegation) * Removed extra Handler class * Switched to atomicfu for atomics * Use string templates instead of String.format * Use ThreadSafeHeap instead of PriorityQueue * Removed MAX_DELAY
1 parent 10a3cec commit be64c79

File tree

2 files changed

+91
-127
lines changed

2 files changed

+91
-127
lines changed

core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/internal/ThreadSafeHeap.kt

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package kotlinx.coroutines.experimental.internal
1818

19+
import java.util.*
20+
1921
/**
2022
* @suppress **This is unstable API and it is subject to change.**
2123
*/
@@ -36,6 +38,11 @@ public class ThreadSafeHeap<T> where T: ThreadSafeHeapNode, T: Comparable<T> {
3638

3739
public val isEmpty: Boolean get() = size == 0
3840

41+
public fun clear() = synchronized(this) {
42+
Arrays.fill(a, 0, size, null)
43+
size = 0
44+
}
45+
3946
public fun peek(): T? = synchronized(this) { firstImpl() }
4047

4148
public fun removeFirstOrNull(): T? = synchronized(this) {

core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/test/TestCoroutineContext.kt

+84-127
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616

1717
package kotlinx.coroutines.experimental.test
1818

19+
import kotlinx.atomicfu.*
1920
import kotlinx.coroutines.experimental.*
20-
import java.util.concurrent.PriorityBlockingQueue
21+
import kotlinx.coroutines.experimental.internal.*
2122
import java.util.concurrent.TimeUnit
22-
import java.util.concurrent.atomic.AtomicLong
23-
import kotlin.coroutines.experimental.CoroutineContext
24-
25-
private const val MAX_DELAY = Long.MAX_VALUE - 1
23+
import kotlin.coroutines.experimental.*
2624

2725
/**
2826
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
@@ -41,31 +39,54 @@ private const val MAX_DELAY = Long.MAX_VALUE - 1
4139
* @param name A user-readable name for debugging purposes.
4240
*/
4341
class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
44-
private val caughtExceptions = mutableListOf<Throwable>()
42+
private val uncaughtExceptions = mutableListOf<Throwable>()
43+
44+
private val ctxDispatcher = Dispatcher()
45+
46+
private val ctxHandler = CoroutineExceptionHandler { _, exception ->
47+
uncaughtExceptions += exception
48+
}
4549

46-
private val context = Dispatcher() + CoroutineExceptionHandler(this::handleException)
50+
// The ordered queue for the runnable tasks.
51+
private val queue = ThreadSafeHeap<TimedRunnable>()
4752

48-
private val handler = TestHandler()
53+
// The per-scheduler global order counter.
54+
private val counter = atomic(0L)
55+
56+
// Storing time in nanoseconds internally.
57+
private val time = atomic(0L)
4958

5059
/**
5160
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
5261
*/
53-
val exceptions: List<Throwable> get() = caughtExceptions
62+
public val exceptions: List<Throwable> get() = uncaughtExceptions
5463

55-
override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
56-
context.fold(initial, operation)
64+
// -- CoroutineContext implementation
5765

58-
override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = context[key]
66+
public override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
67+
operation(operation(initial, ctxDispatcher), ctxHandler)
5968

60-
override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = context.minusKey(key)
69+
@Suppress("UNCHECKED_CAST")
70+
public override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = when {
71+
key === ContinuationInterceptor -> ctxDispatcher as E
72+
key === CoroutineExceptionHandler -> ctxHandler as E
73+
else -> null
74+
}
6175

76+
public override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext = when {
77+
key === ContinuationInterceptor -> ctxHandler
78+
key === CoroutineExceptionHandler -> ctxDispatcher
79+
else -> this
80+
}
81+
6282
/**
6383
* Returns the current virtual clock-time as it is known to this CoroutineContext.
6484
*
6585
* @param unit The [TimeUnit] in which the clock-time must be returned.
6686
* @return The virtual clock-time
6787
*/
68-
fun now(unit: TimeUnit = TimeUnit.MILLISECONDS): Long = handler.now(unit)
88+
public fun now(unit: TimeUnit = TimeUnit.MILLISECONDS)=
89+
unit.convert(time.value, TimeUnit.NANOSECONDS)
6990

7091
/**
7192
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
@@ -77,8 +98,11 @@ class TestCoroutineContext(private val name: String? = null) : CoroutineContext
7798
* @param unit The [TimeUnit] in which [delayTime] and the return value is expressed.
7899
* @return The amount of delay-time that this CoroutinesContext's clock has been forwarded.
79100
*/
80-
fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) =
81-
handler.advanceTimeBy(delayTime, unit)
101+
public fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Long {
102+
val oldTime = time.value
103+
advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
104+
return unit.convert(time.value - oldTime, TimeUnit.NANOSECONDS)
105+
}
82106

83107
/**
84108
* Moves the CoroutineContext's clock-time to a particular moment in time.
@@ -87,158 +111,91 @@ class TestCoroutineContext(private val name: String? = null) : CoroutineContext
87111
* @param unit The [TimeUnit] in which [targetTime] is expressed.
88112
*/
89113
fun advanceTimeTo(targetTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
90-
handler.advanceTimeTo(targetTime, unit)
114+
val nanoTime = unit.toNanos(targetTime)
115+
triggerActions(nanoTime)
116+
if (nanoTime > time.value) time.value = nanoTime
91117
}
92118

93119
/**
94120
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
95121
* before this CoroutineContext's present virtual clock-time.
96122
*/
97-
fun triggerActions() {
98-
handler.triggerActions()
99-
}
123+
public fun triggerActions() = triggerActions(time.value)
100124

101125
/**
102126
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
103127
* mess with your coroutines work. This method should usually be called on tear-down of a
104128
* unit test.
105129
*/
106-
fun cancelAllActions() {
107-
handler.cancelAllActions()
108-
}
130+
public fun cancelAllActions() = queue.clear()
109131

110-
override fun toString(): String = name ?: super.toString()
132+
private fun post(block: Runnable) =
133+
queue.addLast(TimedRunnable(block, counter.getAndIncrement()))
111134

112-
override fun equals(other: Any?): Boolean = (other is TestCoroutineContext) && (other.handler === handler)
135+
private fun postDelayed(block: Runnable, delayTime: Long) =
136+
TimedRunnable(block, counter.getAndIncrement(), time.value + TimeUnit.MILLISECONDS.toNanos(delayTime))
137+
.also {
138+
queue.addLast(it)
139+
}
113140

114-
override fun hashCode(): Int = System.identityHashCode(handler)
141+
private fun processNextEvent(): Long {
142+
val current = queue.peek()
143+
if (current != null) {
144+
/** Automatically advance time for [EventLoop]-callbacks */
145+
triggerActions(current.time)
146+
}
147+
return if (queue.isEmpty) Long.MAX_VALUE else 0L
148+
}
115149

116-
private fun handleException(@Suppress("UNUSED_PARAMETER") context: CoroutineContext, exception: Throwable) {
117-
caughtExceptions += exception
150+
private fun triggerActions(targetTime: Long) {
151+
while (true) {
152+
val current = queue.removeFirstIf { it.time <= targetTime } ?: break
153+
// If the scheduled time is 0 (immediate) use current virtual time
154+
if (current.time != 0L) time.value = current.time
155+
current.run()
156+
}
118157
}
119158

159+
public override fun toString(): String = name ?: "TestCoroutineContext@$hexAddress"
160+
120161
private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
121-
override fun dispatch(context: CoroutineContext, block: Runnable) {
122-
handler.post(block)
123-
}
162+
override fun dispatch(context: CoroutineContext, block: Runnable) = post(block)
124163

125164
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
126-
handler.postDelayed(Runnable {
165+
postDelayed(Runnable {
127166
with(continuation) { resumeUndispatched(Unit) }
128-
}, unit.toMillis(time).coerceAtMost(MAX_DELAY))
167+
}, unit.toMillis(time))
129168
}
130169

131170
override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
132-
handler.postDelayed(block, unit.toMillis(time).coerceAtMost(MAX_DELAY))
171+
val node = postDelayed(block, unit.toMillis(time))
133172
return object : DisposableHandle {
134173
override fun dispose() {
135-
handler.removeCallbacks(block)
174+
queue.remove(node)
136175
}
137176
}
138177
}
139178

140-
override fun processNextEvent() = handler.processNextEvent()
141-
}
142-
}
143-
144-
private class TestHandler {
145-
// The ordered queue for the runnable tasks.
146-
private val queue = PriorityBlockingQueue<TimedRunnable>(16)
147-
148-
// The per-scheduler global order counter.
149-
private var counter = AtomicLong(0L)
150-
151-
// Storing time in nanoseconds internally.
152-
private var time = AtomicLong(0L)
153-
154-
private val nextEventTime get() = if (queue.isEmpty()) Long.MAX_VALUE else 0L
155-
156-
internal fun post(block: Runnable) {
157-
queue.add(TimedRunnable(block, counter.getAndIncrement()))
158-
}
159-
160-
internal fun postDelayed(block: Runnable, delayTime: Long) {
161-
queue.add(TimedRunnable(block, counter.getAndIncrement(), time.get() + TimeUnit.MILLISECONDS.toNanos(delayTime)))
162-
}
163-
164-
internal fun removeCallbacks(block: Runnable) {
165-
queue.remove(TimedRunnable(block))
166-
}
167-
168-
internal fun now(unit: TimeUnit) = unit.convert(time.get(), TimeUnit.NANOSECONDS)
169-
170-
internal fun advanceTimeBy(delayTime: Long, unit: TimeUnit): Long {
171-
val oldTime = time.get()
172-
173-
advanceTimeTo(oldTime + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
174-
175-
return unit.convert(time.get() - oldTime, TimeUnit.NANOSECONDS)
176-
}
177-
178-
internal fun advanceTimeTo(targetTime: Long, unit: TimeUnit) {
179-
val nanoTime = unit.toNanos(targetTime)
180-
181-
triggerActions(nanoTime)
182-
183-
if (nanoTime > time.get()) {
184-
time.set(nanoTime)
185-
}
186-
}
187-
188-
internal fun triggerActions() {
189-
triggerActions(time.get())
190-
}
179+
override fun processNextEvent() = this@TestCoroutineContext.processNextEvent()
191180

192-
internal fun cancelAllActions() {
193-
queue.clear()
194-
}
195-
196-
internal fun processNextEvent(): Long {
197-
val current = queue.peek()
198-
if (current != null) {
199-
/** Automatically advance time for [EventLoop]-callbacks */
200-
triggerActions(current.time)
201-
}
202-
203-
return nextEventTime
204-
}
205-
206-
private fun triggerActions(targetTime: Long) {
207-
while (true) {
208-
val current = queue.peek()
209-
if (current == null || current.time > targetTime) {
210-
break
211-
}
212-
213-
// If the scheduled time is 0 (immediate) use current virtual time
214-
if (current.time != 0L) {
215-
time.set(current.time)
216-
}
217-
218-
queue.remove(current)
219-
current.run()
220-
}
181+
public override fun toString(): String = "Dispatcher(${this@TestCoroutineContext})"
221182
}
222183
}
223184

224185
private class TimedRunnable(
225-
private val run: Runnable,
226-
private val count: Long = 0,
227-
internal val time: Long = 0
228-
) : Comparable<TimedRunnable>, Runnable {
229-
override fun run() {
230-
run.run()
231-
}
186+
private val run: Runnable,
187+
private val count: Long = 0,
188+
@JvmField internal val time: Long = 0
189+
) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
190+
override var index: Int = 0
191+
192+
override fun run() = run.run()
232193

233194
override fun compareTo(other: TimedRunnable) = if (time == other.time) {
234195
count.compareTo(other.count)
235196
} else {
236197
time.compareTo(other.time)
237198
}
238199

239-
override fun hashCode() = run.hashCode()
240-
241-
override fun equals(other: Any?) = other is TimedRunnable && (run == other.run)
242-
243-
override fun toString() = String.format("TimedRunnable(time = %d, run = %s)", time, run.toString())
200+
override fun toString() = "TimedRunnable(time=$time, run=$run)"
244201
}

0 commit comments

Comments
 (0)