Skip to content

Commit 8befcd6

Browse files
author
Trol
committed
Support thread interrupting blocking functions (#1947)
This is implementation of issue #1947 Signed-off-by: Trol <[email protected]>
1 parent 5eaf83c commit 8befcd6

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+5
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ public final class kotlinx/coroutines/GlobalScope : kotlinx/coroutines/Coroutine
328328
public abstract interface annotation class kotlinx/coroutines/InternalCoroutinesApi : java/lang/annotation/Annotation {
329329
}
330330

331+
public final class kotlinx/coroutines/InterruptibleKt {
332+
public static final fun runInterruptible (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
333+
public static synthetic fun runInterruptible$default (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
334+
}
335+
331336
public abstract interface class kotlinx/coroutines/Job : kotlin/coroutines/CoroutineContext$Element {
332337
public static final field Key Lkotlinx/coroutines/Job$Key;
333338
public abstract fun attachChild (Lkotlinx/coroutines/ChildJob;)Lkotlinx/coroutines/ChildHandle;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlinx.atomicfu.AtomicRef
8+
import kotlinx.atomicfu.atomic
9+
import kotlinx.atomicfu.loop
10+
import kotlin.coroutines.CoroutineContext
11+
import kotlin.coroutines.EmptyCoroutineContext
12+
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
13+
14+
/**
15+
* Makes a blocking code block cancellable (become a cancellation point of the coroutine).
16+
*
17+
* The blocking code block will be interrupted and this function will throw [CancellationException]
18+
* if the coroutine is cancelled.
19+
*
20+
* Example:
21+
* ```
22+
* GlobalScope.launch(Dispatchers.IO) {
23+
* async {
24+
* // This function will throw [CancellationException].
25+
* runInterruptible {
26+
* doSomethingUseful()
27+
*
28+
* // This blocking procedure will be interrupted when this coroutine is canceled
29+
* // by Exception thrown by the below async block.
30+
* doSomethingElseUsefulInterruptible()
31+
* }
32+
* }
33+
*
34+
* async {
35+
* delay(500L)
36+
* throw Exception()
37+
* }
38+
* }
39+
* ```
40+
*
41+
* There is also an optional context parameter to this function to enable single-call conversion of
42+
* interruptible Java methods into main-safe suspending functions like this:
43+
* ```
44+
* // With one call here we are moving the call to Dispatchers.IO and supporting interruption.
45+
* suspend fun <T> BlockingQueue<T>.awaitTake(): T =
46+
* runInterruptible(Dispatchers.IO) { queue.take() }
47+
* ```
48+
*
49+
* @param context additional to [CoroutineScope.coroutineContext] context of the coroutine.
50+
* @param block regular blocking block that will be interrupted on coroutine cancellation.
51+
*/
52+
public suspend fun <T> runInterruptible(
53+
context: CoroutineContext = EmptyCoroutineContext,
54+
block: () -> T
55+
): T = withContext(context) { runInterruptibleInExpectedContext(block) }
56+
57+
private suspend fun <T> runInterruptibleInExpectedContext(block: () -> T): T =
58+
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
59+
try {
60+
// fast path: no job
61+
val job = uCont.context[Job] ?: return@sc block()
62+
// slow path
63+
val threadState = ThreadState(job)
64+
try {
65+
block()
66+
} finally {
67+
threadState.clear()
68+
}
69+
} catch (e: InterruptedException) {
70+
throw CancellationException("runInterruptible: interrupted").initCause(e)
71+
}
72+
}
73+
74+
private const val WORKING = 0
75+
private const val FINISH = 1
76+
private const val INTERRUPTING = 2
77+
private const val INTERRUPTED = 3
78+
79+
private class ThreadState : CompletionHandler {
80+
/*
81+
=== States ===
82+
83+
WORKING: running normally
84+
FINISH: complete normally
85+
INTERRUPTING: canceled, going to interrupt this thread
86+
INTERRUPTED: this thread is interrupted
87+
88+
89+
=== Possible Transitions ===
90+
91+
+----------------+ remember +-------------------------+
92+
| WORKING | cancellation listener | WORKING |
93+
| (thread, null) | -------------------------> | (thread, cancel handle) |
94+
+----------------+ +-------------------------+
95+
| | |
96+
| cancel cancel | | complete
97+
| | |
98+
V | |
99+
+---------------+ | |
100+
| INTERRUPTING | <--------------------------------------+ |
101+
+---------------+ |
102+
| |
103+
| interrupt |
104+
| |
105+
V V
106+
+---------------+ +-------------------------+
107+
| INTERRUPTED | | FINISH |
108+
+---------------+ +-------------------------+
109+
*/
110+
private val state: AtomicRef<State>
111+
112+
private data class State(val state: Int, val thread: Thread? = null, val cancelHandle: DisposableHandle? = null)
113+
114+
// We're using a non-primary constructor instead of init block of a primary constructor here, because
115+
// we need to `return`.
116+
constructor (job: Job) {
117+
state = atomic(State(WORKING, Thread.currentThread()))
118+
// watches the job for cancellation
119+
val cancelHandle =
120+
job.invokeOnCompletion(onCancelling = true, invokeImmediately = true, handler = this)
121+
// remembers the cancel handle or drops it
122+
state.loop { s ->
123+
when(s.state) {
124+
WORKING -> if (state.compareAndSet(s, State(WORKING, s.thread, cancelHandle))) return
125+
INTERRUPTING, INTERRUPTED -> return
126+
FINISH -> throw IllegalStateException("impossible state")
127+
else -> throw IllegalStateException("unknown state")
128+
}
129+
}
130+
}
131+
132+
fun clear() {
133+
state.loop { s ->
134+
when(s.state) {
135+
WORKING -> if (state.compareAndSet(s, State(FINISH))) { s.cancelHandle!!.dispose(); return }
136+
INTERRUPTING -> { /* spin */ }
137+
INTERRUPTED -> { Thread.interrupted(); return } // no interrupt leak
138+
FINISH -> throw IllegalStateException("impossible state")
139+
else -> throw IllegalStateException("unknown state")
140+
}
141+
}
142+
}
143+
144+
override fun invoke(cause: Throwable?) = onCancel(cause)
145+
146+
private inline fun onCancel(cause: Throwable?) {
147+
state.loop { s ->
148+
when(s.state) {
149+
WORKING -> {
150+
if (state.compareAndSet(s, State(INTERRUPTING))) {
151+
s.thread!!.interrupt()
152+
state.value = State(INTERRUPTED)
153+
return
154+
}
155+
}
156+
FINISH -> return
157+
INTERRUPTING, INTERRUPTED -> return
158+
else -> throw IllegalStateException("unknown state")
159+
}
160+
}
161+
}
162+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import java.io.IOException
8+
import java.util.concurrent.Executors
9+
import java.util.concurrent.atomic.AtomicBoolean
10+
import java.util.concurrent.atomic.AtomicInteger
11+
import kotlin.test.*
12+
13+
class InterruptibleTest: TestBase() {
14+
@Test
15+
fun testNormalRun() = runBlocking {
16+
var result = runInterruptible {
17+
var x = doSomethingUsefulBlocking(1, 1)
18+
var y = doSomethingUsefulBlocking(1, 2)
19+
x + y
20+
}
21+
assertEquals(3, result)
22+
}
23+
24+
@Test
25+
fun testExceptionThrow() = runBlocking {
26+
try {
27+
runInterruptible {
28+
throw TestException()
29+
}
30+
} catch (e: Throwable) {
31+
assertTrue(e is TestException)
32+
return@runBlocking
33+
}
34+
fail()
35+
}
36+
37+
@Test
38+
fun testRunWithContext() = runBlocking {
39+
var runThread =
40+
runInterruptible (Dispatchers.IO) {
41+
Thread.currentThread()
42+
}
43+
assertNotEquals(runThread, Thread.currentThread())
44+
}
45+
46+
@Test
47+
fun testInterrupt() {
48+
val count = AtomicInteger(0)
49+
try {
50+
expect(1)
51+
runBlocking {
52+
launch(Dispatchers.IO) {
53+
async {
54+
try {
55+
// `runInterruptible` makes a blocking block cancelable (become a cancellation point)
56+
// by interrupting it on cancellation and throws CancellationException
57+
runInterruptible {
58+
try {
59+
doSomethingUsefulBlocking(100, 1)
60+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
61+
} catch (e: InterruptedException) {
62+
expect(3)
63+
throw e
64+
}
65+
}
66+
} catch (e: CancellationException) {
67+
expect(4)
68+
}
69+
}
70+
71+
async {
72+
delay(500L)
73+
expect(2)
74+
throw IOException()
75+
}
76+
}
77+
}
78+
} catch (e: IOException) {
79+
expect(5)
80+
}
81+
finish(6)
82+
}
83+
84+
@Test
85+
fun testNoInterruptLeak() = runBlocking {
86+
var interrupted = true
87+
88+
var task = launch(Dispatchers.IO) {
89+
try {
90+
runInterruptible {
91+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
92+
}
93+
} finally {
94+
interrupted = Thread.currentThread().isInterrupted
95+
}
96+
}
97+
98+
delay(500)
99+
task.cancel()
100+
task.join()
101+
assertFalse(interrupted)
102+
}
103+
104+
@Test
105+
fun testStress() {
106+
val REPEAT_TIMES = 2_000
107+
108+
Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher ->
109+
val interruptLeak = AtomicBoolean(false)
110+
val enterCount = AtomicInteger(0)
111+
val interruptedCount = AtomicInteger(0)
112+
val otherExceptionCount = AtomicInteger(0)
113+
114+
runBlocking {
115+
repeat(REPEAT_TIMES) { repeat ->
116+
var job = launch(start = CoroutineStart.LAZY, context = dispatcher) {
117+
try {
118+
runInterruptible {
119+
enterCount.incrementAndGet()
120+
try {
121+
doSomethingUsefulBlocking(Long.MAX_VALUE, 0)
122+
} catch (e: InterruptedException) {
123+
interruptedCount.incrementAndGet()
124+
throw e
125+
}
126+
}
127+
} catch (e: CancellationException) {
128+
} catch (e: Throwable) {
129+
otherExceptionCount.incrementAndGet()
130+
} finally {
131+
interruptLeak.set(interruptLeak.get() || Thread.currentThread().isInterrupted)
132+
}
133+
}
134+
135+
var cancelJob = launch(start = CoroutineStart.LAZY, context = dispatcher) {
136+
job.cancel()
137+
}
138+
139+
launch (dispatcher) {
140+
delay((REPEAT_TIMES - repeat).toLong())
141+
job.start()
142+
}
143+
144+
launch (dispatcher) {
145+
delay(repeat.toLong())
146+
cancelJob.start()
147+
}
148+
}
149+
}
150+
151+
assertFalse(interruptLeak.get())
152+
assertEquals(enterCount.get(), interruptedCount.get())
153+
assertEquals(0, otherExceptionCount.get())
154+
}
155+
}
156+
157+
private fun doSomethingUsefulBlocking(timeUseMillis: Long, result: Int): Int {
158+
Thread.sleep(timeUseMillis)
159+
return result
160+
}
161+
162+
private class TestException : Exception()
163+
}

0 commit comments

Comments
 (0)