Skip to content

Commit b64a23b

Browse files
authored
Improve FieldWalker, don't access JDK classes (#1799)
* Improve FieldWalker, don't access JDK classes * Works on future JDKs that forbid reflective access to JDK classes * Show human-readable path to field is something fails
1 parent 4aa3880 commit b64a23b

File tree

3 files changed

+123
-84
lines changed

3 files changed

+123
-84
lines changed
Lines changed: 112 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,154 @@
11
/*
2-
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
package kotlinx.coroutines
66

77
import java.lang.reflect.*
88
import java.util.*
99
import java.util.Collections.*
10+
import java.util.concurrent.atomic.*
1011
import kotlin.collections.ArrayList
12+
import kotlin.test.*
1113

1214
object FieldWalker {
15+
sealed class Ref {
16+
object RootRef : Ref()
17+
class FieldRef(val parent: Any, val name: String) : Ref()
18+
class ArrayRef(val parent: Any, val index: Int) : Ref()
19+
}
20+
21+
private val fieldsCache = HashMap<Class<*>, List<Field>>()
22+
23+
init {
24+
// excluded/terminal classes (don't walk them)
25+
fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class)
26+
.map { it.java }
27+
.associateWith { emptyList<Field>() }
28+
}
1329

1430
/*
1531
* Reflectively starts to walk through object graph and returns identity set of all reachable objects.
32+
* Use [walkRefs] if you need a path from root for debugging.
33+
*/
34+
public fun walk(root: Any?): Set<Any> = walkRefs(root).keys
35+
36+
public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) {
37+
val visited = walkRefs(root)
38+
val actual = visited.keys.filter(predicate)
39+
if (actual.size != expected) {
40+
val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) }
41+
assertEquals(
42+
expected, actual.size,
43+
"Unexpected number objects. Expected $expected, found ${actual.size}$textDump"
44+
)
45+
}
46+
}
47+
48+
/*
49+
* Reflectively starts to walk through object graph and map to all the reached object to their path
50+
* in from root. Use [showPath] do display a path if needed.
1651
*/
17-
public fun walk(root: Any): Set<Any> {
18-
val result = newSetFromMap<Any>(IdentityHashMap())
19-
result.add(root)
52+
private fun walkRefs(root: Any?): Map<Any, Ref> {
53+
val visited = IdentityHashMap<Any, Ref>()
54+
if (root == null) return visited
55+
visited[root] = Ref.RootRef
2056
val stack = ArrayDeque<Any>()
2157
stack.addLast(root)
2258
while (stack.isNotEmpty()) {
2359
val element = stack.removeLast()
24-
val type = element.javaClass
25-
type.visit(element, result, stack)
60+
try {
61+
visit(element, visited, stack)
62+
} catch (e: Exception) {
63+
error("Failed to visit element ${showPath(element, visited)}: $e")
64+
}
2665
}
27-
return result
66+
return visited
2867
}
2968

30-
private fun Class<*>.visit(
31-
element: Any,
32-
result: MutableSet<Any>,
33-
stack: ArrayDeque<Any>
34-
) {
35-
val fields = fields()
36-
fields.forEach {
37-
it.isAccessible = true
38-
val value = it.get(element) ?: return@forEach
39-
if (result.add(value)) {
40-
stack.addLast(value)
69+
private fun showPath(element: Any, visited: Map<Any, Ref>): String {
70+
val path = ArrayList<String>()
71+
var cur = element
72+
while (true) {
73+
val ref = visited.getValue(cur)
74+
if (ref is Ref.RootRef) break
75+
when (ref) {
76+
is Ref.FieldRef -> {
77+
cur = ref.parent
78+
path += ".${ref.name}"
79+
}
80+
is Ref.ArrayRef -> {
81+
cur = ref.parent
82+
path += "[${ref.index}]"
83+
}
4184
}
4285
}
86+
path.reverse()
87+
return path.joinToString("")
88+
}
4389

44-
if (isArray && !componentType.isPrimitive) {
45-
val array = element as Array<Any?>
46-
array.filterNotNull().forEach {
47-
if (result.add(it)) {
48-
stack.addLast(it)
90+
private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) {
91+
val type = element.javaClass
92+
when {
93+
// Special code for arrays
94+
type.isArray && !type.componentType.isPrimitive -> {
95+
@Suppress("UNCHECKED_CAST")
96+
val array = element as Array<Any?>
97+
array.forEachIndexed { index, value ->
98+
push(value, visited, stack) { Ref.ArrayRef(element, index) }
99+
}
100+
}
101+
// Special code for platform types that cannot be reflectively accessed on modern JDKs
102+
type.name.startsWith("java.") && element is Collection<*> -> {
103+
element.forEachIndexed { index, value ->
104+
push(value, visited, stack) { Ref.ArrayRef(element, index) }
105+
}
106+
}
107+
type.name.startsWith("java.") && element is Map<*, *> -> {
108+
push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") }
109+
push(element.values, visited, stack) { Ref.FieldRef(element, "values") }
110+
}
111+
element is AtomicReference<*> -> {
112+
push(element.get(), visited, stack) { Ref.FieldRef(element, "value") }
113+
}
114+
// All the other classes are reflectively scanned
115+
else -> fields(type).forEach { field ->
116+
push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) }
117+
// special case to scan Throwable cause (cannot get it reflectively)
118+
if (element is Throwable) {
119+
push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") }
49120
}
50121
}
51122
}
52123
}
53124

54-
private fun Class<*>.fields(): List<Field> {
125+
private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) {
126+
if (value != null && !visited.containsKey(value)) {
127+
visited[value] = ref()
128+
stack.addLast(value)
129+
}
130+
}
131+
132+
private fun fields(type0: Class<*>): List<Field> {
133+
fieldsCache[type0]?.let { return it }
55134
val result = ArrayList<Field>()
56-
var type = this
57-
while (type != Any::class.java) {
135+
var type = type0
136+
while (true) {
58137
val fields = type.declaredFields.filter {
59138
!it.type.isPrimitive
60139
&& !Modifier.isStatic(it.modifiers)
61140
&& !(it.type.isArray && it.type.componentType.isPrimitive)
62141
}
142+
fields.forEach { it.isAccessible = true } // make them all accessible
63143
result.addAll(fields)
64144
type = type.superclass
65-
}
66-
67-
return result
68-
}
69-
70-
// Debugging-only
71-
@Suppress("UNUSED")
72-
fun printPath(from: Any, to: Any) {
73-
val pathNodes = ArrayList<String>()
74-
val visited = newSetFromMap<Any>(IdentityHashMap())
75-
visited.add(from)
76-
if (findPath(from, to, visited, pathNodes)) {
77-
pathNodes.reverse()
78-
println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName))
79-
} else {
80-
println("Path from $from to $to not found")
81-
}
82-
}
83-
84-
private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean {
85-
if (from === to) {
86-
return true
87-
}
88-
89-
val type = from.javaClass
90-
if (type.isArray) {
91-
if (type.componentType.isPrimitive) return false
92-
val array = from as Array<Any?>
93-
array.filterNotNull().forEach {
94-
if (findPath(it, to, visited, pathNodes)) {
95-
return true
96-
}
145+
val superFields = fieldsCache[type] // will stop at Any anyway
146+
if (superFields != null) {
147+
result.addAll(superFields)
148+
break
97149
}
98-
return false
99150
}
100-
101-
val fields = type.fields()
102-
fields.forEach {
103-
it.isAccessible = true
104-
val value = it.get(from) ?: return@forEach
105-
if (!visited.add(value)) return@forEach
106-
val found = findPath(value, to, visited, pathNodes)
107-
if (found) {
108-
pathNodes += from.javaClass.simpleName + ":" + it.name
109-
return true
110-
}
111-
}
112-
113-
return false
151+
fieldsCache[type0] = result
152+
return result
114153
}
115154
}

kotlinx-coroutines-core/jvm/test/ReusableCancellableContinuationTest.kt

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class ReusableCancellableContinuationTest : TestBase() {
7676
expect(4)
7777
ensureActive()
7878
// Verify child was bound
79-
assertNotNull(FieldWalker.walk(coroutineContext[Job]!!).single { it === continuation })
79+
FieldWalker.assertReachableCount(1, coroutineContext[Job]) { it === continuation }
8080
suspendAtomicCancellableCoroutineReusable<Unit> {
8181
expect(5)
8282
coroutineContext[Job]!!.cancel()
@@ -97,7 +97,7 @@ class ReusableCancellableContinuationTest : TestBase() {
9797
cont = it
9898
}
9999
ensureActive()
100-
assertTrue { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
100+
assertTrue { FieldWalker.walk(coroutineContext[Job]).contains(cont!!) }
101101
finish(2)
102102
}
103103

@@ -112,7 +112,7 @@ class ReusableCancellableContinuationTest : TestBase() {
112112
cont = it
113113
}
114114
ensureActive()
115-
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
115+
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
116116
finish(2)
117117
}
118118

@@ -127,7 +127,7 @@ class ReusableCancellableContinuationTest : TestBase() {
127127
}
128128
expectUnreached()
129129
} catch (e: CancellationException) {
130-
assertFalse { FieldWalker.walk(coroutineContext[Job]!!).contains(cont!!) }
130+
FieldWalker.assertReachableCount(0, coroutineContext[Job]) { it === cont }
131131
finish(2)
132132
}
133133
}
@@ -148,19 +148,19 @@ class ReusableCancellableContinuationTest : TestBase() {
148148
expect(4)
149149
ensureActive()
150150
// Verify child was bound
151-
assertEquals(1, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
151+
FieldWalker.assertReachableCount(1, currentJob) { it is CancellableContinuation<*> }
152152
currentJob.cancel()
153153
assertFalse(isActive)
154154
// Child detached
155-
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
155+
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
156156
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
157157
suspendAtomicCancellableCoroutineReusable<Unit> { it.resume(Unit) }
158-
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
158+
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
159159

160160
try {
161161
suspendAtomicCancellableCoroutineReusable<Unit> {}
162162
} catch (e: CancellationException) {
163-
assertEquals(0, FieldWalker.walk(currentJob).count { it is CancellableContinuation<*> })
163+
FieldWalker.assertReachableCount(0, currentJob) { it is CancellableContinuation<*> }
164164
finish(5)
165165
}
166166
}
@@ -184,12 +184,12 @@ class ReusableCancellableContinuationTest : TestBase() {
184184
expect(2)
185185
val job = coroutineContext[Job]!!
186186
// 1 for reusable CC, another one for outer joiner
187-
assertEquals(2, FieldWalker.walk(job).count { it is CancellableContinuation<*> })
187+
FieldWalker.assertReachableCount(2, job) { it is CancellableContinuation<*> }
188188
}
189189
expect(1)
190190
receiver.join()
191191
// Reference should be claimed at this point
192-
assertEquals(0, FieldWalker.walk(receiver).count { it is CancellableContinuation<*> })
192+
FieldWalker.assertReachableCount(0, receiver) { it is CancellableContinuation<*> }
193193
finish(3)
194194
}
195195
}

kotlinx-coroutines-core/jvm/test/flow/ConsumeAsFlowLeakTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class ConsumeAsFlowLeakTest : TestBase() {
4141
if (shouldSuspendOnSend) yield()
4242
channel.send(second)
4343
yield()
44-
assertEquals(0, FieldWalker.walk(channel).count { it === second })
44+
FieldWalker.assertReachableCount(0, channel) { it === second }
4545
finish(6)
4646
job.cancelAndJoin()
4747
}

0 commit comments

Comments
 (0)