Skip to content

Use sparse arrays instead of maps for pickling #9604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Contexts._, Symbols._, Annotations._, Decorators._
import collection.mutable
import util.Spans._

class PositionPickler(pickler: TastyPickler, addrOfTree: untpd.Tree => Addr) {
class PositionPickler(pickler: TastyPickler, addrOfTree: PositionPickler.TreeToAddr) {
val buf: TastyBuffer = new TastyBuffer(5000)
pickler.newSection("Positions", buf)
import ast.tpd._
Expand Down Expand Up @@ -121,3 +121,8 @@ class PositionPickler(pickler: TastyPickler, addrOfTree: untpd.Tree => Addr) {
traverse(root, NoSource)
}
}
object PositionPickler:
// Note: This could be just TreeToAddr => Addr if functions are specialized to value classes.
// We use a SAM type to avoid boxing of Addr
@FunctionalInterface trait TreeToAddr:
def apply(x: untpd.Tree): Addr
34 changes: 15 additions & 19 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import dotty.tools.tasty.TastyBuffer
import TastyBuffer.{Addr, NoAddr, AddrWidth}

import util.Util.bestFit
import util.SparseIntArray
import config.Printers.pickling
import ast.untpd.Tree

Expand All @@ -20,18 +21,20 @@ class TreeBuffer extends TastyBuffer(50000) {
private var delta: Array[Int] = _
private var numOffsets = 0

/** A map from trees to the address at which a tree is pickled. */
private val treeAddrs = new java.util.IdentityHashMap[Tree, Any] // really: Addr | Null
/** A map from tree unique ids to the address index at which a tree is pickled. */
private val addrOfTree = SparseIntArray()

def registerTreeAddr(tree: Tree): Addr = treeAddrs.get(tree) match {
case null => treeAddrs.put(tree, currentAddr); currentAddr
case addr: Addr => addr
}
def registerTreeAddr(tree: Tree): Addr =
val id = tree.uniqueId
if addrOfTree.contains(id) then Addr(addrOfTree(id))
else
addrOfTree(tree.uniqueId) = currentAddr.index
currentAddr

def addrOfTree(tree: Tree): Addr = treeAddrs.get(tree) match {
case null => NoAddr
case addr: Addr => addr
}
def addrOfTree(tree: Tree): Addr =
val idx = tree.uniqueId
if addrOfTree.contains(idx) then Addr(addrOfTree(idx))
else NoAddr

private def offset(i: Int): Addr = Addr(offsets(i))

Expand Down Expand Up @@ -156,15 +159,8 @@ class TreeBuffer extends TastyBuffer(50000) {
wasted
}

def adjustTreeAddrs(): Unit = {
val it = treeAddrs.keySet.iterator
while (it.hasNext) {
val tree = it.next
treeAddrs.get(tree) match {
case addr: Addr => treeAddrs.put(tree, adjusted(addr))
}
}
}
def adjustTreeAddrs(): Unit =
addrOfTree.transform((id, addr) => adjusted(Addr(addr)).index)

/** Final assembly, involving the following steps:
* - compute deltas
Expand Down
239 changes: 239 additions & 0 deletions compiler/src/dotty/tools/dotc/util/SparseIntArray.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package dotty.tools.dotc
package util

import java.util.NoSuchElementException

class SparseIntArray:
import SparseIntArray._

private var siz: Int = 0
private var root: Node = LeafNode()

private def grow() =
val newRoot = InnerNode(root.level + 1)
newRoot.elems(0) = root
root = newRoot

private def capacity: Int = root.elemSize * NodeSize

def size = siz

def contains(index: Int): Boolean =
0 <= index && index < capacity && root.contains(index)

def apply(index: Int): Value =
require(index >= 0)
if index >= capacity then throw NoSuchElementException()
root.apply(index)

def update(index: Int, value: Value): Unit =
require(index >= 0)
while capacity <= index do
require(root.level < MaxLevels, "array index too large, maximum is 2^30 - 1")
grow()
if !root.update(index, value) then siz += 1

/** Remove element at `index` if it is present
* @return element was present
*/
def remove(index: Int): Boolean =
require(index >= 0)
index < capacity && {
val result = root.remove(index)
if result then siz -= 1
result
}

/** All defined indices in an iterator */
def keysIterator: Iterator[Int] = root.keysIterator(0)

/** Perform operation for each key/value pair */
def foreachBinding(op: (Int, Value) => Unit): Unit =
root.foreachBinding(op, 0)

/** Transform each defined value with transformation `op`.
* The transformation takes the element index and value as parameters.
*/
def transform(op: Transform): Unit =
root.transform(op, 0)

/** Access to some info about low-level representation */
def repr: Repr = root

override def toString =
val b = StringBuilder() ++= "SparseIntArray("
var first = true
foreachBinding { (idx, elem) =>
if first then first = false else b ++= ", "
b ++= s"$idx -> $elem"
}
b ++= ")"
b.toString

object SparseIntArray:
type Value = Int

trait Transform:
def apply(key: Int, v: Value): Value

private inline val NodeSizeLog = 5
private inline val NodeSize = 1 << NodeSizeLog
private inline val MaxLevels = 5 // max size is 2 ^ ((MaxLevels + 1) * NodeSizeLog) = 2 ^ 30

/** The exposed representation. Should be used just for nodeCount and
* low-level toString.
*/
abstract class Repr:
def nodeCount: Int

private abstract class Node(val level: Int) extends Repr:
private[SparseIntArray] def elemShift = level * NodeSizeLog
private[SparseIntArray] def elemSize = 1 << elemShift
private[SparseIntArray] def elemMask = elemSize - 1
def contains(index: Int): Boolean
def apply(index: Int): Value
def update(index: Int, value: Value): Boolean
def remove(index: Int): Boolean
def isEmpty: Boolean
def keysIterator(offset: Int): Iterator[Int]
def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit
def transform(op: Transform, offset: Int): Unit
def nodeCount: Int
end Node

private class LeafNode extends Node(0):
private val elems = new Array[Value](NodeSize)
private var present: Int = 0

def contains(index: Int): Boolean =
(present & (1 << index)) != 0

def apply(index: Int) =
if !contains(index) then throw NoSuchElementException()
elems(index)

def update(index: Int, value: Value): Boolean =
elems(index) = value
val result = contains(index)
present = present | (1 << index)
result

def remove(index: Int): Boolean =
val result = contains(index)
present = present & ~(1 << index)
result

def isEmpty = present == 0

private def skipUndefined(i: Int): Int =
if i < NodeSize && !contains(i) then skipUndefined(i + 1) else i

def keysIterator(offset: Int) = new Iterator[Int]:
private var curIdx = skipUndefined(0)
def hasNext = curIdx < NodeSize
def next(): Int =
val result = curIdx + offset
curIdx = skipUndefined(curIdx + 1)
result

def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit =
var i = 0
while i < NodeSize do
if contains(i) then op(offset + i, elems(i))
i += 1

def transform(op: Transform, offset: Int): Unit =
var i = 0
while i < NodeSize do
if contains(i) then elems(i) = op(offset + i, elems(i))
i += 1

def nodeCount = 1

override def toString =
elems
.zipWithIndex
.filter((elem, idx) => contains(idx))
.map((elem, idx) => s"$idx -> $elem").mkString(s"0#(", ", ", ")")
end LeafNode

private class InnerNode(level: Int) extends Node(level):
private[SparseIntArray] val elems = new Array[Node](NodeSize)
private var empty: Boolean = true

def contains(index: Int): Boolean =
val elem = elems(index >>> elemShift)
elem != null && elem.contains(index & elemMask)

def apply(index: Int): Value =
val elem = elems(index >>> elemShift)
if elem == null then throw NoSuchElementException()
elem.apply(index & elemMask)

def update(index: Int, value: Value): Boolean =
empty = false
var elem = elems(index >>> elemShift)
if elem == null then
elem = newNode(level - 1)
elems(index >>> elemShift) = elem
elem.update(index & elemMask, value)

def remove(index: Int): Boolean =
val elem = elems(index >>> elemShift)
if elem == null then false
else
val result = elem.remove(index & elemMask)
if elem.isEmpty then
elems(index >>> elemShift) = null
var i = 0
while i < NodeSize && elems(i) == null do i += 1
if i == NodeSize then empty = true
result

def isEmpty = empty

private def skipUndefined(i: Int): Int =
if i < NodeSize && elems(i) == null then skipUndefined(i + 1) else i

// Note: This takes (depth of tree) recursive steps to produce the
// next index. It could be more efficient if we kept all active iterators
// in a path.
def keysIterator(offset: Int) = new Iterator[Value]:
private var curIdx = skipUndefined(0)
private var elemIt = Iterator.empty[Int]
def hasNext = elemIt.hasNext || curIdx < NodeSize
def next(): Value =
if elemIt.hasNext then elemIt.next()
else
elemIt = elems(curIdx).keysIterator(offset + curIdx * elemSize)
curIdx = skipUndefined(curIdx + 1)
elemIt.next()

def foreachBinding(op: (Int, Value) => Unit, offset: Int): Unit =
var i = 0
while i < NodeSize do
if elems(i) != null then
elems(i).foreachBinding(op, offset + i * elemSize)
i += 1

def transform(op: Transform, offset: Int): Unit =
var i = 0
while i < NodeSize do
if elems(i) != null then
elems(i).transform(op, offset + i * elemSize)
i += 1

def nodeCount =
1 + elems.filter(_ != null).map(_.nodeCount).sum

override def toString =
elems
.zipWithIndex
.filter((elem, idx) => elem != null)
.map((elem, idx) => s"$idx -> $elem").mkString(s"$level#(", ", ", ")")
end InnerNode

private def newNode(level: Int): Node =
if level == 0 then LeafNode() else InnerNode(level)

end SparseIntArray
38 changes: 38 additions & 0 deletions compiler/test/dotty/tools/dotc/util/SparseIntArrayTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package dotty.tools.dotc.util

import org.junit.Assert._
import org.junit.Test

class SparseIntArrayTests:
@Test
def sparseArrayTests: Unit =
val a = SparseIntArray()
assert(a.toString == "SparseIntArray()")
a(1) = 22
assert(a.toString == "SparseIntArray(1 -> 22)")
a(222) = 33
assert(a.toString == "SparseIntArray(1 -> 22, 222 -> 33)")
a(55555) = 44
assert(a.toString == "SparseIntArray(1 -> 22, 222 -> 33, 55555 -> 44)")
assert(a.keysIterator.toList == List(1, 222, 55555))
assert(a.size == 3, a)
assert(a.contains(1), a)
assert(a.contains(222), a)
assert(a.contains(55555), a)
assert(!a.contains(2))
assert(!a.contains(20000000))
a(222) = 44
assert(a.size == 3)
assert(a(1) == 22)
assert(a(222) == 44)
assert(a(55555) == 44)
assert(a.remove(1))
assert(a.toString == "SparseIntArray(222 -> 44, 55555 -> 44)")
assert(a(222) == 44, a)
assert(a.remove(55555))
assert(a(222) == 44, a)
assert(a.size == 1)
assert(!a.contains(1))
assert(!a.remove(55555))
assert(a.remove(222))
assert(a.size == 0)