Skip to content

Commit 3d3987c

Browse files
author
Trol
committed
(non-intrusive) Implement optional thread interrupt on coroutine cancellation (Kotlin#57)
This is implementation of issue Kotlin#57 and non-intrusive variant of Kotlin#1922 Signed-off-by: Trol <[email protected]>
1 parent 5eaf83c commit 3d3987c

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ public final class kotlinx/coroutines/CancellableContinuationKt {
8686
public static final fun suspendCancellableCoroutine (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
8787
}
8888

89+
public final class kotlinx/coroutines/CancellationPointKt {
90+
public static final fun interruptible (Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
91+
}
92+
8993
public abstract interface class kotlinx/coroutines/ChildHandle : kotlinx/coroutines/DisposableHandle {
9094
public abstract fun childCancelled (Ljava/lang/Throwable;)Z
9195
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package kotlinx.coroutines
2+
3+
import kotlinx.atomicfu.AtomicRef
4+
import kotlinx.atomicfu.atomic
5+
import kotlinx.atomicfu.loop
6+
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
7+
8+
/**
9+
* Makes a blocking code block cancellable (become a cancellation point of the coroutine).
10+
*
11+
* The blocking code block will be interrupted and this function will throw [CancellationException]
12+
* if the coroutine is cancelled.
13+
*
14+
* Example:
15+
* ```
16+
* GlobalScope.launch(Dispatchers.IO) {
17+
* async {
18+
* // This function will throw [CancellationException].
19+
* interruptible {
20+
* doSomethingUseful()
21+
*
22+
* // This blocking procedure will be interrupted when this coroutine is canceled
23+
* // by Exception thrown by the below async block.
24+
* doSomethingElseUsefulInterruptible()
25+
* }
26+
* }
27+
*
28+
* async {
29+
* delay(500L)
30+
* throw Exception()
31+
* }
32+
* }
33+
* ```
34+
*/
35+
public suspend fun <T> interruptible(block: () -> T): T = suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
36+
try {
37+
// fast path: no job
38+
val job = uCont.context[Job] ?: return@sc block()
39+
// slow path
40+
val threadState = ThreadState().apply { initInterrupt(job) }
41+
try {
42+
block()
43+
} finally {
44+
threadState.clearInterrupt()
45+
}
46+
} catch (e: InterruptedException) {
47+
throw CancellationException()
48+
}
49+
}
50+
51+
private class ThreadState {
52+
53+
fun initInterrupt(job: Job) {
54+
// starts with Init
55+
if (state.value !== Init) throw IllegalStateException("impossible state")
56+
// remembers this running thread
57+
state.value = Working(Thread.currentThread(), null)
58+
// watches the job for cancellation
59+
val cancelHandle =
60+
job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = CancelHandler())
61+
// remembers the cancel handle or drops it
62+
state.loop { s ->
63+
when {
64+
s is Working -> if (state.compareAndSet(s, Working(s.thread, cancelHandle))) return
65+
s === Interrupting || s === Interrupted -> return
66+
s === Init || s === Finish -> throw IllegalStateException("impossible state")
67+
else -> throw IllegalStateException("unknown state")
68+
}
69+
}
70+
}
71+
72+
fun clearInterrupt() {
73+
state.loop { s ->
74+
when {
75+
s is Working -> if (state.compareAndSet(s, Finish)) { s.cancelHandle!!.dispose(); return }
76+
s === Interrupting -> Thread.yield() // eases the thread
77+
s === Interrupted -> { Thread.interrupted(); return } // no interrupt leak
78+
s === Init || s === Finish -> throw IllegalStateException("impossible state")
79+
else -> throw IllegalStateException("unknown state")
80+
}
81+
}
82+
}
83+
84+
private inner class CancelHandler : CompletionHandler {
85+
override fun invoke(cause: Throwable?) {
86+
state.loop { s ->
87+
when {
88+
s is Working -> {
89+
if (state.compareAndSet(s, Interrupting)) {
90+
s.thread!!.interrupt()
91+
state.value = Interrupted
92+
return
93+
}
94+
}
95+
s === Finish -> return
96+
s === Interrupting || s === Interrupted -> return
97+
s === Init -> throw IllegalStateException("impossible state")
98+
else -> throw IllegalStateException("unknown state")
99+
}
100+
}
101+
}
102+
}
103+
104+
private val state: AtomicRef<State> = atomic(Init)
105+
106+
private interface State
107+
// initial state
108+
private object Init : State
109+
// cancellation watching is setup and/or the continuation is running
110+
private data class Working(val thread: Thread?, val cancelHandle: DisposableHandle?) : State
111+
// the continuation done running without interruption
112+
private object Finish : State
113+
// interrupting this thread
114+
private object Interrupting: State
115+
// done interrupting
116+
private object Interrupted: State
117+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import org.junit.Test
8+
import java.io.IOException
9+
import java.util.concurrent.Executors
10+
import java.util.concurrent.atomic.AtomicBoolean
11+
import java.util.concurrent.atomic.AtomicInteger
12+
import kotlin.test.assertEquals
13+
import kotlin.test.assertFalse
14+
15+
class InterruptibleCancellationPointTest: TestBase() {
16+
17+
@Test
18+
fun testNormalRun() = runBlocking {
19+
var result = interruptible {
20+
var x = doSomethingUsefulBlocking(1, 1)
21+
var y = doSomethingUsefulBlocking(1, 2)
22+
x + y
23+
}
24+
assertEquals(3, result)
25+
}
26+
27+
@Test
28+
fun testInterrupt() {
29+
val count = AtomicInteger(0)
30+
try {
31+
expect(1)
32+
runBlocking {
33+
launch(Dispatchers.IO) {
34+
async {
35+
try {
36+
// `interruptible` makes a blocking block cancelable (become a cancellation point)
37+
// by interrupting it on cancellation and throws CancellationException
38+
interruptible {
39+
try {
40+
doSomethingUsefulBlocking(100, 1)
41+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
42+
} catch (e: InterruptedException) {
43+
expect(3)
44+
throw e
45+
}
46+
}
47+
} catch (e: CancellationException) {
48+
expect(4)
49+
}
50+
}
51+
52+
async {
53+
delay(500L)
54+
expect(2)
55+
throw IOException()
56+
}
57+
}
58+
}
59+
} catch (e: IOException) {
60+
expect(5)
61+
}
62+
finish(6)
63+
}
64+
65+
@Test
66+
fun testNoInterruptLeak() = runBlocking {
67+
var interrupted = true
68+
69+
var task = launch(Dispatchers.IO) {
70+
try {
71+
interruptible {
72+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
73+
}
74+
} finally {
75+
interrupted = Thread.currentThread().isInterrupted
76+
}
77+
}
78+
79+
delay(500)
80+
task.cancel()
81+
task.join()
82+
assertFalse(interrupted)
83+
}
84+
85+
@Test
86+
fun testStress() {
87+
val REPEAT_TIMES = 2_000
88+
89+
Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher ->
90+
val interruptLeak = AtomicBoolean(false)
91+
val enterCount = AtomicInteger(0)
92+
val interruptedCount = AtomicInteger(0)
93+
val otherExceptionCount = AtomicInteger(0)
94+
95+
runBlocking {
96+
repeat(REPEAT_TIMES) { repeat ->
97+
var job = launch(start = CoroutineStart.LAZY, context = dispatcher) {
98+
try {
99+
interruptible {
100+
enterCount.incrementAndGet()
101+
try {
102+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
103+
} catch (e: InterruptedException) {
104+
interruptedCount.incrementAndGet()
105+
throw e
106+
}
107+
}
108+
} catch (e: CancellationException) {
109+
} catch (e: Throwable) {
110+
otherExceptionCount.incrementAndGet()
111+
} finally {
112+
interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted)
113+
}
114+
}
115+
116+
var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) {
117+
job.cancel()
118+
}
119+
120+
launch (dispatcher) {
121+
delay((REPEAT_TIMES - repeat).toLong())
122+
job.start()
123+
}
124+
125+
launch (dispatcher) {
126+
delay(repeat.toLong())
127+
cancelJob.start()
128+
}
129+
}
130+
}
131+
132+
assertFalse(interruptLeak.get())
133+
assertEquals(enterCount.get(), interruptedCount.get())
134+
assertEquals(0, otherExceptionCount.get())
135+
}
136+
}
137+
138+
private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int {
139+
Thread.sleep(timeUseMillis)
140+
return result
141+
}
142+
}

0 commit comments

Comments
 (0)