Skip to content

Commit 85725e8

Browse files
elizarovqwwdfsad
authored andcommitted
Fixed rescheduling, synchronization & task disposal in EventLoop
* CRITICAL: Fixed synchronization of ThreadSafeHeap.removeFirstIf * Reworked code to make sure that memory for disposed task never leaks without violating ThreadSafeHeap encapsulation
1 parent 7764e43 commit 85725e8

File tree

4 files changed

+55
-34
lines changed

4 files changed

+55
-34
lines changed

core/kotlinx-coroutines-core/src/EventLoop.kt

+42-25
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ public fun EventLoop(thread: Thread = Thread.currentThread(), parentJob: Job? =
7272
public fun EventLoop_Deprecated(thread: Thread = Thread.currentThread(), parentJob: Job? = null): CoroutineDispatcher =
7373
EventLoop(thread, parentJob) as CoroutineDispatcher
7474

75-
internal const val DELAYED = 0
76-
internal const val REMOVED = 1
77-
internal const val RESCHEDULED = 2
75+
private val DISPOSED_TASK = Symbol("REMOVED_TASK")
76+
77+
// results for scheduleImpl
78+
private const val SCHEDULE_OK = 0
79+
private const val SCHEDULE_COMPLETED = 1
80+
private const val SCHEDULE_DISPOSED = 2
7881

7982
private const val MS_TO_NS = 1_000_000L
8083
private const val MAX_MS = Long.MAX_VALUE / MS_TO_NS
@@ -242,22 +245,23 @@ internal abstract class EventLoopBase: CoroutineDispatcher(), Delay, EventLoop {
242245
}
243246

244247
internal fun schedule(delayedTask: DelayedTask) {
245-
if (scheduleImpl(delayedTask)) {
246-
if (shouldUnpark(delayedTask)) unpark()
247-
} else {
248-
DefaultExecutor.schedule(delayedTask)
248+
when (scheduleImpl(delayedTask)) {
249+
SCHEDULE_OK -> if (shouldUnpark(delayedTask)) unpark()
250+
SCHEDULE_COMPLETED -> DefaultExecutor.schedule(delayedTask)
251+
SCHEDULE_DISPOSED -> {} // do nothing -- task was already disposed
252+
else -> error("unexpected result")
249253
}
250254
}
251255

252256
private fun shouldUnpark(task: DelayedTask): Boolean = _delayed.value?.peek() === task
253257

254-
private fun scheduleImpl(delayedTask: DelayedTask): Boolean {
255-
if (isCompleted) return false
258+
private fun scheduleImpl(delayedTask: DelayedTask): Int {
259+
if (isCompleted) return SCHEDULE_COMPLETED
256260
val delayed = _delayed.value ?: run {
257261
_delayed.compareAndSet(null, ThreadSafeHeap())
258262
_delayed.value!!
259263
}
260-
return delayed.addLastIf(delayedTask) { !isCompleted }
264+
return delayedTask.schedule(delayed)
261265
}
262266

263267
internal fun removeDelayedImpl(delayedTask: DelayedTask) {
@@ -273,6 +277,13 @@ internal abstract class EventLoopBase: CoroutineDispatcher(), Delay, EventLoop {
273277
// This is a "soft" (normal) shutdown
274278
protected fun rescheduleAllDelayed() {
275279
while (true) {
280+
/*
281+
* `removeFirstOrNull` below is the only operation on DelayedTask & ThreadSafeHeap that is not
282+
* synchronized on DelayedTask itself. All other operation are synchronized both on
283+
* DelayedTask & ThreadSafeHeap instances (in this order). It is still safe, because `dispose`
284+
* first removes DelayedTask from the heap (under synchronization) then
285+
* assign "_heap = DISPOSED_TASK", so there cannot be ever a race to _heap reference update.
286+
*/
276287
val delayedTask = _delayed.value?.removeFirstOrNull() ?: break
277288
delayedTask.rescheduleOnShutdown()
278289
}
@@ -281,8 +292,17 @@ internal abstract class EventLoopBase: CoroutineDispatcher(), Delay, EventLoop {
281292
internal abstract inner class DelayedTask(
282293
timeMillis: Long
283294
) : Runnable, Comparable<DelayedTask>, DisposableHandle, ThreadSafeHeapNode {
295+
private var _heap: Any? = null // null | ThreadSafeHeap | DISPOSED_TASK
296+
297+
override var heap: ThreadSafeHeap<*>?
298+
get() = _heap as? ThreadSafeHeap<*>
299+
set(value) {
300+
require(_heap !== DISPOSED_TASK) // this can never happen, it is always checked before adding/removing
301+
_heap = value
302+
}
303+
284304
override var index: Int = -1
285-
@JvmField var state = DELAYED // Guarded by by lock on this task for reschedule/dispose purposes
305+
286306
@JvmField val nanoTime: Long = timeSource.nanoTime() + delayToNanos(timeMillis)
287307

288308
override fun compareTo(other: DelayedTask): Int {
@@ -297,24 +317,21 @@ internal abstract class EventLoopBase: CoroutineDispatcher(), Delay, EventLoop {
297317
fun timeToExecute(now: Long): Boolean = now - nanoTime >= 0L
298318

299319
@Synchronized
300-
fun rescheduleOnShutdown() {
301-
if (state != DELAYED) return
302-
if (_delayed.value!!.remove(this)) {
303-
state = RESCHEDULED
304-
DefaultExecutor.schedule(this)
305-
} else {
306-
state = REMOVED
307-
}
320+
fun schedule(delayed: ThreadSafeHeap<DelayedTask>): Int {
321+
if (_heap === DISPOSED_TASK) return SCHEDULE_DISPOSED // don't add -- was already disposed
322+
return if (delayed.addLastIf(this) { !isCompleted }) SCHEDULE_OK else SCHEDULE_COMPLETED
308323
}
309324

325+
// note: DefaultExecutor.schedule performs `schedule` (above) which does sync & checks for DISPOSED_TASK
326+
fun rescheduleOnShutdown() = DefaultExecutor.schedule(this)
327+
310328
@Synchronized
311329
final override fun dispose() {
312-
when (state) {
313-
DELAYED -> _delayed.value?.remove(this)
314-
RESCHEDULED -> DefaultExecutor.removeDelayedImpl(this)
315-
else -> return
316-
}
317-
state = REMOVED
330+
val heap = _heap
331+
if (heap === DISPOSED_TASK) return // already disposed
332+
@Suppress("UNCHECKED_CAST")
333+
(heap as? ThreadSafeHeap<DelayedTask>)?.remove(this) // remove if it is in heap (first)
334+
_heap = DISPOSED_TASK // never add again to any heap
318335
}
319336

320337
override fun toString(): String = "Delayed[nanos=$nanoTime]"

core/kotlinx-coroutines-core/src/internal/ThreadSafeHeap.kt

+11-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import java.util.*
1111
* @suppress **This is unstable API and it is subject to change.**
1212
*/
1313
public interface ThreadSafeHeapNode {
14+
public var heap: ThreadSafeHeap<*>?
1415
public var index: Int
1516
}
1617

@@ -44,8 +45,8 @@ public class ThreadSafeHeap<T> : SynchronizedObject() where T: ThreadSafeHeapNod
4445
null
4546
}
4647

47-
@Synchronized
48-
public inline fun removeFirstIf(predicate: (T) -> Boolean): T? {
48+
// @Synchronized // NOTE! NOTE! NOTE! inline fun cannot be @Synchronized
49+
public inline fun removeFirstIf(predicate: (T) -> Boolean): T? = synchronized(this) {
4950
val first = firstImpl() ?: return null
5051
return if (predicate(first)) {
5152
removeAtImpl(0)
@@ -68,10 +69,12 @@ public class ThreadSafeHeap<T> : SynchronizedObject() where T: ThreadSafeHeapNod
6869

6970
@Synchronized
7071
public fun remove(node: T): Boolean {
71-
return if (node.index < 0) {
72+
return if (node.heap == null) {
7273
false
7374
} else {
74-
removeAtImpl(node.index)
75+
val index = node.index
76+
check(index >= 0)
77+
removeAtImpl(index)
7578
true
7679
}
7780
}
@@ -95,18 +98,17 @@ public class ThreadSafeHeap<T> : SynchronizedObject() where T: ThreadSafeHeapNod
9598
}
9699
}
97100
val result = a[size]!!
101+
check(result.heap === this)
102+
result.heap = null
98103
result.index = -1
99104
a[size] = null
100105
return result
101106
}
102107

103108
@PublishedApi
104109
internal fun addImpl(node: T) {
105-
// TODO remove this after #541 when ThreadSafeHeapNode is gone
106-
if (node is EventLoopBase.DelayedTask && node.state == REMOVED) {
107-
return
108-
}
109-
110+
check(node.heap == null)
111+
node.heap = this
110112
val a = realloc()
111113
val i = size++
112114
a[i] = node

core/kotlinx-coroutines-core/src/test_/TestCoroutineContext.kt

+1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ private class TimedRunnable(
239239
private val count: Long = 0,
240240
@JvmField internal val time: Long = 0
241241
) : Comparable<TimedRunnable>, Runnable by run, ThreadSafeHeapNode {
242+
override var heap: ThreadSafeHeap<*>? = null
242243
override var index: Int = 0
243244

244245
override fun run() = run.run()

core/kotlinx-coroutines-core/test/internal/ThreadSafeHeapTest.kt

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import java.util.*
1010

1111
class ThreadSafeHeapTest : TestBase() {
1212
class Node(val value: Int) : ThreadSafeHeapNode, Comparable<Node> {
13+
override var heap: ThreadSafeHeap<*>? = null
1314
override var index = -1
1415
override fun compareTo(other: Node): Int = value.compareTo(other.value)
1516
override fun equals(other: Any?): Boolean = other is Node && other.value == value

0 commit comments

Comments
 (0)