Skip to content

Commit 6865160

Browse files
authored
Merge pull request #9604 from dotty-staging/use-sparse-arrays
Use sparse arrays instead of maps for pickling
2 parents f2018f0 + 09e8b17 commit 6865160

File tree

4 files changed

+298
-20
lines changed

4 files changed

+298
-20
lines changed

compiler/src/dotty/tools/dotc/core/tasty/PositionPickler.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import Contexts._, Symbols._, Annotations._, Decorators._
1616
import collection.mutable
1717
import util.Spans._
1818

19-
class PositionPickler(pickler: TastyPickler, addrOfTree: untpd.Tree => Addr) {
19+
class PositionPickler(pickler: TastyPickler, addrOfTree: PositionPickler.TreeToAddr) {
2020
val buf: TastyBuffer = new TastyBuffer(5000)
2121
pickler.newSection("Positions", buf)
2222
import ast.tpd._
@@ -121,3 +121,8 @@ class PositionPickler(pickler: TastyPickler, addrOfTree: untpd.Tree => Addr) {
121121
traverse(root, NoSource)
122122
}
123123
}
124+
object PositionPickler:
125+
// Note: This could be just TreeToAddr => Addr if functions are specialized to value classes.
126+
// We use a SAM type to avoid boxing of Addr
127+
@FunctionalInterface trait TreeToAddr:
128+
def apply(x: untpd.Tree): Addr

compiler/src/dotty/tools/dotc/core/tasty/TreeBuffer.scala

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import dotty.tools.tasty.TastyBuffer
88
import TastyBuffer.{Addr, NoAddr, AddrWidth}
99

1010
import util.Util.bestFit
11+
import util.SparseIntArray
1112
import config.Printers.pickling
1213
import ast.untpd.Tree
1314

@@ -20,18 +21,20 @@ class TreeBuffer extends TastyBuffer(50000) {
2021
private var delta: Array[Int] = _
2122
private var numOffsets = 0
2223

23-
/** A map from trees to the address at which a tree is pickled. */
24-
private val treeAddrs = new java.util.IdentityHashMap[Tree, Any] // really: Addr | Null
24+
/** A map from tree unique ids to the address index at which a tree is pickled. */
25+
private val addrOfTree = SparseIntArray()
2526

26-
def registerTreeAddr(tree: Tree): Addr = treeAddrs.get(tree) match {
27-
case null => treeAddrs.put(tree, currentAddr); currentAddr
28-
case addr: Addr => addr
29-
}
27+
def registerTreeAddr(tree: Tree): Addr =
28+
val id = tree.uniqueId
29+
if addrOfTree.contains(id) then Addr(addrOfTree(id))
30+
else
31+
addrOfTree(tree.uniqueId) = currentAddr.index
32+
currentAddr
3033

31-
def addrOfTree(tree: Tree): Addr = treeAddrs.get(tree) match {
32-
case null => NoAddr
33-
case addr: Addr => addr
34-
}
34+
def addrOfTree(tree: Tree): Addr =
35+
val idx = tree.uniqueId
36+
if addrOfTree.contains(idx) then Addr(addrOfTree(idx))
37+
else NoAddr
3538

3639
private def offset(i: Int): Addr = Addr(offsets(i))
3740

@@ -156,15 +159,8 @@ class TreeBuffer extends TastyBuffer(50000) {
156159
wasted
157160
}
158161

159-
def adjustTreeAddrs(): Unit = {
160-
val it = treeAddrs.keySet.iterator
161-
while (it.hasNext) {
162-
val tree = it.next
163-
treeAddrs.get(tree) match {
164-
case addr: Addr => treeAddrs.put(tree, adjusted(addr))
165-
}
166-
}
167-
}
162+
def adjustTreeAddrs(): Unit =
163+
addrOfTree.transform((id, addr) => adjusted(Addr(addr)).index)
168164

169165
/** Final assembly, involving the following steps:
170166
* - compute deltas
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
package dotty.tools.dotc
2+
package util
3+
4+
import java.util.NoSuchElementException
5+
6+
class SparseIntArray:
7+
import SparseIntArray._
8+
9+
private var siz: Int = 0
10+
private var root: Node = LeafNode()
11+
12+
private def grow() =
13+
val newRoot = InnerNode(root.level + 1)
14+
newRoot.elems(0) = root
15+
root = newRoot
16+
17+
private def capacity: Int = root.elemSize * NodeSize
18+
19+
def size = siz
20+
21+
def contains(index: Int): Boolean =
22+
0 <= index && index < capacity && root.contains(index)
23+
24+
def apply(index: Int): Value =
25+
require(index >= 0)
26+
if index >= capacity then throw NoSuchElementException()
27+
root.apply(index)
28+
29+
def update(index: Int, value: Value): Unit =
30+
require(index >= 0)
31+
while capacity <= index do
32+
require(root.level < MaxLevels, "array index too large, maximum is 2^30 - 1")
33+
grow()
34+
if !root.update(index, value) then siz += 1
35+
36+
/** Remove element at `index` if it is present
37+
* @return element was present
38+
*/
39+
def remove(index: Int): Boolean =
40+
require(index >= 0)
41+
index < capacity && {
42+
val result = root.remove(index)
43+
if result then siz -= 1
44+
result
45+
}
46+
47+
/** All defined indices in an iterator */
48+
def keysIterator: Iterator[Int] = root.keysIterator(0)
49+
50+
/** Perform operation for each key/value pair */
51+
def foreachBinding(op: (Int, Value) => Unit): Unit =
52+
root.foreachBinding(op, 0)
53+
54+
/** Transform each defined value with transformation `op`.
55+
* The transformation takes the element index and value as parameters.
56+
*/
57+
def transform(op: Transform): Unit =
58+
root.transform(op, 0)
59+
60+
/** Access to some info about low-level representation */
61+
def repr: Repr = root
62+
63+
override def toString =
64+
val b = StringBuilder() ++= "SparseIntArray("
65+
var first = true
66+
foreachBinding { (idx, elem) =>
67+
if first then first = false else b ++= ", "
68+
b ++= s"$idx -> $elem"
69+
}
70+
b ++= ")"
71+
b.toString
72+
73+
object SparseIntArray:
74+
type Value = Int
75+
76+
trait Transform:
77+
def apply(key: Int, v: Value): Value
78+
79+
private inline val NodeSizeLog = 5
80+
private inline val NodeSize = 1 << NodeSizeLog
81+
private inline val MaxLevels = 5 // max size is 2 ^ ((MaxLevels + 1) * NodeSizeLog) = 2 ^ 30
82+
83+
/** The exposed representation. Should be used just for nodeCount and
84+
* low-level toString.
85+
*/
86+
abstract class Repr:
87+
def nodeCount: Int
88+
89+
private abstract class Node(val level: Int) extends Repr:
90+
private[SparseIntArray] def elemShift = level * NodeSizeLog
91+
private[SparseIntArray] def elemSize = 1 << elemShift
92+
private[SparseIntArray] def elemMask = elemSize - 1
93+
def contains(index: Int): Boolean
94+
def apply(index: Int): Value
95+
def update(index: Int, value: Value): Boolean
96+
def remove(index: Int): Boolean
97+
def isEmpty: Boolean
98+
def keysIterator(offset: Int): Iterator[Int]
99+
def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit
100+
def transform(op: Transform, offset: Int): Unit
101+
def nodeCount: Int
102+
end Node
103+
104+
private class LeafNode extends Node(0):
105+
private val elems = new Array[Value](NodeSize)
106+
private var present: Int = 0
107+
108+
def contains(index: Int): Boolean =
109+
(present & (1 << index)) != 0
110+
111+
def apply(index: Int) =
112+
if !contains(index) then throw NoSuchElementException()
113+
elems(index)
114+
115+
def update(index: Int, value: Value): Boolean =
116+
elems(index) = value
117+
val result = contains(index)
118+
present = present | (1 << index)
119+
result
120+
121+
def remove(index: Int): Boolean =
122+
val result = contains(index)
123+
present = present & ~(1 << index)
124+
result
125+
126+
def isEmpty = present == 0
127+
128+
private def skipUndefined(i: Int): Int =
129+
if i < NodeSize && !contains(i) then skipUndefined(i + 1) else i
130+
131+
def keysIterator(offset: Int) = new Iterator[Int]:
132+
private var curIdx = skipUndefined(0)
133+
def hasNext = curIdx < NodeSize
134+
def next(): Int =
135+
val result = curIdx + offset
136+
curIdx = skipUndefined(curIdx + 1)
137+
result
138+
139+
def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit =
140+
var i = 0
141+
while i < NodeSize do
142+
if contains(i) then op(offset + i, elems(i))
143+
i += 1
144+
145+
def transform(op: Transform, offset: Int): Unit =
146+
var i = 0
147+
while i < NodeSize do
148+
if contains(i) then elems(i) = op(offset + i, elems(i))
149+
i += 1
150+
151+
def nodeCount = 1
152+
153+
override def toString =
154+
elems
155+
.zipWithIndex
156+
.filter((elem, idx) => contains(idx))
157+
.map((elem, idx) => s"$idx -> $elem").mkString(s"0#(", ", ", ")")
158+
end LeafNode
159+
160+
private class InnerNode(level: Int) extends Node(level):
161+
private[SparseIntArray] val elems = new Array[Node](NodeSize)
162+
private var empty: Boolean = true
163+
164+
def contains(index: Int): Boolean =
165+
val elem = elems(index >>> elemShift)
166+
elem != null && elem.contains(index & elemMask)
167+
168+
def apply(index: Int): Value =
169+
val elem = elems(index >>> elemShift)
170+
if elem == null then throw NoSuchElementException()
171+
elem.apply(index & elemMask)
172+
173+
def update(index: Int, value: Value): Boolean =
174+
empty = false
175+
var elem = elems(index >>> elemShift)
176+
if elem == null then
177+
elem = newNode(level - 1)
178+
elems(index >>> elemShift) = elem
179+
elem.update(index & elemMask, value)
180+
181+
def remove(index: Int): Boolean =
182+
val elem = elems(index >>> elemShift)
183+
if elem == null then false
184+
else
185+
val result = elem.remove(index & elemMask)
186+
if elem.isEmpty then
187+
elems(index >>> elemShift) = null
188+
var i = 0
189+
while i < NodeSize && elems(i) == null do i += 1
190+
if i == NodeSize then empty = true
191+
result
192+
193+
def isEmpty = empty
194+
195+
private def skipUndefined(i: Int): Int =
196+
if i < NodeSize && elems(i) == null then skipUndefined(i + 1) else i
197+
198+
// Note: This takes (depth of tree) recursive steps to produce the
199+
// next index. It could be more efficient if we kept all active iterators
200+
// in a path.
201+
def keysIterator(offset: Int) = new Iterator[Value]:
202+
private var curIdx = skipUndefined(0)
203+
private var elemIt = Iterator.empty[Int]
204+
def hasNext = elemIt.hasNext || curIdx < NodeSize
205+
def next(): Value =
206+
if elemIt.hasNext then elemIt.next()
207+
else
208+
elemIt = elems(curIdx).keysIterator(offset + curIdx * elemSize)
209+
curIdx = skipUndefined(curIdx + 1)
210+
elemIt.next()
211+
212+
def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit =
213+
var i = 0
214+
while i < NodeSize do
215+
if elems(i) != null then
216+
elems(i).foreachBinding(op, offset + i * elemSize)
217+
i += 1
218+
219+
def transform(op: Transform, offset: Int): Unit =
220+
var i = 0
221+
while i < NodeSize do
222+
if elems(i) != null then
223+
elems(i).transform(op, offset + i * elemSize)
224+
i += 1
225+
226+
def nodeCount =
227+
1 + elems.filter(_ != null).map(_.nodeCount).sum
228+
229+
override def toString =
230+
elems
231+
.zipWithIndex
232+
.filter((elem, idx) => elem != null)
233+
.map((elem, idx) => s"$idx -> $elem").mkString(s"$level#(", ", ", ")")
234+
end InnerNode
235+
236+
private def newNode(level: Int): Node =
237+
if level == 0 then LeafNode() else InnerNode(level)
238+
239+
end SparseIntArray
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package dotty.tools.dotc.util
2+
3+
import org.junit.Assert._
4+
import org.junit.Test
5+
6+
class SparseIntArrayTests:
7+
@Test
8+
def sparseArrayTests: Unit =
9+
val a = SparseIntArray()
10+
assert(a.toString == "SparseIntArray()")
11+
a(1) = 22
12+
assert(a.toString == "SparseIntArray(1 -> 22)")
13+
a(222) = 33
14+
assert(a.toString == "SparseIntArray(1 -> 22, 222 -> 33)")
15+
a(55555) = 44
16+
assert(a.toString == "SparseIntArray(1 -> 22, 222 -> 33, 55555 -> 44)")
17+
assert(a.keysIterator.toList == List(1, 222, 55555))
18+
assert(a.size == 3, a)
19+
assert(a.contains(1), a)
20+
assert(a.contains(222), a)
21+
assert(a.contains(55555), a)
22+
assert(!a.contains(2))
23+
assert(!a.contains(20000000))
24+
a(222) = 44
25+
assert(a.size == 3)
26+
assert(a(1) == 22)
27+
assert(a(222) == 44)
28+
assert(a(55555) == 44)
29+
assert(a.remove(1))
30+
assert(a.toString == "SparseIntArray(222 -> 44, 55555 -> 44)")
31+
assert(a(222) == 44, a)
32+
assert(a.remove(55555))
33+
assert(a(222) == 44, a)
34+
assert(a.size == 1)
35+
assert(!a.contains(1))
36+
assert(!a.remove(55555))
37+
assert(a.remove(222))
38+
assert(a.size == 0)

0 commit comments

Comments
 (0)