Skip to content

Commit f410298

Browse files
author
Trol
committed
Support thread interrupting blocking functions (#1947)
This is implementation of issue #1947 Signed-off-by: Trol <[email protected]>
1 parent 5eaf83c commit f410298

File tree

3 files changed

+322
-0
lines changed

3 files changed

+322
-0
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ public final class kotlinx/coroutines/CancellableContinuationKt {
8686
public static final fun suspendCancellableCoroutine (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
8787
}
8888

89+
public final class kotlinx/coroutines/CancellationPointKt {
90+
public static final fun runInterruptible (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
91+
public static synthetic fun runInterruptible$default (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
92+
}
93+
8994
public abstract interface class kotlinx/coroutines/ChildHandle : kotlinx/coroutines/DisposableHandle {
9095
public abstract fun childCancelled (Ljava/lang/Throwable;)Z
9196
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlinx.atomicfu.AtomicRef
8+
import kotlinx.atomicfu.atomic
9+
import kotlinx.atomicfu.loop
10+
import kotlin.coroutines.CoroutineContext
11+
import kotlin.coroutines.EmptyCoroutineContext
12+
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
13+
14+
/**
15+
* Makes a blocking code block cancellable (become a cancellation point of the coroutine).
16+
*
17+
* The blocking code block will be interrupted and this function will throw [CancellationException]
18+
* if the coroutine is cancelled.
19+
*
20+
* Example:
21+
* ```
22+
* GlobalScope.launch(Dispatchers.IO) {
23+
* async {
24+
* // This function will throw [CancellationException].
25+
* runInterruptible {
26+
* doSomethingUseful()
27+
*
28+
* // This blocking procedure will be interrupted when this coroutine is canceled
29+
* // by Exception thrown by the below async block.
30+
* doSomethingElseUsefulInterruptible()
31+
* }
32+
* }
33+
*
34+
* async {
35+
* delay(500L)
36+
* throw Exception()
37+
* }
38+
* }
39+
* ```
40+
*
41+
* There is also an optional context parameter to this function to enable single-call conversion of
42+
* interruptible Java methods into main-safe suspending functions like this:
43+
* ```
44+
* // With one call here we are moving the call to Dispatchers.IO and supporting interruption.
45+
* suspend fun <T> BlockingQueue<T>.awaitTake(): T =
46+
* runInterruptible(Dispatchers.IO) { queue.take() }
47+
* ```
48+
*
49+
* @param context additional to [CoroutineScope.coroutineContext] context of the coroutine.
50+
* @param block regular blocking block that will be interrupted on coroutine cancellation.
51+
*/
52+
public suspend fun <T> runInterruptible(
53+
context: CoroutineContext = EmptyCoroutineContext,
54+
block: () -> T
55+
): T {
56+
// fast path: empty context
57+
if (context === EmptyCoroutineContext) { return runInterruptibleInExpectedContext(block) }
58+
// slow path:
59+
return withContext(context) { runInterruptibleInExpectedContext(block) }
60+
}
61+
62+
private suspend fun <T> runInterruptibleInExpectedContext(block: () -> T): T =
63+
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
64+
try {
65+
// fast path: no job
66+
val job = uCont.context[Job] ?: return@sc block()
67+
// slow path
68+
val threadState = ThreadState().apply { initInterrupt(job) }
69+
try {
70+
block()
71+
} finally {
72+
threadState.clearInterrupt()
73+
}
74+
} catch (e: InterruptedException) {
75+
throw CancellationException()
76+
}
77+
}
78+
79+
private class ThreadState {
80+
81+
fun initInterrupt(job: Job) {
82+
// starts with Init
83+
if (state.value !== Init) throw IllegalStateException("impossible state")
84+
// remembers this running thread
85+
state.value = Working(Thread.currentThread(), null)
86+
// watches the job for cancellation
87+
val cancelHandle =
88+
job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = CancelHandler())
89+
// remembers the cancel handle or drops it
90+
state.loop { s ->
91+
when {
92+
s is Working -> if (state.compareAndSet(s, Working(s.thread, cancelHandle))) return
93+
s === Interrupting || s === Interrupted -> return
94+
s === Init || s === Finish -> throw IllegalStateException("impossible state")
95+
else -> throw IllegalStateException("unknown state")
96+
}
97+
}
98+
}
99+
100+
fun clearInterrupt() {
101+
state.loop { s ->
102+
when {
103+
s is Working -> if (state.compareAndSet(s, Finish)) { s.cancelHandle!!.dispose(); return }
104+
s === Interrupting -> Thread.yield() // eases the thread
105+
s === Interrupted -> { Thread.interrupted(); return } // no interrupt leak
106+
s === Init || s === Finish -> throw IllegalStateException("impossible state")
107+
else -> throw IllegalStateException("unknown state")
108+
}
109+
}
110+
}
111+
112+
private inner class CancelHandler : CompletionHandler {
113+
override fun invoke(cause: Throwable?) {
114+
state.loop { s ->
115+
when {
116+
s is Working -> {
117+
if (state.compareAndSet(s, Interrupting)) {
118+
s.thread!!.interrupt()
119+
state.value = Interrupted
120+
return
121+
}
122+
}
123+
s === Finish -> return
124+
s === Interrupting || s === Interrupted -> return
125+
s === Init -> throw IllegalStateException("impossible state")
126+
else -> throw IllegalStateException("unknown state")
127+
}
128+
}
129+
}
130+
}
131+
132+
private val state: AtomicRef<State> = atomic(Init)
133+
134+
private interface State
135+
// initial state
136+
private object Init : State
137+
// cancellation watching is setup and/or the continuation is running
138+
private data class Working(val thread: Thread?, val cancelHandle: DisposableHandle?) : State
139+
// the continuation done running without interruption
140+
private object Finish : State
141+
// interrupting this thread
142+
private object Interrupting: State
143+
// done interrupting
144+
private object Interrupted: State
145+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import java.io.IOException
8+
import java.util.concurrent.Executors
9+
import java.util.concurrent.atomic.AtomicBoolean
10+
import java.util.concurrent.atomic.AtomicInteger
11+
import kotlin.test.*
12+
13+
class InterruptibleCancellationPointTest: TestBase() {
14+
15+
@Test
16+
fun testNormalRun() = runBlocking {
17+
var result = runInterruptible {
18+
var x = doSomethingUsefulBlocking(1, 1)
19+
var y = doSomethingUsefulBlocking(1, 2)
20+
x + y
21+
}
22+
assertEquals(3, result)
23+
}
24+
25+
@Test
26+
fun testExceptionThrow() = runBlocking {
27+
val exception = Exception()
28+
try {
29+
runInterruptible {
30+
throw exception
31+
}
32+
} catch (e: Throwable) {
33+
assertEquals(exception, e)
34+
return@runBlocking
35+
}
36+
fail()
37+
}
38+
39+
@Test
40+
fun testRunWithContext() = runBlocking {
41+
var runThread =
42+
runInterruptible (Dispatchers.IO) {
43+
Thread.currentThread()
44+
}
45+
assertNotEquals(runThread, Thread.currentThread())
46+
}
47+
48+
@Test
49+
fun testRunWithContextFastPath() = runBlocking {
50+
var runThread : Thread =
51+
runInterruptible {
52+
Thread.currentThread()
53+
}
54+
assertEquals(runThread, Thread.currentThread())
55+
}
56+
57+
@Test
58+
fun testInterrupt() {
59+
val count = AtomicInteger(0)
60+
try {
61+
expect(1)
62+
runBlocking {
63+
launch(Dispatchers.IO) {
64+
async {
65+
try {
66+
// `runInterruptible` makes a blocking block cancelable (become a cancellation point)
67+
// by interrupting it on cancellation and throws CancellationException
68+
runInterruptible {
69+
try {
70+
doSomethingUsefulBlocking(100, 1)
71+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
72+
} catch (e: InterruptedException) {
73+
expect(3)
74+
throw e
75+
}
76+
}
77+
} catch (e: CancellationException) {
78+
expect(4)
79+
}
80+
}
81+
82+
async {
83+
delay(500L)
84+
expect(2)
85+
throw IOException()
86+
}
87+
}
88+
}
89+
} catch (e: IOException) {
90+
expect(5)
91+
}
92+
finish(6)
93+
}
94+
95+
@Test
96+
fun testNoInterruptLeak() = runBlocking {
97+
var interrupted = true
98+
99+
var task = launch(Dispatchers.IO) {
100+
try {
101+
runInterruptible {
102+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
103+
}
104+
} finally {
105+
interrupted = Thread.currentThread().isInterrupted
106+
}
107+
}
108+
109+
delay(500)
110+
task.cancel()
111+
task.join()
112+
assertFalse(interrupted)
113+
}
114+
115+
@Test
116+
fun testStress() {
117+
val REPEAT_TIMES = 2_000
118+
119+
Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher ->
120+
val interruptLeak = AtomicBoolean(false)
121+
val enterCount = AtomicInteger(0)
122+
val interruptedCount = AtomicInteger(0)
123+
val otherExceptionCount = AtomicInteger(0)
124+
125+
runBlocking {
126+
repeat(REPEAT_TIMES) { repeat ->
127+
var job = launch(start = CoroutineStart.LAZY, context = dispatcher) {
128+
try {
129+
runInterruptible {
130+
enterCount.incrementAndGet()
131+
try {
132+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
133+
} catch (e: InterruptedException) {
134+
interruptedCount.incrementAndGet()
135+
throw e
136+
}
137+
}
138+
} catch (e: CancellationException) {
139+
} catch (e: Throwable) {
140+
otherExceptionCount.incrementAndGet()
141+
} finally {
142+
interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted)
143+
}
144+
}
145+
146+
var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) {
147+
job.cancel()
148+
}
149+
150+
launch (dispatcher) {
151+
delay((REPEAT_TIMES - repeat).toLong())
152+
job.start()
153+
}
154+
155+
launch (dispatcher) {
156+
delay(repeat.toLong())
157+
cancelJob.start()
158+
}
159+
}
160+
}
161+
162+
assertFalse(interruptLeak.get())
163+
assertEquals(enterCount.get(), interruptedCount.get())
164+
assertEquals(0, otherExceptionCount.get())
165+
}
166+
}
167+
168+
private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int {
169+
Thread.sleep(timeUseMillis)
170+
return result
171+
}
172+
}

0 commit comments

Comments
 (0)