diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index 01a3362c659a..bd3ab4e3ae0f 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -26,6 +26,8 @@ import xsbti.UseScope import xsbti.api.DependencyContext import xsbti.api.DependencyContext._ +import scala.jdk.CollectionConverters.* + import scala.collection.{Set, mutable} @@ -74,8 +76,8 @@ class ExtractDependencies extends Phase { collector.traverse(unit.tpdTree) if (ctx.settings.YdumpSbtInc.value) { - val deps = rec.classDependencies.map(_.toString).toArray[Object] - val names = rec.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object] + val deps = rec.foundDeps.iterator.map { case (clazz, found) => s"$clazz: ${found.classesString}" }.toArray[Object] + val names = rec.foundDeps.iterator.map { case (clazz, found) => s"$clazz: ${found.namesString}" }.toArray[Object] Arrays.sort(deps) Arrays.sort(names) @@ -162,7 +164,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. /** Traverse the tree of a source file and record the dependencies and used names which - * can be retrieved using `dependencies` and`usedNames`. + * can be retrieved using `foundDeps`. */ override def traverse(tree: Tree)(using Context): Unit = try { tree match { @@ -226,6 +228,13 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. throw ex } + /**Reused EqHashSet, safe to use as each TypeDependencyTraverser is used atomically + * Avoid cycles by remembering both the types (testcase: + * tests/run/enum-values.scala) and the symbols of named types (testcase: + * tests/pos-java-interop/i13575) we've seen before. + */ + private val scratchSeen = new util.EqHashSet[Symbol | Type](128) + /** Traverse a used type and record all the dependencies we need to keep track * of for incremental recompilation. * @@ -262,17 +271,13 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. private abstract class TypeDependencyTraverser(using Context) extends TypeTraverser() { protected def addDependency(symbol: Symbol): Unit - // Avoid cycles by remembering both the types (testcase: - // tests/run/enum-values.scala) and the symbols of named types (testcase: - // tests/pos-java-interop/i13575) we've seen before. - val seen = new mutable.HashSet[Symbol | Type] - def traverse(tp: Type): Unit = if (!seen.contains(tp)) { - seen += tp + scratchSeen.clear(resetToInitial = false) + + def traverse(tp: Type): Unit = if scratchSeen.add(tp) then { tp match { case tp: NamedType => val sym = tp.symbol - if !seen.contains(sym) && !sym.is(Package) then - seen += sym + if !sym.is(Package) && scratchSeen.add(sym) then addDependency(sym) if !sym.isClass then traverse(tp.info) traverse(tp.prefix) @@ -306,8 +311,6 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. } } -case class ClassDependency(fromClass: Symbol, toClass: Symbol, context: DependencyContext) - /** Record dependencies using `addUsedName`/`addClassDependency` and inform Zinc using `sendToZinc()`. * * Note: As an alternative design choice, we could directly call the appropriate @@ -319,10 +322,10 @@ case class ClassDependency(fromClass: Symbol, toClass: Symbol, context: Dependen class DependencyRecorder { import ExtractDependencies.* - /** A map from a non-local class to the names it uses, this does not include + /** A map from a non-local class to the names and classes it uses, this does not include * names which are only defined and not referenced. */ - def usedNames: collection.Map[Symbol, UsedNamesInClass] = _usedNames + def foundDeps: util.ReadOnlyMap[Symbol, FoundDepsInClass] = _foundDeps /** Record a reference to the name of `sym` from the current non-local * enclosing class. @@ -355,10 +358,9 @@ class DependencyRecorder { * safely. */ def addUsedRawName(name: Name, includeSealedChildren: Boolean = false)(using Context): Unit = { - val fromClass = resolveDependencySource + val fromClass = resolveDependencyFromClass if (fromClass.exists) { - val usedName = _usedNames.getOrElseUpdate(fromClass, new UsedNamesInClass) - usedName.update(name, includeSealedChildren) + lastFoundCache.recordName(name, includeSealedChildren) } } @@ -367,24 +369,34 @@ class DependencyRecorder { private val DefaultScopes = EnumSet.of(UseScope.Default) private val PatMatScopes = EnumSet.of(UseScope.Default, UseScope.PatMatTarget) - /** An object that maintain the set of used names from within a class */ - final class UsedNamesInClass { + /** An object that maintain the set of used names and class dependencies from within a class */ + final class FoundDepsInClass { /** Each key corresponds to a name used in the class. To understand the meaning * of the associated value, see the documentation of parameter `includeSealedChildren` * of `addUsedRawName`. */ - private val _names = new mutable.HashMap[Name, DefaultScopes.type | PatMatScopes.type] + private val _names = new util.HashMap[Name, DefaultScopes.type | PatMatScopes.type] + + /** Each key corresponds to a class dependency used in the class. + */ + private val _classes = util.EqHashMap[Symbol, EnumSet[DependencyContext]]() + + def addDependency(fromClass: Symbol, context: DependencyContext): Unit = + val set = _classes.getOrElseUpdate(fromClass, EnumSet.noneOf(classOf[DependencyContext])) + set.add(context) + + def classes: Iterator[(Symbol, EnumSet[DependencyContext])] = _classes.iterator - def names: collection.Map[Name, EnumSet[UseScope]] = _names + def names: Iterator[(Name, EnumSet[UseScope])] = _names.iterator - private[DependencyRecorder] def update(name: Name, includeSealedChildren: Boolean): Unit = { + private[DependencyRecorder] def recordName(name: Name, includeSealedChildren: Boolean): Unit = { if (includeSealedChildren) _names(name) = PatMatScopes else _names.getOrElseUpdate(name, DefaultScopes) } - override def toString(): String = { + def namesString: String = { val builder = new StringBuilder names.foreach { case (name, scopes) => builder.append(name.mangledString) @@ -395,51 +407,59 @@ class DependencyRecorder { } builder.toString() } - } - - - private val _classDependencies = new mutable.HashSet[ClassDependency] - def classDependencies: Set[ClassDependency] = _classDependencies + def classesString: String = { + val builder = new StringBuilder + classes.foreach { case (clazz, scopes) => + builder.append(clazz.toString) + builder.append(" in [") + scopes.forEach(scope => builder.append(scope.toString)) + builder.append("]") + builder.append(", ") + } + builder.toString() + } + } /** Record a dependency to the class `to` in a given `context` * from the current non-local enclosing class. */ def addClassDependency(toClass: Symbol, context: DependencyContext)(using Context): Unit = - val fromClass = resolveDependencySource + val fromClass = resolveDependencyFromClass if (fromClass.exists) - _classDependencies += ClassDependency(fromClass, toClass, context) + lastFoundCache.addDependency(toClass, context) - private val _usedNames = new mutable.HashMap[Symbol, UsedNamesInClass] + private val _foundDeps = new util.EqHashMap[Symbol, FoundDepsInClass] /** Send the collected dependency information to Zinc and clear the local caches. */ def sendToZinc()(using Context): Unit = ctx.withIncCallback: cb => - usedNames.foreach: - case (clazz, usedNames) => - val className = classNameAsString(clazz) - usedNames.names.foreach: - case (usedName, scopes) => - cb.usedName(className, usedName.toString, scopes) val siblingClassfiles = new mutable.HashMap[PlainFile, Path] - classDependencies.foreach(recordClassDependency(cb, _, siblingClassfiles)) + _foundDeps.iterator.foreach: + case (clazz, foundDeps) => + val className = classNameAsString(clazz) + foundDeps.names.foreach: (usedName, scopes) => + cb.usedName(className, usedName.toString, scopes) + for (toClass, deps) <- foundDeps.classes do + for dep <- deps.asScala do + recordClassDependency(cb, clazz, toClass, dep, siblingClassfiles) clear() /** Clear all state. */ def clear(): Unit = - _usedNames.clear() - _classDependencies.clear() + _foundDeps.clear() lastOwner = NoSymbol lastDepSource = NoSymbol + lastFoundCache = null _responsibleForImports = NoSymbol /** Handles dependency on given symbol by trying to figure out if represents a term * that is coming from either source code (not necessarily compiled in this compilation * run) or from class file and calls respective callback method. */ - private def recordClassDependency(cb: interfaces.IncrementalCallback, dep: ClassDependency, - siblingClassfiles: mutable.Map[PlainFile, Path])(using Context): Unit = { - val fromClassName = classNameAsString(dep.fromClass) + private def recordClassDependency(cb: interfaces.IncrementalCallback, fromClass: Symbol, toClass: Symbol, + depCtx: DependencyContext, siblingClassfiles: mutable.Map[PlainFile, Path])(using Context): Unit = { + val fromClassName = classNameAsString(fromClass) val sourceFile = ctx.compilationUnit.source /**For a `.tasty` file, constructs a sibling class to the `jpath`. @@ -465,13 +485,13 @@ class DependencyRecorder { }) def binaryDependency(path: Path, binaryClassName: String) = - cb.binaryDependency(path, binaryClassName, fromClassName, sourceFile, dep.context) + cb.binaryDependency(path, binaryClassName, fromClassName, sourceFile, depCtx) - val depClass = dep.toClass + val depClass = toClass val depFile = depClass.associatedFile if depFile != null then { // Cannot ignore inheritance relationship coming from the same source (see sbt/zinc#417) - def allowLocal = dep.context == DependencyByInheritance || dep.context == LocalDependencyByInheritance + def allowLocal = depCtx == DependencyByInheritance || depCtx == LocalDependencyByInheritance val isTasty = depFile.hasTastyExtension def processExternalDependency() = { @@ -485,7 +505,7 @@ class DependencyRecorder { case pf: PlainFile => // The dependency comes from a class file, Zinc handles JRT filesystem binaryDependency(if isTasty then cachedSiblingClass(pf) else pf.jpath, binaryClassName) case _ => - internalError(s"Ignoring dependency $depFile of unknown class ${depFile.getClass}}", dep.fromClass.srcPos) + internalError(s"Ignoring dependency $depFile of unknown class ${depFile.getClass}}", fromClass.srcPos) } } @@ -495,23 +515,28 @@ class DependencyRecorder { // We cannot ignore dependencies coming from the same source file because // the dependency info needs to propagate. See source-dependencies/trait-trait-211. val toClassName = classNameAsString(depClass) - cb.classDependency(toClassName, fromClassName, dep.context) + cb.classDependency(toClassName, fromClassName, depCtx) } } private var lastOwner: Symbol = _ private var lastDepSource: Symbol = _ + private var lastFoundCache: FoundDepsInClass | Null = _ /** The source of the dependency according to `nonLocalEnclosingClass` * if it exists, otherwise fall back to `responsibleForImports`. * * This is backed by a cache which is invalidated when `ctx.owner` changes. */ - private def resolveDependencySource(using Context): Symbol = { + private def resolveDependencyFromClass(using Context): Symbol = { + import dotty.tools.uncheckedNN if (lastOwner != ctx.owner) { lastOwner = ctx.owner val source = nonLocalEnclosingClass - lastDepSource = if (source.is(PackageClass)) responsibleForImports else source + val fromClass = if (source.is(PackageClass)) responsibleForImports else source + if lastDepSource != fromClass then + lastDepSource = fromClass + lastFoundCache = _foundDeps.getOrElseUpdate(fromClass, new FoundDepsInClass) } lastDepSource diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index cbb13a841946..35bb36b003f9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -20,7 +20,6 @@ import annotation.{tailrec, constructorOnly} import ast.tpd._ import Synthesizer._ import sbt.ExtractDependencies.* -import sbt.ClassDependency import xsbti.api.DependencyContext._ /** Synthesize terms for special classes */ diff --git a/compiler/src/dotty/tools/dotc/util/EqHashMap.scala b/compiler/src/dotty/tools/dotc/util/EqHashMap.scala index ea049acba02b..25d9fb2907b8 100644 --- a/compiler/src/dotty/tools/dotc/util/EqHashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/EqHashMap.scala @@ -58,6 +58,22 @@ extends GenericHashMap[Key, Value](initialCapacity, capacityMultiple): used += 1 if used > limit then growTable() + override def getOrElseUpdate(key: Key, value: => Value): Value = + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v + private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) var idx = firstIndex(key) diff --git a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala new file mode 100644 index 000000000000..d584441fd00a --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala @@ -0,0 +1,106 @@ +package dotty.tools.dotc.util + +import dotty.tools.uncheckedNN + +object EqHashSet: + + def from[T](xs: IterableOnce[T]): EqHashSet[T] = + val set = new EqHashSet[T]() + set ++= xs + set + +/** A hash set that allows some privileged protected access to its internals + * @param initialCapacity Indicates the initial number of slots in the hash table. + * The actual number of slots is always a power of 2, so the + * initial size of the table will be the smallest power of two + * that is equal or greater than the given `initialCapacity`. + * Minimum value is 4. + * @param capacityMultiple The minimum multiple of capacity relative to used elements. + * The hash table will be re-sized once the number of elements + * multiplied by capacityMultiple exceeds the current size of the hash table. + * However, a table of size up to DenseLimit will be re-sized only + * once the number of elements reaches the table's size. + */ +class EqHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends GenericHashSet[T](initialCapacity, capacityMultiple) { + import GenericHashSet.DenseLimit + + /** System's identity hashcode left shifted by 1 */ + final def hash(key: T): Int = + System.identityHashCode(key) << 1 + + /** reference equality */ + final def isEqual(x: T, y: T): Boolean = x.asInstanceOf[AnyRef] eq y.asInstanceOf[AnyRef] + + /** Turn hashcode `x` into a table index */ + private def index(x: Int): Int = x & (table.length - 1) + + private def firstIndex(x: T) = if isDense then 0 else index(hash(x)) + private def nextIndex(idx: Int) = + Stats.record(statsItem("miss")) + index(idx + 1) + + private def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] + private def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] + + override def lookup(x: T): T | Null = + Stats.record(statsItem("lookup")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return e + idx = nextIndex(idx) + e = entryAt(idx) + null + + /** Add entry at `x` at index `idx` */ + private def addEntryAt(idx: Int, x: T): T = + Stats.record(statsItem("addEntryAt")) + setEntry(idx, x) + used += 1 + if used > limit then growTable() + x + + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. */ + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return false // already entered + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + true // first entry + + override def put(x: T): T = + Stats.record(statsItem("put")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + + override def +=(x: T): Unit = put(x) + + private def addOld(x: T) = + Stats.record(statsItem("re-enter")) + var idx = firstIndex(x) + var e = entryAt(idx) + while e != null do + idx = nextIndex(idx) + e = entryAt(idx) + setEntry(idx, x) + + override def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + if isDense then + Array.copy(oldTable, 0, table, 0, oldTable.length) + else + var idx = 0 + while idx < oldTable.length do + val e: T | Null = oldTable(idx).asInstanceOf[T | Null] + if e != null then addOld(e.uncheckedNN) + idx += 1 +} diff --git a/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala b/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala index a21a4af37038..6d013717ec52 100644 --- a/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala @@ -129,12 +129,20 @@ abstract class GenericHashMap[Key, Value] null def getOrElseUpdate(key: Key, value: => Value): Value = - var v: Value | Null = lookup(key) - if v == null then - val v1 = value - v = v1 - update(key, v1) - v.uncheckedNN + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) diff --git a/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala new file mode 100644 index 000000000000..7abe40a8e13d --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala @@ -0,0 +1,190 @@ +package dotty.tools.dotc.util + +import dotty.tools.uncheckedNN + +object GenericHashSet: + + /** The number of elements up to which dense packing is used. + * If the number of elements reaches `DenseLimit` a hash table is used instead + */ + inline val DenseLimit = 8 + +/** A hash set that allows some privileged protected access to its internals + * @param initialCapacity Indicates the initial number of slots in the hash table. + * The actual number of slots is always a power of 2, so the + * initial size of the table will be the smallest power of two + * that is equal or greater than the given `initialCapacity`. + * Minimum value is 4. +* @param capacityMultiple The minimum multiple of capacity relative to used elements. + * The hash table will be re-sized once the number of elements + * multiplied by capacityMultiple exceeds the current size of the hash table. + * However, a table of size up to DenseLimit will be re-sized only + * once the number of elements reaches the table's size. + */ +abstract class GenericHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends MutableSet[T] { + import GenericHashSet.DenseLimit + + protected var used: Int = _ + protected var limit: Int = _ + protected var table: Array[AnyRef | Null] = _ + + clear() + + private def allocate(capacity: Int) = + table = new Array[AnyRef | Null](capacity) + limit = if capacity <= DenseLimit then capacity - 1 else capacity / capacityMultiple + + private def roundToPower(n: Int) = + if n < 4 then 4 + else 1 << (32 - Integer.numberOfLeadingZeros(n - 1)) + + def clear(resetToInitial: Boolean): Unit = + used = 0 + if resetToInitial then allocate(roundToPower(initialCapacity)) + else java.util.Arrays.fill(table, null) + + /** The number of elements in the set */ + def size: Int = used + + protected def isDense = limit < DenseLimit + + /** Hashcode, to be implemented in subclass */ + protected def hash(key: T): Int + + /** Equality, to be implemented in subclass */ + protected def isEqual(x: T, y: T): Boolean + + /** Turn hashcode `x` into a table index */ + private def index(x: Int): Int = x & (table.length - 1) + + protected def currentTable: Array[AnyRef | Null] = table + + private def firstIndex(x: T) = if isDense then 0 else index(hash(x)) + private def nextIndex(idx: Int) = + Stats.record(statsItem("miss")) + index(idx + 1) + + private def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] + private def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] + + def lookup(x: T): T | Null = + Stats.record(statsItem("lookup")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return e + idx = nextIndex(idx) + e = entryAt(idx) + null + + /** Add entry at `x` at index `idx` */ + private def addEntryAt(idx: Int, x: T): T = + Stats.record(statsItem("addEntryAt")) + setEntry(idx, x) + used += 1 + if used > limit then growTable() + x + + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. */ + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return false // already entered + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + true // first entry + + def put(x: T): T = + Stats.record(statsItem("put")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + + def +=(x: T): Unit = put(x) + + def remove(x: T): Boolean = + Stats.record(statsItem("remove")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then + var hole = idx + while + idx = nextIndex(idx) + e = entryAt(idx) + e != null + do + val eidx = index(hash(e.uncheckedNN)) + if isDense + || index(eidx - (hole + 1)) > index(idx - (hole + 1)) + // entry `e` at `idx` can move unless `index(hash(e))` is in + // the (ring-)interval [hole + 1 .. idx] + then + setEntry(hole, e.uncheckedNN) + hole = idx + table(hole) = null + used -= 1 + return true + idx = nextIndex(idx) + e = entryAt(idx) + false + + def -=(x: T): Unit = + remove(x) + + private def addOld(x: T) = + Stats.record(statsItem("re-enter")) + var idx = firstIndex(x) + var e = entryAt(idx) + while e != null do + idx = nextIndex(idx) + e = entryAt(idx) + setEntry(idx, x) + + def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + if isDense then + Array.copy(oldTable, 0, table, 0, oldTable.length) + else + var idx = 0 + while idx < oldTable.length do + val e: T | Null = oldTable(idx).asInstanceOf[T | Null] + if e != null then addOld(e.uncheckedNN) + idx += 1 + + protected def growTable(): Unit = + val oldTable = table + val newLength = + if oldTable.length == DenseLimit then DenseLimit * 2 * roundToPower(capacityMultiple) + else table.length * 2 + allocate(newLength) + copyFrom(oldTable) + + abstract class EntryIterator extends Iterator[T]: + def entry(idx: Int): T | Null + private var idx = 0 + def hasNext = + while idx < table.length && table(idx) == null do idx += 1 + idx < table.length + def next() = + require(hasNext) + try entry(idx).uncheckedNN finally idx += 1 + + def iterator: Iterator[T] = new EntryIterator(): + def entry(idx: Int) = entryAt(idx) + + override def toString: String = + iterator.mkString("HashSet(", ", ", ")") + + protected def statsItem(op: String) = + val prefix = if isDense then "HashSet(dense)." else "HashSet." + val suffix = getClass.getSimpleName + s"$prefix$op $suffix" +} diff --git a/compiler/src/dotty/tools/dotc/util/HashMap.scala b/compiler/src/dotty/tools/dotc/util/HashMap.scala index aaae781c310a..eec3a604b5e2 100644 --- a/compiler/src/dotty/tools/dotc/util/HashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/HashMap.scala @@ -63,6 +63,22 @@ extends GenericHashMap[Key, Value](initialCapacity, capacityMultiple): used += 1 if used > limit then growTable() + override def getOrElseUpdate(key: Key, value: => Value): Value = + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v + private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) var idx = firstIndex(key) diff --git a/compiler/src/dotty/tools/dotc/util/HashSet.scala b/compiler/src/dotty/tools/dotc/util/HashSet.scala index a6e1532c804f..3a973793d542 100644 --- a/compiler/src/dotty/tools/dotc/util/HashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/HashSet.scala @@ -4,11 +4,6 @@ import dotty.tools.uncheckedNN object HashSet: - /** The number of elements up to which dense packing is used. - * If the number of elements reaches `DenseLimit` a hash table is used instead - */ - inline val DenseLimit = 8 - def from[T](xs: IterableOnce[T]): HashSet[T] = val set = new HashSet[T]() set ++= xs @@ -26,33 +21,8 @@ object HashSet: * However, a table of size up to DenseLimit will be re-sized only * once the number of elements reaches the table's size. */ -class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends MutableSet[T] { - import HashSet.DenseLimit - - private var used: Int = _ - private var limit: Int = _ - private var table: Array[AnyRef | Null] = _ - - clear() - - private def allocate(capacity: Int) = - table = new Array[AnyRef | Null](capacity) - limit = if capacity <= DenseLimit then capacity - 1 else capacity / capacityMultiple - - private def roundToPower(n: Int) = - if n < 4 then 4 - else if Integer.bitCount(n) == 1 then n - else 1 << (32 - Integer.numberOfLeadingZeros(n)) - - def clear(resetToInitial: Boolean): Unit = - used = 0 - if resetToInitial then allocate(roundToPower(initialCapacity)) - else java.util.Arrays.fill(table, null) - - /** The number of elements in the set */ - def size: Int = used - - protected def isDense = limit < DenseLimit +class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends GenericHashSet[T](initialCapacity, capacityMultiple) { + import GenericHashSet.DenseLimit /** Hashcode, by default a processed `x.hashCode`, can be overridden */ protected def hash(key: T): Int = @@ -68,8 +38,6 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu /** Turn hashcode `x` into a table index */ protected def index(x: Int): Int = x & (table.length - 1) - protected def currentTable: Array[AnyRef | Null] = table - protected def firstIndex(x: T) = if isDense then 0 else index(hash(x)) protected def nextIndex(idx: Int) = Stats.record(statsItem("miss")) @@ -78,7 +46,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu protected def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] protected def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] - def lookup(x: T): T | Null = + override def lookup(x: T): T | Null = Stats.record(statsItem("lookup")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) @@ -96,48 +64,29 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu if used > limit then growTable() x - def put(x: T): T = - Stats.record(statsItem("put")) + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) while e != null do - // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling - if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + if isEqual(e.uncheckedNN, x) then return false // already entered idx = nextIndex(idx) e = entryAt(idx) addEntryAt(idx, x) + true // first entry - def +=(x: T): Unit = put(x) - - def remove(x: T): Boolean = - Stats.record(statsItem("remove")) + override def put(x: T): T = + Stats.record(statsItem("put")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) while e != null do - if isEqual(e.uncheckedNN, x) then - var hole = idx - while - idx = nextIndex(idx) - e = entryAt(idx) - e != null - do - val eidx = index(hash(e.uncheckedNN)) - if isDense - || index(eidx - (hole + 1)) > index(idx - (hole + 1)) - // entry `e` at `idx` can move unless `index(hash(e))` is in - // the (ring-)interval [hole + 1 .. idx] - then - setEntry(hole, e.uncheckedNN) - hole = idx - table(hole) = null - used -= 1 - return true + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN idx = nextIndex(idx) e = entryAt(idx) - false + addEntryAt(idx, x) - def -=(x: T): Unit = - remove(x) + override def +=(x: T): Unit = put(x) private def addOld(x: T) = Stats.record(statsItem("re-enter")) @@ -148,7 +97,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu e = entryAt(idx) setEntry(idx, x) - def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + override def copyFrom(oldTable: Array[AnyRef | Null]): Unit = if isDense then Array.copy(oldTable, 0, table, 0, oldTable.length) else @@ -157,33 +106,4 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu val e: T | Null = oldTable(idx).asInstanceOf[T | Null] if e != null then addOld(e.uncheckedNN) idx += 1 - - protected def growTable(): Unit = - val oldTable = table - val newLength = - if oldTable.length == DenseLimit then DenseLimit * 2 * roundToPower(capacityMultiple) - else table.length * 2 - allocate(newLength) - copyFrom(oldTable) - - abstract class EntryIterator extends Iterator[T]: - def entry(idx: Int): T | Null - private var idx = 0 - def hasNext = - while idx < table.length && table(idx) == null do idx += 1 - idx < table.length - def next() = - require(hasNext) - try entry(idx).uncheckedNN finally idx += 1 - - def iterator: Iterator[T] = new EntryIterator(): - def entry(idx: Int) = entryAt(idx) - - override def toString: String = - iterator.mkString("HashSet(", ", ", ")") - - protected def statsItem(op: String) = - val prefix = if isDense then "HashSet(dense)." else "HashSet." - val suffix = getClass.getSimpleName - s"$prefix$op $suffix" } diff --git a/compiler/src/dotty/tools/dotc/util/MutableSet.scala b/compiler/src/dotty/tools/dotc/util/MutableSet.scala index 9529262fa5ec..05fd57a50e71 100644 --- a/compiler/src/dotty/tools/dotc/util/MutableSet.scala +++ b/compiler/src/dotty/tools/dotc/util/MutableSet.scala @@ -7,6 +7,13 @@ abstract class MutableSet[T] extends ReadOnlySet[T]: /** Add element `x` to the set */ def +=(x: T): Unit + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. + * Overridden in GenericHashSet. + */ + def add(x: T): Boolean = + if lookup(x) == null then { this += x; true } + else false + /** Like `+=` but return existing element equal to `x` of it exists, * `x` itself otherwise. */ diff --git a/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala b/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala new file mode 100644 index 000000000000..561dabb555a9 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala @@ -0,0 +1,115 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class EqHashMapTest: + + var counter = 0 + + // basic identity hash, and reference equality, but with a counter for ordering + class Id: + val count = { counter += 1; counter } + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + + @Test + def newEmpty: Unit = + val m = EqHashMap[Id, Int]() + assert(m.size == 0) + assert(m.iterator.toList == Nil) + + @Test + def update: Unit = + val m = EqHashMap[Id, Int]() + assert(m.size == 0 && !m.contains(id1)) + m.update(id1, 1) + assert(m.size == 1 && m(id1) == 1) + m.update(id1, 2) // replace value + assert(m.size == 1 && m(id1) == 2) + m.update(id3, 3) // new key + assert(m.size == 2 && m(id1) == 2 && m(id3) == 3) + + @Test + def getOrElseUpdate: Unit = + val m = EqHashMap[Id, Int]() + // add id1 + assert(m.size == 0 && !m.contains(id1)) + val added = m.getOrElseUpdate(id1, 1) + assert(added == 1 && m.size == 1 && m(id1) == 1) + // try add id1 again + val addedAgain = m.getOrElseUpdate(id1, 23) + assert(addedAgain != 23 && m.size == 1 && m(id1) == 1) // no change + + private def fullMap() = + val m = EqHashMap[Id, Int]() + m.update(id1, 1) + m.update(id2, 2) + m + + @Test + def remove: Unit = + val m = fullMap() + // remove id2 + m.remove(id2) + assert(m.size == 1) + assert(m.contains(id1) && !m.contains(id2)) + // remove id1 + m -= id1 + assert(m.size == 0) + assert(!m.contains(id1) && !m.contains(id2)) + + @Test + def lookup: Unit = + val m = fullMap() + assert(m.lookup(id1) == 1) + assert(m.lookup(id2) == 2) + assert(m.lookup(id3) == null) + + @Test + def iterator: Unit = + val m = fullMap() + assert(m.iterator.toList.sorted == List(id1 -> 1,id2 -> 2)) + + @Test + def clear: Unit = + locally: + val s1 = fullMap() + s1.clear() + assert(s1.size == 0) + locally: + val s2 = fullMap() + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash set is based on reference equality, i.e. does not use universal equality */ + @Test + def referenceEquality: Unit = + val i1, i2 = I32(1) // different instances + + assert(i1.equals(i2)) // structural equality + assert(i1 ne i2) // reference inequality + + val m = locally: + val m = EqHashMap[I32, Int]() + m(i1) = 23 + m(i2) = 29 + m + + assert(m.size == 2 && m(i1) == 23 && m(i2) == 29) + assert(m.keysIterator.toSet == Set(i1)) // scala.Set delegates to universal equality + end referenceEquality + diff --git a/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala b/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala new file mode 100644 index 000000000000..1c1ffe0b7931 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala @@ -0,0 +1,119 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class EqHashSetTest: + + var counter = 0 + + // basic identity hash, and reference equality, but with a counter for ordering + class Id: + val count = { counter += 1; counter } + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + + @Test + def newEmpty: Unit = + val s = EqHashSet[Id]() + assert(s.size == 0) + assert(s.iterator.toList == Nil) + + @Test + def put: Unit = + val s = EqHashSet[Id]() + // put id1 + assert(s.size == 0 && !s.contains(id1)) + s += id1 + assert(s.size == 1 && s.contains(id1)) + // put id2 + assert(!s.contains(id2)) + s.put(id2) + assert(s.size == 2 && s.contains(id1) && s.contains(id2)) + // put id3 + s ++= List(id3) + assert(s.size == 3 && s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def add: Unit = + val s = EqHashSet[Id]() + // add id1 + assert(s.size == 0 && !s.contains(id1)) + val added = s.add(id1) + assert(added && s.size == 1 && s.contains(id1)) + // try add id1 again + val addedAgain = s.add(id1) + assert(!addedAgain && s.size == 1 && s.contains(id1)) // no change + + @Test + def construct: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + assert(s.size == 3) + assert(s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def remove: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + // remove id2 + s.remove(id2) + assert(s.size == 2) + assert(s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id1 + s -= id1 + assert(s.size == 1) + assert(!s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id3 + s --= List(id3) + assert(s.size == 0) + assert(!s.contains(id1) && !s.contains(id2) && !s.contains(id3)) + + @Test + def lookup: Unit = + val s = EqHashSet.from(List(id1, id2)) + assert(s.lookup(id1) eq id1) + assert(s.lookup(id2) eq id2) + assert(s.lookup(id3) eq null) + + @Test + def iterator: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + assert(s.iterator.toList.sorted == List(id1,id2,id3)) + + @Test + def clear: Unit = + locally: + val s1 = EqHashSet.from(List(id1,id2,id3)) + s1.clear() + assert(s1.size == 0) + locally: + val s2 = EqHashSet.from(List(id1,id2,id3)) + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash map is based on reference equality, i.e. does not use universal equality */ + @Test + def referenceEquality: Unit = + val i1, i2 = I32(1) // different instances + + assert(i1.equals(i2)) // structural equality + assert(i1 ne i2) // reference inequality + + val s = EqHashSet.from(List(i1,i2)) + + assert(s.size == 2 && s.contains(i1) && s.contains(i2)) + assert(s.iterator.toSet == Set(i1)) // scala.Set delegates to universal equality + end referenceEquality + diff --git a/compiler/test/dotty/tools/dotc/util/HashMapTest.scala b/compiler/test/dotty/tools/dotc/util/HashMapTest.scala new file mode 100644 index 000000000000..97bf8446756c --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/HashMapTest.scala @@ -0,0 +1,137 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class HashMapTest: + + var counter = 0 + + // structural hash and equality, but with a counter for ordering + class Id(val count: Int = { counter += 1; counter }): + override def hashCode(): Int = count + override def equals(that: Any): Boolean = that match + case that: Id => this.count == that.count + case _ => false + def makeCopy: Id = new Id(count) + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + assert(id1 != id2 && id1 != id3 && id2 != id3) + + @Test + def newEmpty: Unit = + val m = HashMap[Id, Int]() + assert(m.size == 0) + assert(m.iterator.toList == Nil) + + @Test + def update: Unit = + val m = HashMap[Id, Int]() + assert(m.size == 0 && !m.contains(id1)) + m.update(id1, 1) + assert(m.size == 1 && m(id1) == 1) + m.update(id1, 2) // replace value + assert(m.size == 1 && m(id1) == 2) + m.update(id3, 3) // new key + assert(m.size == 2 && m(id1) == 2 && m(id3) == 3) + + @Test + def getOrElseUpdate: Unit = + val m = HashMap[Id, Int]() + // add id1 + assert(m.size == 0 && !m.contains(id1)) + val added = m.getOrElseUpdate(id1, 1) + assert(added == 1 && m.size == 1 && m(id1) == 1) + // try add id1 again + val addedAgain = m.getOrElseUpdate(id1, 23) + assert(addedAgain != 23 && m.size == 1 && m(id1) == 1) // no change + + class StatefulHash: + var hashCount = 0 + override def hashCode(): Int = { hashCount += 1; super.hashCode() } + + @Test + def getOrElseUpdate_hashesAtMostOnce: Unit = + locally: + val sh1 = StatefulHash() + val m = HashMap[StatefulHash, Int]() // will be a dense map with default size + val added = m.getOrElseUpdate(sh1, 1) + assert(sh1.hashCount == 0) // no hashing at all for dense maps + locally: + val sh1 = StatefulHash() + val m = HashMap[StatefulHash, Int](64) // not dense + val added = m.getOrElseUpdate(sh1, 1) + assert(sh1.hashCount == 1) // would be 2 if for example getOrElseUpdate was implemented as lookup + update + + private def fullMap() = + val m = HashMap[Id, Int]() + m.update(id1, 1) + m.update(id2, 2) + m + + @Test + def remove: Unit = + val m = fullMap() + // remove id2 + m.remove(id2) + assert(m.size == 1) + assert(m.contains(id1) && !m.contains(id2)) + // remove id1 + m -= id1 + assert(m.size == 0) + assert(!m.contains(id1) && !m.contains(id2)) + + @Test + def lookup: Unit = + val m = fullMap() + assert(m.lookup(id1) == 1) + assert(m.lookup(id2) == 2) + assert(m.lookup(id3) == null) + + @Test + def iterator: Unit = + val m = fullMap() + assert(m.iterator.toList.sorted == List(id1 -> 1,id2 -> 2)) + + @Test + def clear: Unit = + locally: + val s1 = fullMap() + s1.clear() + assert(s1.size == 0) + locally: + val s2 = fullMap() + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash map is based on universal equality, i.e. does not use reference equality */ + @Test + def universalEquality: Unit = + val id2_2 = id2.makeCopy + + assert(id2.equals(id2_2)) // structural equality + assert(id2 ne id2_2) // reference inequality + + val m = locally: + val m = HashMap[Id, Int]() + m(id2) = 23 + m(id2_2) = 29 + m + + assert(m.size == 1 && m(id2) == 29 && m(id2_2) == 29) + assert(m.keysIterator.toList.head eq id2) // does not replace id2 with id2_2 + end universalEquality + diff --git a/compiler/test/dotty/tools/dotc/util/HashSetTest.scala b/compiler/test/dotty/tools/dotc/util/HashSetTest.scala new file mode 100644 index 000000000000..2089be508a4c --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/HashSetTest.scala @@ -0,0 +1,117 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class HashSetTest: + + var counter = 0 + + // structural hash and equality, with a counter for ordering + class Id(val count: Int = { counter += 1; counter }): + override def hashCode: Int = count + override def equals(that: Any): Boolean = that match + case that: Id => this.count == that.count + case _ => false + def makeCopy: Id = new Id(count) + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + assert(id1 != id2 && id1 != id3 && id2 != id3) + + @Test + def newEmpty: Unit = + val s = HashSet[Id]() + assert(s.size == 0) + assert(s.iterator.toList == Nil) + + @Test + def put: Unit = + val s = HashSet[Id]() + // put id1 + assert(s.size == 0 && !s.contains(id1)) + s += id1 + assert(s.size == 1 && s.contains(id1)) + // put id2 + assert(!s.contains(id2)) + s.put(id2) + assert(s.size == 2 && s.contains(id1) && s.contains(id2)) + // put id3 + s ++= List(id3) + assert(s.size == 3 && s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def add: Unit = + val s = HashSet[Id]() + // add id1 + assert(s.size == 0 && !s.contains(id1)) + val added = s.add(id1) + assert(added && s.size == 1 && s.contains(id1)) + // try add id1 again + val addedAgain = s.add(id1) + assert(!addedAgain && s.size == 1 && s.contains(id1)) // no change + + @Test + def construct: Unit = + val s = HashSet.from(List(id1,id2,id3)) + assert(s.size == 3) + assert(s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def remove: Unit = + val s = HashSet.from(List(id1,id2,id3)) + // remove id2 + s.remove(id2) + assert(s.size == 2) + assert(s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id1 + s -= id1 + assert(s.size == 1) + assert(!s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id3 + s --= List(id3) + assert(s.size == 0) + assert(!s.contains(id1) && !s.contains(id2) && !s.contains(id3)) + + @Test + def lookup: Unit = + val s = HashSet.from(List(id1, id2)) + assert(s.lookup(id1) eq id1) + assert(s.lookup(id2) eq id2) + assert(s.lookup(id3) eq null) + + @Test + def iterator: Unit = + val s = HashSet.from(List(id1,id2,id3)) + assert(s.iterator.toList.sorted == List(id1,id2,id3)) + + @Test + def clear: Unit = + locally: + val s1 = HashSet.from(List(id1,id2,id3)) + s1.clear() + assert(s1.size == 0) + locally: + val s2 = HashSet.from(List(id1,id2,id3)) + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + /** the hash set is based on universal equality, i.e. does not use reference equality */ + @Test + def universalEquality: Unit = + val id2_2 = id2.makeCopy + + assert(id2.equals(id2_2)) // structural equality + assert(id2 ne id2_2) // reference inequality + + val s = HashSet.from(List(id2,id2_2)) + + assert(s.size == 1 && s.contains(id2) && s.contains(id2_2)) + assert(s.iterator.toList == List(id2)) // single element + end universalEquality +