|
1 | 1 | /*
|
2 |
| - * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. |
| 2 | + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. |
3 | 3 | */
|
4 | 4 |
|
5 | 5 | package kotlinx.coroutines
|
6 | 6 |
|
7 | 7 | import java.lang.reflect.*
|
8 | 8 | import java.util.*
|
9 | 9 | import java.util.Collections.*
|
| 10 | +import java.util.concurrent.atomic.* |
10 | 11 | import kotlin.collections.ArrayList
|
| 12 | +import kotlin.test.* |
11 | 13 |
|
12 | 14 | object FieldWalker {
|
| 15 | + sealed class Ref { |
| 16 | + object RootRef : Ref() |
| 17 | + class FieldRef(val parent: Any, val name: String) : Ref() |
| 18 | + class ArrayRef(val parent: Any, val index: Int) : Ref() |
| 19 | + } |
| 20 | + |
| 21 | + private val fieldsCache = HashMap<Class<*>, List<Field>>() |
| 22 | + |
| 23 | + init { |
| 24 | + // excluded/terminal classes (don't walk them) |
| 25 | + fieldsCache += listOf(Any::class, String::class, Thread::class, Throwable::class) |
| 26 | + .map { it.java } |
| 27 | + .associateWith { emptyList<Field>() } |
| 28 | + } |
13 | 29 |
|
14 | 30 | /*
|
15 | 31 | * Reflectively starts to walk through object graph and returns identity set of all reachable objects.
|
| 32 | + * Use [walkRefs] if you need a path from root for debugging. |
| 33 | + */ |
| 34 | + public fun walk(root: Any?): Set<Any> = walkRefs(root).keys |
| 35 | + |
| 36 | + public fun assertReachableCount(expected: Int, root: Any?, predicate: (Any) -> Boolean) { |
| 37 | + val visited = walkRefs(root) |
| 38 | + val actual = visited.keys.filter(predicate) |
| 39 | + if (actual.size != expected) { |
| 40 | + val textDump = actual.joinToString("") { "\n\t" + showPath(it, visited) } |
| 41 | + assertEquals( |
| 42 | + expected, actual.size, |
| 43 | + "Unexpected number objects. Expected $expected, found ${actual.size}$textDump" |
| 44 | + ) |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + /* |
| 49 | + * Reflectively starts to walk through object graph and map to all the reached object to their path |
| 50 | + * in from root. Use [showPath] do display a path if needed. |
16 | 51 | */
|
17 |
| - public fun walk(root: Any): Set<Any> { |
18 |
| - val result = newSetFromMap<Any>(IdentityHashMap()) |
19 |
| - result.add(root) |
| 52 | + private fun walkRefs(root: Any?): Map<Any, Ref> { |
| 53 | + val visited = IdentityHashMap<Any, Ref>() |
| 54 | + if (root == null) return visited |
| 55 | + visited[root] = Ref.RootRef |
20 | 56 | val stack = ArrayDeque<Any>()
|
21 | 57 | stack.addLast(root)
|
22 | 58 | while (stack.isNotEmpty()) {
|
23 | 59 | val element = stack.removeLast()
|
24 |
| - val type = element.javaClass |
25 |
| - type.visit(element, result, stack) |
| 60 | + try { |
| 61 | + visit(element, visited, stack) |
| 62 | + } catch (e: Exception) { |
| 63 | + error("Failed to visit element ${showPath(element, visited)}: $e") |
| 64 | + } |
26 | 65 | }
|
27 |
| - return result |
| 66 | + return visited |
28 | 67 | }
|
29 | 68 |
|
30 |
| - private fun Class<*>.visit( |
31 |
| - element: Any, |
32 |
| - result: MutableSet<Any>, |
33 |
| - stack: ArrayDeque<Any> |
34 |
| - ) { |
35 |
| - val fields = fields() |
36 |
| - fields.forEach { |
37 |
| - it.isAccessible = true |
38 |
| - val value = it.get(element) ?: return@forEach |
39 |
| - if (result.add(value)) { |
40 |
| - stack.addLast(value) |
| 69 | + private fun showPath(element: Any, visited: Map<Any, Ref>): String { |
| 70 | + val path = ArrayList<String>() |
| 71 | + var cur = element |
| 72 | + while (true) { |
| 73 | + val ref = visited.getValue(cur) |
| 74 | + if (ref is Ref.RootRef) break |
| 75 | + when (ref) { |
| 76 | + is Ref.FieldRef -> { |
| 77 | + cur = ref.parent |
| 78 | + path += ".${ref.name}" |
| 79 | + } |
| 80 | + is Ref.ArrayRef -> { |
| 81 | + cur = ref.parent |
| 82 | + path += "[${ref.index}]" |
| 83 | + } |
41 | 84 | }
|
42 | 85 | }
|
| 86 | + path.reverse() |
| 87 | + return path.joinToString("") |
| 88 | + } |
43 | 89 |
|
44 |
| - if (isArray && !componentType.isPrimitive) { |
45 |
| - val array = element as Array<Any?> |
46 |
| - array.filterNotNull().forEach { |
47 |
| - if (result.add(it)) { |
48 |
| - stack.addLast(it) |
| 90 | + private fun visit(element: Any, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>) { |
| 91 | + val type = element.javaClass |
| 92 | + when { |
| 93 | + // Special code for arrays |
| 94 | + type.isArray && !type.componentType.isPrimitive -> { |
| 95 | + @Suppress("UNCHECKED_CAST") |
| 96 | + val array = element as Array<Any?> |
| 97 | + array.forEachIndexed { index, value -> |
| 98 | + push(value, visited, stack) { Ref.ArrayRef(element, index) } |
| 99 | + } |
| 100 | + } |
| 101 | + // Special code for platform types that cannot be reflectively accessed on modern JDKs |
| 102 | + type.name.startsWith("java.") && element is Collection<*> -> { |
| 103 | + element.forEachIndexed { index, value -> |
| 104 | + push(value, visited, stack) { Ref.ArrayRef(element, index) } |
| 105 | + } |
| 106 | + } |
| 107 | + type.name.startsWith("java.") && element is Map<*, *> -> { |
| 108 | + push(element.keys, visited, stack) { Ref.FieldRef(element, "keys") } |
| 109 | + push(element.values, visited, stack) { Ref.FieldRef(element, "values") } |
| 110 | + } |
| 111 | + element is AtomicReference<*> -> { |
| 112 | + push(element.get(), visited, stack) { Ref.FieldRef(element, "value") } |
| 113 | + } |
| 114 | + // All the other classes are reflectively scanned |
| 115 | + else -> fields(type).forEach { field -> |
| 116 | + push(field.get(element), visited, stack) { Ref.FieldRef(element, field.name) } |
| 117 | + // special case to scan Throwable cause (cannot get it reflectively) |
| 118 | + if (element is Throwable) { |
| 119 | + push(element.cause, visited, stack) { Ref.FieldRef(element, "cause") } |
49 | 120 | }
|
50 | 121 | }
|
51 | 122 | }
|
52 | 123 | }
|
53 | 124 |
|
54 |
| - private fun Class<*>.fields(): List<Field> { |
| 125 | + private inline fun push(value: Any?, visited: IdentityHashMap<Any, Ref>, stack: ArrayDeque<Any>, ref: () -> Ref) { |
| 126 | + if (value != null && !visited.containsKey(value)) { |
| 127 | + visited[value] = ref() |
| 128 | + stack.addLast(value) |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + private fun fields(type0: Class<*>): List<Field> { |
| 133 | + fieldsCache[type0]?.let { return it } |
55 | 134 | val result = ArrayList<Field>()
|
56 |
| - var type = this |
57 |
| - while (type != Any::class.java) { |
| 135 | + var type = type0 |
| 136 | + while (true) { |
58 | 137 | val fields = type.declaredFields.filter {
|
59 | 138 | !it.type.isPrimitive
|
60 | 139 | && !Modifier.isStatic(it.modifiers)
|
61 | 140 | && !(it.type.isArray && it.type.componentType.isPrimitive)
|
62 | 141 | }
|
| 142 | + fields.forEach { it.isAccessible = true } // make them all accessible |
63 | 143 | result.addAll(fields)
|
64 | 144 | type = type.superclass
|
65 |
| - } |
66 |
| - |
67 |
| - return result |
68 |
| - } |
69 |
| - |
70 |
| - // Debugging-only |
71 |
| - @Suppress("UNUSED") |
72 |
| - fun printPath(from: Any, to: Any) { |
73 |
| - val pathNodes = ArrayList<String>() |
74 |
| - val visited = newSetFromMap<Any>(IdentityHashMap()) |
75 |
| - visited.add(from) |
76 |
| - if (findPath(from, to, visited, pathNodes)) { |
77 |
| - pathNodes.reverse() |
78 |
| - println(pathNodes.joinToString(" -> ", from.javaClass.simpleName + " -> ", "-> " + to.javaClass.simpleName)) |
79 |
| - } else { |
80 |
| - println("Path from $from to $to not found") |
81 |
| - } |
82 |
| - } |
83 |
| - |
84 |
| - private fun findPath(from: Any, to: Any, visited: MutableSet<Any>, pathNodes: MutableList<String>): Boolean { |
85 |
| - if (from === to) { |
86 |
| - return true |
87 |
| - } |
88 |
| - |
89 |
| - val type = from.javaClass |
90 |
| - if (type.isArray) { |
91 |
| - if (type.componentType.isPrimitive) return false |
92 |
| - val array = from as Array<Any?> |
93 |
| - array.filterNotNull().forEach { |
94 |
| - if (findPath(it, to, visited, pathNodes)) { |
95 |
| - return true |
96 |
| - } |
| 145 | + val superFields = fieldsCache[type] // will stop at Any anyway |
| 146 | + if (superFields != null) { |
| 147 | + result.addAll(superFields) |
| 148 | + break |
97 | 149 | }
|
98 |
| - return false |
99 | 150 | }
|
100 |
| - |
101 |
| - val fields = type.fields() |
102 |
| - fields.forEach { |
103 |
| - it.isAccessible = true |
104 |
| - val value = it.get(from) ?: return@forEach |
105 |
| - if (!visited.add(value)) return@forEach |
106 |
| - val found = findPath(value, to, visited, pathNodes) |
107 |
| - if (found) { |
108 |
| - pathNodes += from.javaClass.simpleName + ":" + it.name |
109 |
| - return true |
110 |
| - } |
111 |
| - } |
112 |
| - |
113 |
| - return false |
| 151 | + fieldsCache[type0] = result |
| 152 | + return result |
114 | 153 | }
|
115 | 154 | }
|
0 commit comments