Skip to content

Commit 83ffd17

Browse files
qwwdfsadelizarov
andauthored
Properly cleanup completion in SafeCollector to avoid unintended memo… (#3199)
* Properly cleanup completion in SafeCollector to avoid unintended memory leak that regular coroutines (e.g. unsafe flow) are not prone to Also, FieldWalker is improved to avoid "illegal reflective access" Fixes #3197 Co-authored-by: Roman Elizarov <[email protected]>
1 parent d5f852c commit 83ffd17

File tree

3 files changed

+88
-17
lines changed

3 files changed

+88
-17
lines changed

kotlinx-coroutines-core/jvm/src/flow/internal/SafeCollector.kt

+29-14
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,22 @@ internal actual class SafeCollector<T> actual constructor(
2929

3030
@JvmField // Note, it is non-capturing lambda, so no extra allocation during init of SafeCollector
3131
internal actual val collectContextSize = collectContext.fold(0) { count, _ -> count + 1 }
32+
33+
// Either context of the last emission or wrapper 'DownstreamExceptionContext'
3234
private var lastEmissionContext: CoroutineContext? = null
35+
// Completion if we are currently suspended or within completion body or null otherwise
3336
private var completion: Continuation<Unit>? = null
3437

35-
// ContinuationImpl
38+
/*
39+
* This property is accessed in two places:
40+
* * ContinuationImpl invokes this in its `releaseIntercepted` as `context[ContinuationInterceptor]!!`
41+
* * When we are within a callee, it is used to create its continuation object with this collector as completion
42+
*/
3643
override val context: CoroutineContext
37-
get() = completion?.context ?: EmptyCoroutineContext
44+
get() = lastEmissionContext ?: EmptyCoroutineContext
3845

3946
override fun invokeSuspend(result: Result<Any?>): Any {
40-
result.onFailure { lastEmissionContext = DownstreamExceptionElement(it) }
47+
result.onFailure { lastEmissionContext = DownstreamExceptionContext(it, context) }
4148
completion?.resumeWith(result as Result<Unit>)
4249
return COROUTINE_SUSPENDED
4350
}
@@ -59,7 +66,9 @@ internal actual class SafeCollector<T> actual constructor(
5966
emit(uCont, value)
6067
} catch (e: Throwable) {
6168
// Save the fact that exception from emit (or even check context) has been thrown
62-
lastEmissionContext = DownstreamExceptionElement(e)
69+
// Note, that this can the first emit and lastEmissionContext may not be saved yet,
70+
// hence we use `uCont.context` here.
71+
lastEmissionContext = DownstreamExceptionContext(e, uCont.context)
6372
throw e
6473
}
6574
}
@@ -72,24 +81,32 @@ internal actual class SafeCollector<T> actual constructor(
7281
val previousContext = lastEmissionContext
7382
if (previousContext !== currentContext) {
7483
checkContext(currentContext, previousContext, value)
84+
lastEmissionContext = currentContext
7585
}
7686
completion = uCont
77-
return emitFun(collector as FlowCollector<Any?>, value, this as Continuation<Unit>)
87+
val result = emitFun(collector as FlowCollector<Any?>, value, this as Continuation<Unit>)
88+
/*
89+
* If the callee hasn't suspended, that means that it won't (it's forbidden) call 'resumeWith` (-> `invokeSuspend`)
90+
* and we don't have to retain a strong reference to it to avoid memory leaks.
91+
*/
92+
if (result != COROUTINE_SUSPENDED) {
93+
completion = null
94+
}
95+
return result
7896
}
7997

8098
private fun checkContext(
8199
currentContext: CoroutineContext,
82100
previousContext: CoroutineContext?,
83101
value: T
84102
) {
85-
if (previousContext is DownstreamExceptionElement) {
103+
if (previousContext is DownstreamExceptionContext) {
86104
exceptionTransparencyViolated(previousContext, value)
87105
}
88106
checkContext(currentContext)
89-
lastEmissionContext = currentContext
90107
}
91108

92-
private fun exceptionTransparencyViolated(exception: DownstreamExceptionElement, value: Any?) {
109+
private fun exceptionTransparencyViolated(exception: DownstreamExceptionContext, value: Any?) {
93110
/*
94111
* Exception transparency ensures that if a `collect` block or any intermediate operator
95112
* throws an exception, then no more values will be received by it.
@@ -122,14 +139,12 @@ internal actual class SafeCollector<T> actual constructor(
122139
For a more detailed explanation, please refer to Flow documentation.
123140
""".trimIndent())
124141
}
125-
126142
}
127143

128-
internal class DownstreamExceptionElement(@JvmField val e: Throwable) : CoroutineContext.Element {
129-
companion object Key : CoroutineContext.Key<DownstreamExceptionElement>
130-
131-
override val key: CoroutineContext.Key<*> = Key
132-
}
144+
internal class DownstreamExceptionContext(
145+
@JvmField val e: Throwable,
146+
originalContext: CoroutineContext
147+
) : CoroutineContext by originalContext
133148

134149
private object NoOpContinuation : Continuation<Any?> {
135150
override val context: CoroutineContext = EmptyCoroutineContext

kotlinx-coroutines-core/jvm/test/FieldWalker.kt

+11-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import java.lang.reflect.*
99
import java.text.*
1010
import java.util.*
1111
import java.util.Collections.*
12+
import java.util.concurrent.*
1213
import java.util.concurrent.atomic.*
1314
import java.util.concurrent.locks.*
1415
import kotlin.test.*
@@ -26,11 +27,11 @@ object FieldWalker {
2627
// excluded/terminal classes (don't walk them)
2728
fieldsCache += listOf(
2829
Any::class, String::class, Thread::class, Throwable::class, StackTraceElement::class,
29-
WeakReference::class, ReferenceQueue::class, AbstractMap::class,
30-
ReentrantReadWriteLock::class, SimpleDateFormat::class
30+
WeakReference::class, ReferenceQueue::class, AbstractMap::class, Enum::class,
31+
ReentrantLock::class, ReentrantReadWriteLock::class, SimpleDateFormat::class, ThreadPoolExecutor::class,
3132
)
3233
.map { it.java }
33-
.associateWith { emptyList<Field>() }
34+
.associateWith { emptyList() }
3435
}
3536

3637
/*
@@ -159,6 +160,13 @@ object FieldWalker {
159160
&& !(it.type.isArray && it.type.componentType.isPrimitive)
160161
&& it.name != "previousOut" // System.out from TestBase that we store in a field to restore later
161162
}
163+
check(fields.isEmpty() || !type.name.startsWith("java.")) {
164+
"""
165+
Trying to walk trough JDK's '$type' will get into illegal reflective access on JDK 9+.
166+
Either modify your test to avoid usage of this class or update FieldWalker code to retrieve
167+
the captured state of this class without going through reflection (see how collections are handled).
168+
""".trimIndent()
169+
}
162170
fields.forEach { it.isAccessible = true } // make them all accessible
163171
result.addAll(fields)
164172
type = type.superclass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.flow
6+
7+
import kotlinx.coroutines.*
8+
import org.junit.*
9+
10+
class SafeCollectorMemoryLeakTest : TestBase() {
11+
// custom List.forEach impl to avoid using iterator (FieldWalker cannot scan it)
12+
private inline fun <T> List<T>.listForEach(action: (T) -> Unit) {
13+
for (i in indices) action(get(i))
14+
}
15+
16+
@Test
17+
fun testCompletionIsProperlyCleanedUp() = runBlocking {
18+
val job = flow {
19+
emit(listOf(239))
20+
expect(2)
21+
hang {}
22+
}.transform { l -> l.listForEach { _ -> emit(42) } }
23+
.onEach { expect(1) }
24+
.launchIn(this)
25+
yield()
26+
expect(3)
27+
FieldWalker.assertReachableCount(0, job) { it == 239 }
28+
job.cancelAndJoin()
29+
finish(4)
30+
}
31+
32+
@Test
33+
fun testCompletionIsNotCleanedUp() = runBlocking {
34+
val job = flow {
35+
emit(listOf(239))
36+
hang {}
37+
}.transform { l -> l.listForEach { _ -> emit(42) } }
38+
.onEach {
39+
expect(1)
40+
hang { finish(3) }
41+
}
42+
.launchIn(this)
43+
yield()
44+
expect(2)
45+
FieldWalker.assertReachableCount(1, job) { it == 239 }
46+
job.cancelAndJoin()
47+
}
48+
}

0 commit comments

Comments
 (0)