From b18b7442d1eb01230d1e7de2622950e0c1491b0a Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 2 Aug 2023 16:47:24 +0200 Subject: [PATCH 1/9] faster class dependency cache --- .../tools/dotc/sbt/ExtractDependencies.scala | 78 ++++++++++++------- .../dotty/tools/dotc/typer/Synthesizer.scala | 1 - 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index 01a3362c659a..d7d3678a3298 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -26,6 +26,8 @@ import xsbti.UseScope import xsbti.api.DependencyContext import xsbti.api.DependencyContext._ +import scala.jdk.CollectionConverters.* + import scala.collection.{Set, mutable} @@ -74,7 +76,11 @@ class ExtractDependencies extends Phase { collector.traverse(unit.tpdTree) if (ctx.settings.YdumpSbtInc.value) { - val deps = rec.classDependencies.map(_.toString).toArray[Object] + val deps = rec.classDependencies.flatMap((k,vs) => + vs.iterator.flatMap((to, depCtxs) => + depCtxs.asScala.map(depCtx => s"ClassDependency($k, $to, $depCtx)") + ) + ).toArray[Object] val names = rec.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object] Arrays.sort(deps) Arrays.sort(names) @@ -265,7 +271,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. // Avoid cycles by remembering both the types (testcase: // tests/run/enum-values.scala) and the symbols of named types (testcase: // tests/pos-java-interop/i13575) we've seen before. - val seen = new mutable.HashSet[Symbol | Type] + val seen = new util.HashSet[Symbol | Type](64) def traverse(tp: Type): Unit = if (!seen.contains(tp)) { seen += tp tp match { @@ -306,7 +312,15 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. } } -case class ClassDependency(fromClass: Symbol, toClass: Symbol, context: DependencyContext) +class ClassDepsInClass: + private val _classes = util.EqHashMap[Symbol, EnumSet[DependencyContext]]() + + def addDependency(fromClass: Symbol, context: DependencyContext): Unit = + val set = _classes.getOrElseUpdate(fromClass, EnumSet.noneOf(classOf[DependencyContext])) + set.add(context) + + def iterator: Iterator[(Symbol, EnumSet[DependencyContext])] = + _classes.iterator /** Record dependencies using `addUsedName`/`addClassDependency` and inform Zinc using `sendToZinc()`. * @@ -355,10 +369,9 @@ class DependencyRecorder { * safely. */ def addUsedRawName(name: Name, includeSealedChildren: Boolean = false)(using Context): Unit = { - val fromClass = resolveDependencySource + val fromClass = resolveDependencyFromClass if (fromClass.exists) { - val usedName = _usedNames.getOrElseUpdate(fromClass, new UsedNamesInClass) - usedName.update(name, includeSealedChildren) + lastUsedCache.update(name, includeSealedChildren) } } @@ -373,9 +386,9 @@ class DependencyRecorder { * of the associated value, see the documentation of parameter `includeSealedChildren` * of `addUsedRawName`. */ - private val _names = new mutable.HashMap[Name, DefaultScopes.type | PatMatScopes.type] + private val _names = new util.HashMap[Name, DefaultScopes.type | PatMatScopes.type] - def names: collection.Map[Name, EnumSet[UseScope]] = _names + def iterator: Iterator[(Name, EnumSet[UseScope])] = _names.iterator private[DependencyRecorder] def update(name: Name, includeSealedChildren: Boolean): Unit = { if (includeSealedChildren) @@ -386,7 +399,7 @@ class DependencyRecorder { override def toString(): String = { val builder = new StringBuilder - names.foreach { case (name, scopes) => + iterator.foreach { (name, scopes) => builder.append(name.mangledString) builder.append(" in [") scopes.forEach(scope => builder.append(scope.toString)) @@ -398,17 +411,17 @@ class DependencyRecorder { } - private val _classDependencies = new mutable.HashSet[ClassDependency] + private val _classDependencies = new mutable.HashMap[Symbol, ClassDepsInClass] - def classDependencies: Set[ClassDependency] = _classDependencies + def classDependencies: collection.Map[Symbol, ClassDepsInClass] = _classDependencies /** Record a dependency to the class `to` in a given `context` * from the current non-local enclosing class. */ def addClassDependency(toClass: Symbol, context: DependencyContext)(using Context): Unit = - val fromClass = resolveDependencySource + val fromClass = resolveDependencyFromClass if (fromClass.exists) - _classDependencies += ClassDependency(fromClass, toClass, context) + lastDepCache.addDependency(toClass, context) private val _usedNames = new mutable.HashMap[Symbol, UsedNamesInClass] @@ -418,11 +431,13 @@ class DependencyRecorder { usedNames.foreach: case (clazz, usedNames) => val className = classNameAsString(clazz) - usedNames.names.foreach: - case (usedName, scopes) => - cb.usedName(className, usedName.toString, scopes) + usedNames.iterator.foreach: (usedName, scopes) => + cb.usedName(className, usedName.toString, scopes) val siblingClassfiles = new mutable.HashMap[PlainFile, Path] - classDependencies.foreach(recordClassDependency(cb, _, siblingClassfiles)) + for (fromClass, partialDependencies) <- _classDependencies do + for (toClass, deps) <- partialDependencies.iterator do + for dep <- deps.asScala do + recordClassDependency(cb, fromClass, toClass, dep, siblingClassfiles) clear() /** Clear all state. */ @@ -431,15 +446,17 @@ class DependencyRecorder { _classDependencies.clear() lastOwner = NoSymbol lastDepSource = NoSymbol + lastDepCache = null + lastUsedCache = null _responsibleForImports = NoSymbol /** Handles dependency on given symbol by trying to figure out if represents a term * that is coming from either source code (not necessarily compiled in this compilation * run) or from class file and calls respective callback method. */ - private def recordClassDependency(cb: interfaces.IncrementalCallback, dep: ClassDependency, - siblingClassfiles: mutable.Map[PlainFile, Path])(using Context): Unit = { - val fromClassName = classNameAsString(dep.fromClass) + private def recordClassDependency(cb: interfaces.IncrementalCallback, fromClass: Symbol, toClass: Symbol, + depCtx: DependencyContext, siblingClassfiles: mutable.Map[PlainFile, Path])(using Context): Unit = { + val fromClassName = classNameAsString(fromClass) val sourceFile = ctx.compilationUnit.source /**For a `.tasty` file, constructs a sibling class to the `jpath`. @@ -465,13 +482,13 @@ class DependencyRecorder { }) def binaryDependency(path: Path, binaryClassName: String) = - cb.binaryDependency(path, binaryClassName, fromClassName, sourceFile, dep.context) + cb.binaryDependency(path, binaryClassName, fromClassName, sourceFile, depCtx) - val depClass = dep.toClass + val depClass = toClass val depFile = depClass.associatedFile if depFile != null then { // Cannot ignore inheritance relationship coming from the same source (see sbt/zinc#417) - def allowLocal = dep.context == DependencyByInheritance || dep.context == LocalDependencyByInheritance + def allowLocal = depCtx == DependencyByInheritance || depCtx == LocalDependencyByInheritance val isTasty = depFile.hasTastyExtension def processExternalDependency() = { @@ -485,7 +502,7 @@ class DependencyRecorder { case pf: PlainFile => // The dependency comes from a class file, Zinc handles JRT filesystem binaryDependency(if isTasty then cachedSiblingClass(pf) else pf.jpath, binaryClassName) case _ => - internalError(s"Ignoring dependency $depFile of unknown class ${depFile.getClass}}", dep.fromClass.srcPos) + internalError(s"Ignoring dependency $depFile of unknown class ${depFile.getClass}}", fromClass.srcPos) } } @@ -495,23 +512,30 @@ class DependencyRecorder { // We cannot ignore dependencies coming from the same source file because // the dependency info needs to propagate. See source-dependencies/trait-trait-211. val toClassName = classNameAsString(depClass) - cb.classDependency(toClassName, fromClassName, dep.context) + cb.classDependency(toClassName, fromClassName, depCtx) } } private var lastOwner: Symbol = _ private var lastDepSource: Symbol = _ + private var lastDepCache: ClassDepsInClass | Null = _ + private var lastUsedCache: UsedNamesInClass | Null = _ /** The source of the dependency according to `nonLocalEnclosingClass` * if it exists, otherwise fall back to `responsibleForImports`. * * This is backed by a cache which is invalidated when `ctx.owner` changes. */ - private def resolveDependencySource(using Context): Symbol = { + private def resolveDependencyFromClass(using Context): Symbol = { + import dotty.tools.uncheckedNN if (lastOwner != ctx.owner) { lastOwner = ctx.owner val source = nonLocalEnclosingClass - lastDepSource = if (source.is(PackageClass)) responsibleForImports else source + val fromClass = if (source.is(PackageClass)) responsibleForImports else source + if lastDepSource != fromClass then + lastDepSource = fromClass + lastDepCache = _classDependencies.getOrElseUpdate(fromClass, new ClassDepsInClass) + lastUsedCache = _usedNames.getOrElseUpdate(fromClass, new UsedNamesInClass) } lastDepSource diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index cbb13a841946..35bb36b003f9 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -20,7 +20,6 @@ import annotation.{tailrec, constructorOnly} import ast.tpd._ import Synthesizer._ import sbt.ExtractDependencies.* -import sbt.ClassDependency import xsbti.api.DependencyContext._ /** Synthesize terms for special classes */ From d20c6246504fb4ea5595821533dbfc43ac483775 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Thu, 3 Aug 2023 13:46:32 +0200 Subject: [PATCH 2/9] add EqHashSet --- .../src/dotty/tools/dotc/util/EqHashSet.scala | 136 +++++++++++++ .../tools/dotc/util/GenericHashSet.scala | 191 ++++++++++++++++++ .../src/dotty/tools/dotc/util/HashSet.scala | 76 +------ .../dotty/tools/dotc/util/MutableSet.scala | 7 + 4 files changed, 343 insertions(+), 67 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/util/EqHashSet.scala create mode 100644 compiler/src/dotty/tools/dotc/util/GenericHashSet.scala diff --git a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala new file mode 100644 index 000000000000..42aee97ce79c --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala @@ -0,0 +1,136 @@ +package dotty.tools.dotc.util + +import dotty.tools.uncheckedNN + +object EqHashSet: + + def from[T](xs: IterableOnce[T]): EqHashSet[T] = + val set = new EqHashSet[T]() + set ++= xs + set + +/** A hash set that allows some privileged protected access to its internals + * @param initialCapacity Indicates the initial number of slots in the hash table. + * The actual number of slots is always a power of 2, so the + * initial size of the table will be the smallest power of two + * that is equal or greater than the given `initialCapacity`. + * Minimum value is 4. +* @param capacityMultiple The minimum multiple of capacity relative to used elements. + * The hash table will be re-sized once the number of elements + * multiplied by capacityMultiple exceeds the current size of the hash table. + * However, a table of size up to DenseLimit will be re-sized only + * once the number of elements reaches the table's size. + */ +class EqHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends GenericHashSet[T](initialCapacity, capacityMultiple) { + import GenericHashSet.DenseLimit + + /** System's identity hashcode left shifted by 1 */ + final def hash(key: T): Int = + System.identityHashCode(key) << 1 + + /** reference equality */ + final def isEqual(x: T, y: T): Boolean = x.asInstanceOf[AnyRef] eq y.asInstanceOf[AnyRef] + + /** Turn hashcode `x` into a table index */ + private def index(x: Int): Int = x & (table.length - 1) + + private def firstIndex(x: T) = if isDense then 0 else index(hash(x)) + private def nextIndex(idx: Int) = + Stats.record(statsItem("miss")) + index(idx + 1) + + private def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] + private def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] + + override def lookup(x: T): T | Null = + Stats.record(statsItem("lookup")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return e + idx = nextIndex(idx) + e = entryAt(idx) + null + + /** Add entry at `x` at index `idx` */ + private def addEntryAt(idx: Int, x: T): T = + Stats.record(statsItem("addEntryAt")) + setEntry(idx, x) + used += 1 + if used > limit then growTable() + x + + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. */ + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return false // already entered + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + true // first entry + + override def put(x: T): T = + Stats.record(statsItem("put")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + + override def +=(x: T): Unit = put(x) + + override def remove(x: T): Boolean = + Stats.record(statsItem("remove")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then + var hole = idx + while + idx = nextIndex(idx) + e = entryAt(idx) + e != null + do + val eidx = index(hash(e.uncheckedNN)) + if isDense + || index(eidx - (hole + 1)) > index(idx - (hole + 1)) + // entry `e` at `idx` can move unless `index(hash(e))` is in + // the (ring-)interval [hole + 1 .. idx] + then + setEntry(hole, e.uncheckedNN) + hole = idx + table(hole) = null + used -= 1 + return true + idx = nextIndex(idx) + e = entryAt(idx) + false + + override def -=(x: T): Unit = + remove(x) + + private def addOld(x: T) = + Stats.record(statsItem("re-enter")) + var idx = firstIndex(x) + var e = entryAt(idx) + while e != null do + idx = nextIndex(idx) + e = entryAt(idx) + setEntry(idx, x) + + override def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + if isDense then + Array.copy(oldTable, 0, table, 0, oldTable.length) + else + var idx = 0 + while idx < oldTable.length do + val e: T | Null = oldTable(idx).asInstanceOf[T | Null] + if e != null then addOld(e.uncheckedNN) + idx += 1 +} diff --git a/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala new file mode 100644 index 000000000000..704298e55fb7 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala @@ -0,0 +1,191 @@ +package dotty.tools.dotc.util + +import dotty.tools.uncheckedNN + +object GenericHashSet: + + /** The number of elements up to which dense packing is used. + * If the number of elements reaches `DenseLimit` a hash table is used instead + */ + inline val DenseLimit = 8 + +/** A hash set that allows some privileged protected access to its internals + * @param initialCapacity Indicates the initial number of slots in the hash table. + * The actual number of slots is always a power of 2, so the + * initial size of the table will be the smallest power of two + * that is equal or greater than the given `initialCapacity`. + * Minimum value is 4. +* @param capacityMultiple The minimum multiple of capacity relative to used elements. + * The hash table will be re-sized once the number of elements + * multiplied by capacityMultiple exceeds the current size of the hash table. + * However, a table of size up to DenseLimit will be re-sized only + * once the number of elements reaches the table's size. + */ +abstract class GenericHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends MutableSet[T] { + import GenericHashSet.DenseLimit + + protected var used: Int = _ + protected var limit: Int = _ + protected var table: Array[AnyRef | Null] = _ + + clear() + + private def allocate(capacity: Int) = + table = new Array[AnyRef | Null](capacity) + limit = if capacity <= DenseLimit then capacity - 1 else capacity / capacityMultiple + + private def roundToPower(n: Int) = + if n < 4 then 4 + else if Integer.bitCount(n) == 1 then n + else 1 << (32 - Integer.numberOfLeadingZeros(n)) + + def clear(resetToInitial: Boolean): Unit = + used = 0 + if resetToInitial then allocate(roundToPower(initialCapacity)) + else java.util.Arrays.fill(table, null) + + /** The number of elements in the set */ + def size: Int = used + + protected def isDense = limit < DenseLimit + + /** Hashcode, by default a processed `x.hashCode`, can be overridden */ + protected def hash(key: T): Int + + /** Hashcode, by default `equals`, can be overridden */ + protected def isEqual(x: T, y: T): Boolean + + /** Turn hashcode `x` into a table index */ + private def index(x: Int): Int = x & (table.length - 1) + + protected def currentTable: Array[AnyRef | Null] = table + + private def firstIndex(x: T) = if isDense then 0 else index(hash(x)) + private def nextIndex(idx: Int) = + Stats.record(statsItem("miss")) + index(idx + 1) + + private def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] + private def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] + + def lookup(x: T): T | Null = + Stats.record(statsItem("lookup")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return e + idx = nextIndex(idx) + e = entryAt(idx) + null + + /** Add entry at `x` at index `idx` */ + private def addEntryAt(idx: Int, x: T): T = + Stats.record(statsItem("addEntryAt")) + setEntry(idx, x) + used += 1 + if used > limit then growTable() + x + + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. */ + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then return false // already entered + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + true // first entry + + def put(x: T): T = + Stats.record(statsItem("put")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + idx = nextIndex(idx) + e = entryAt(idx) + addEntryAt(idx, x) + + def +=(x: T): Unit = put(x) + + def remove(x: T): Boolean = + Stats.record(statsItem("remove")) + var idx = firstIndex(x) + var e: T | Null = entryAt(idx) + while e != null do + if isEqual(e.uncheckedNN, x) then + var hole = idx + while + idx = nextIndex(idx) + e = entryAt(idx) + e != null + do + val eidx = index(hash(e.uncheckedNN)) + if isDense + || index(eidx - (hole + 1)) > index(idx - (hole + 1)) + // entry `e` at `idx` can move unless `index(hash(e))` is in + // the (ring-)interval [hole + 1 .. idx] + then + setEntry(hole, e.uncheckedNN) + hole = idx + table(hole) = null + used -= 1 + return true + idx = nextIndex(idx) + e = entryAt(idx) + false + + def -=(x: T): Unit = + remove(x) + + private def addOld(x: T) = + Stats.record(statsItem("re-enter")) + var idx = firstIndex(x) + var e = entryAt(idx) + while e != null do + idx = nextIndex(idx) + e = entryAt(idx) + setEntry(idx, x) + + def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + if isDense then + Array.copy(oldTable, 0, table, 0, oldTable.length) + else + var idx = 0 + while idx < oldTable.length do + val e: T | Null = oldTable(idx).asInstanceOf[T | Null] + if e != null then addOld(e.uncheckedNN) + idx += 1 + + protected def growTable(): Unit = + val oldTable = table + val newLength = + if oldTable.length == DenseLimit then DenseLimit * 2 * roundToPower(capacityMultiple) + else table.length * 2 + allocate(newLength) + copyFrom(oldTable) + + abstract class EntryIterator extends Iterator[T]: + def entry(idx: Int): T | Null + private var idx = 0 + def hasNext = + while idx < table.length && table(idx) == null do idx += 1 + idx < table.length + def next() = + require(hasNext) + try entry(idx).uncheckedNN finally idx += 1 + + def iterator: Iterator[T] = new EntryIterator(): + def entry(idx: Int) = entryAt(idx) + + override def toString: String = + iterator.mkString("HashSet(", ", ", ")") + + protected def statsItem(op: String) = + val prefix = if isDense then "HashSet(dense)." else "HashSet." + val suffix = getClass.getSimpleName + s"$prefix$op $suffix" +} diff --git a/compiler/src/dotty/tools/dotc/util/HashSet.scala b/compiler/src/dotty/tools/dotc/util/HashSet.scala index a6e1532c804f..e8cabd13a097 100644 --- a/compiler/src/dotty/tools/dotc/util/HashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/HashSet.scala @@ -4,11 +4,6 @@ import dotty.tools.uncheckedNN object HashSet: - /** The number of elements up to which dense packing is used. - * If the number of elements reaches `DenseLimit` a hash table is used instead - */ - inline val DenseLimit = 8 - def from[T](xs: IterableOnce[T]): HashSet[T] = val set = new HashSet[T]() set ++= xs @@ -26,33 +21,8 @@ object HashSet: * However, a table of size up to DenseLimit will be re-sized only * once the number of elements reaches the table's size. */ -class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends MutableSet[T] { - import HashSet.DenseLimit - - private var used: Int = _ - private var limit: Int = _ - private var table: Array[AnyRef | Null] = _ - - clear() - - private def allocate(capacity: Int) = - table = new Array[AnyRef | Null](capacity) - limit = if capacity <= DenseLimit then capacity - 1 else capacity / capacityMultiple - - private def roundToPower(n: Int) = - if n < 4 then 4 - else if Integer.bitCount(n) == 1 then n - else 1 << (32 - Integer.numberOfLeadingZeros(n)) - - def clear(resetToInitial: Boolean): Unit = - used = 0 - if resetToInitial then allocate(roundToPower(initialCapacity)) - else java.util.Arrays.fill(table, null) - - /** The number of elements in the set */ - def size: Int = used - - protected def isDense = limit < DenseLimit +class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends GenericHashSet[T](initialCapacity, capacityMultiple) { + import GenericHashSet.DenseLimit /** Hashcode, by default a processed `x.hashCode`, can be overridden */ protected def hash(key: T): Int = @@ -68,8 +38,6 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu /** Turn hashcode `x` into a table index */ protected def index(x: Int): Int = x & (table.length - 1) - protected def currentTable: Array[AnyRef | Null] = table - protected def firstIndex(x: T) = if isDense then 0 else index(hash(x)) protected def nextIndex(idx: Int) = Stats.record(statsItem("miss")) @@ -78,7 +46,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu protected def entryAt(idx: Int): T | Null = table(idx).asInstanceOf[T | Null] protected def setEntry(idx: Int, x: T) = table(idx) = x.asInstanceOf[AnyRef | Null] - def lookup(x: T): T | Null = + override def lookup(x: T): T | Null = Stats.record(statsItem("lookup")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) @@ -96,7 +64,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu if used > limit then growTable() x - def put(x: T): T = + override def put(x: T): T = Stats.record(statsItem("put")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) @@ -107,9 +75,9 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu e = entryAt(idx) addEntryAt(idx, x) - def +=(x: T): Unit = put(x) + override def +=(x: T): Unit = put(x) - def remove(x: T): Boolean = + override def remove(x: T): Boolean = Stats.record(statsItem("remove")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) @@ -136,7 +104,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu e = entryAt(idx) false - def -=(x: T): Unit = + override def -=(x: T): Unit = remove(x) private def addOld(x: T) = @@ -148,7 +116,7 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu e = entryAt(idx) setEntry(idx, x) - def copyFrom(oldTable: Array[AnyRef | Null]): Unit = + override def copyFrom(oldTable: Array[AnyRef | Null]): Unit = if isDense then Array.copy(oldTable, 0, table, 0, oldTable.length) else @@ -158,32 +126,6 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Mu if e != null then addOld(e.uncheckedNN) idx += 1 - protected def growTable(): Unit = - val oldTable = table - val newLength = - if oldTable.length == DenseLimit then DenseLimit * 2 * roundToPower(capacityMultiple) - else table.length * 2 - allocate(newLength) - copyFrom(oldTable) - - abstract class EntryIterator extends Iterator[T]: - def entry(idx: Int): T | Null - private var idx = 0 - def hasNext = - while idx < table.length && table(idx) == null do idx += 1 - idx < table.length - def next() = - require(hasNext) - try entry(idx).uncheckedNN finally idx += 1 - - def iterator: Iterator[T] = new EntryIterator(): + override def iterator: Iterator[T] = new EntryIterator(): def entry(idx: Int) = entryAt(idx) - - override def toString: String = - iterator.mkString("HashSet(", ", ", ")") - - protected def statsItem(op: String) = - val prefix = if isDense then "HashSet(dense)." else "HashSet." - val suffix = getClass.getSimpleName - s"$prefix$op $suffix" } diff --git a/compiler/src/dotty/tools/dotc/util/MutableSet.scala b/compiler/src/dotty/tools/dotc/util/MutableSet.scala index 9529262fa5ec..05fd57a50e71 100644 --- a/compiler/src/dotty/tools/dotc/util/MutableSet.scala +++ b/compiler/src/dotty/tools/dotc/util/MutableSet.scala @@ -7,6 +7,13 @@ abstract class MutableSet[T] extends ReadOnlySet[T]: /** Add element `x` to the set */ def +=(x: T): Unit + /** attempts to put `x` in the Set, if it was not entered before, return true, else return false. + * Overridden in GenericHashSet. + */ + def add(x: T): Boolean = + if lookup(x) == null then { this += x; true } + else false + /** Like `+=` but return existing element equal to `x` of it exists, * `x` itself otherwise. */ From 3702fe9a32b3fb6651155d21326ca66e1040a88c Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Thu, 3 Aug 2023 13:46:55 +0200 Subject: [PATCH 3/9] use EqHashSet in extractDependencies --- .../src/dotty/tools/dotc/sbt/ExtractDependencies.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index d7d3678a3298..65203bc8cc7f 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -271,14 +271,12 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. // Avoid cycles by remembering both the types (testcase: // tests/run/enum-values.scala) and the symbols of named types (testcase: // tests/pos-java-interop/i13575) we've seen before. - val seen = new util.HashSet[Symbol | Type](64) - def traverse(tp: Type): Unit = if (!seen.contains(tp)) { - seen += tp + val seen = new util.EqHashSet[Symbol | Type](128) // 64 still needs to grow often for scala3-compiler + def traverse(tp: Type): Unit = if seen.add(tp) then { tp match { case tp: NamedType => val sym = tp.symbol - if !seen.contains(sym) && !sym.is(Package) then - seen += sym + if !sym.is(Package) && seen.add(sym) then addDependency(sym) if !sym.isClass then traverse(tp.info) traverse(tp.prefix) From 992f200ace402412946d855b320ef9e969b7e0a7 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Thu, 3 Aug 2023 14:26:05 +0200 Subject: [PATCH 4/9] use scratch type dependencies set --- .../tools/dotc/sbt/ExtractDependencies.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index 65203bc8cc7f..b3162a309a40 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -232,6 +232,13 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. throw ex } + /**Reused EqHashSet, safe to use as each TypeDependencyTraverser is used atomically + * Avoid cycles by remembering both the types (testcase: + * tests/run/enum-values.scala) and the symbols of named types (testcase: + * tests/pos-java-interop/i13575) we've seen before. + */ + private val scratchSeen = new util.EqHashSet[Symbol | Type](128) + /** Traverse a used type and record all the dependencies we need to keep track * of for incremental recompilation. * @@ -268,15 +275,13 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. private abstract class TypeDependencyTraverser(using Context) extends TypeTraverser() { protected def addDependency(symbol: Symbol): Unit - // Avoid cycles by remembering both the types (testcase: - // tests/run/enum-values.scala) and the symbols of named types (testcase: - // tests/pos-java-interop/i13575) we've seen before. - val seen = new util.EqHashSet[Symbol | Type](128) // 64 still needs to grow often for scala3-compiler - def traverse(tp: Type): Unit = if seen.add(tp) then { + scratchSeen.clear(resetToInitial = false) + + def traverse(tp: Type): Unit = if scratchSeen.add(tp) then { tp match { case tp: NamedType => val sym = tp.symbol - if !sym.is(Package) && seen.add(sym) then + if !sym.is(Package) && scratchSeen.add(sym) then addDependency(sym) if !sym.isClass then traverse(tp.info) traverse(tp.prefix) From cdd353c3c7218fc8203ed56f1eca7d512ba7a3f8 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 4 Aug 2023 23:27:41 +0200 Subject: [PATCH 5/9] optimise getOrElseUpdate --- .../src/dotty/tools/dotc/util/EqHashMap.scala | 16 +++++++++++++++ .../tools/dotc/util/GenericHashMap.scala | 20 +++++++++++++------ .../src/dotty/tools/dotc/util/HashMap.scala | 16 +++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/util/EqHashMap.scala b/compiler/src/dotty/tools/dotc/util/EqHashMap.scala index ea049acba02b..25d9fb2907b8 100644 --- a/compiler/src/dotty/tools/dotc/util/EqHashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/EqHashMap.scala @@ -58,6 +58,22 @@ extends GenericHashMap[Key, Value](initialCapacity, capacityMultiple): used += 1 if used > limit then growTable() + override def getOrElseUpdate(key: Key, value: => Value): Value = + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v + private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) var idx = firstIndex(key) diff --git a/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala b/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala index a21a4af37038..6d013717ec52 100644 --- a/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/GenericHashMap.scala @@ -129,12 +129,20 @@ abstract class GenericHashMap[Key, Value] null def getOrElseUpdate(key: Key, value: => Value): Value = - var v: Value | Null = lookup(key) - if v == null then - val v1 = value - v = v1 - update(key, v1) - v.uncheckedNN + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) diff --git a/compiler/src/dotty/tools/dotc/util/HashMap.scala b/compiler/src/dotty/tools/dotc/util/HashMap.scala index aaae781c310a..eec3a604b5e2 100644 --- a/compiler/src/dotty/tools/dotc/util/HashMap.scala +++ b/compiler/src/dotty/tools/dotc/util/HashMap.scala @@ -63,6 +63,22 @@ extends GenericHashMap[Key, Value](initialCapacity, capacityMultiple): used += 1 if used > limit then growTable() + override def getOrElseUpdate(key: Key, value: => Value): Value = + // created by blending lookup and update, avoid having to recompute hash and probe + Stats.record(statsItem("lookup-or-update")) + var idx = firstIndex(key) + var k = keyAt(idx) + while k != null do + if isEqual(k, key) then return valueAt(idx) + idx = nextIndex(idx) + k = keyAt(idx) + val v = value + setKey(idx, key) + setValue(idx, v) + used += 1 + if used > limit then growTable() + v + private def addOld(key: Key, value: Value): Unit = Stats.record(statsItem("re-enter")) var idx = firstIndex(key) From 4bffbc5e35d791c70542070cf9f576ea12b2743e Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Tue, 15 Aug 2023 16:36:27 +0200 Subject: [PATCH 6/9] merge usednames and classdeps caches --- .../tools/dotc/sbt/ExtractDependencies.scala | 96 +++++++++---------- 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index b3162a309a40..e3b5f375f585 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -76,12 +76,8 @@ class ExtractDependencies extends Phase { collector.traverse(unit.tpdTree) if (ctx.settings.YdumpSbtInc.value) { - val deps = rec.classDependencies.flatMap((k,vs) => - vs.iterator.flatMap((to, depCtxs) => - depCtxs.asScala.map(depCtx => s"ClassDependency($k, $to, $depCtx)") - ) - ).toArray[Object] - val names = rec.usedNames.map { case (clazz, names) => s"$clazz: $names" }.toArray[Object] + val deps = rec.foundDeps.map { case (clazz, found) => s"$clazz: ${found.classesString}" }.toArray[Object] + val names = rec.foundDeps.map { case (clazz, found) => s"$clazz: ${found.namesString}" }.toArray[Object] Arrays.sort(deps) Arrays.sort(names) @@ -168,7 +164,7 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. /** Traverse the tree of a source file and record the dependencies and used names which - * can be retrieved using `dependencies` and`usedNames`. + * can be retrieved using `foundDeps`. */ override def traverse(tree: Tree)(using Context): Unit = try { tree match { @@ -315,16 +311,6 @@ private class ExtractDependenciesCollector(rec: DependencyRecorder) extends tpd. } } -class ClassDepsInClass: - private val _classes = util.EqHashMap[Symbol, EnumSet[DependencyContext]]() - - def addDependency(fromClass: Symbol, context: DependencyContext): Unit = - val set = _classes.getOrElseUpdate(fromClass, EnumSet.noneOf(classOf[DependencyContext])) - set.add(context) - - def iterator: Iterator[(Symbol, EnumSet[DependencyContext])] = - _classes.iterator - /** Record dependencies using `addUsedName`/`addClassDependency` and inform Zinc using `sendToZinc()`. * * Note: As an alternative design choice, we could directly call the appropriate @@ -336,10 +322,10 @@ class ClassDepsInClass: class DependencyRecorder { import ExtractDependencies.* - /** A map from a non-local class to the names it uses, this does not include + /** A map from a non-local class to the names and classes it uses, this does not include * names which are only defined and not referenced. */ - def usedNames: collection.Map[Symbol, UsedNamesInClass] = _usedNames + def foundDeps: collection.Map[Symbol, FoundDepsInClass] = _foundDeps /** Record a reference to the name of `sym` from the current non-local * enclosing class. @@ -374,7 +360,7 @@ class DependencyRecorder { def addUsedRawName(name: Name, includeSealedChildren: Boolean = false)(using Context): Unit = { val fromClass = resolveDependencyFromClass if (fromClass.exists) { - lastUsedCache.update(name, includeSealedChildren) + lastFoundCache.recordName(name, includeSealedChildren) } } @@ -383,26 +369,36 @@ class DependencyRecorder { private val DefaultScopes = EnumSet.of(UseScope.Default) private val PatMatScopes = EnumSet.of(UseScope.Default, UseScope.PatMatTarget) - /** An object that maintain the set of used names from within a class */ - final class UsedNamesInClass { + /** An object that maintain the set of used names and class dependencies from within a class */ + final class FoundDepsInClass { /** Each key corresponds to a name used in the class. To understand the meaning * of the associated value, see the documentation of parameter `includeSealedChildren` * of `addUsedRawName`. */ private val _names = new util.HashMap[Name, DefaultScopes.type | PatMatScopes.type] - def iterator: Iterator[(Name, EnumSet[UseScope])] = _names.iterator + /** Each key corresponds to a class dependency used in the class. + */ + private val _classes = util.EqHashMap[Symbol, EnumSet[DependencyContext]]() + + def addDependency(fromClass: Symbol, context: DependencyContext): Unit = + val set = _classes.getOrElseUpdate(fromClass, EnumSet.noneOf(classOf[DependencyContext])) + set.add(context) + + def classes: Iterator[(Symbol, EnumSet[DependencyContext])] = _classes.iterator - private[DependencyRecorder] def update(name: Name, includeSealedChildren: Boolean): Unit = { + def names: Iterator[(Name, EnumSet[UseScope])] = _names.iterator + + private[DependencyRecorder] def recordName(name: Name, includeSealedChildren: Boolean): Unit = { if (includeSealedChildren) _names(name) = PatMatScopes else _names.getOrElseUpdate(name, DefaultScopes) } - override def toString(): String = { + def namesString: String = { val builder = new StringBuilder - iterator.foreach { (name, scopes) => + names.foreach { case (name, scopes) => builder.append(name.mangledString) builder.append(" in [") scopes.forEach(scope => builder.append(scope.toString)) @@ -411,12 +407,19 @@ class DependencyRecorder { } builder.toString() } - } - - private val _classDependencies = new mutable.HashMap[Symbol, ClassDepsInClass] - - def classDependencies: collection.Map[Symbol, ClassDepsInClass] = _classDependencies + def classesString: String = { + val builder = new StringBuilder + classes.foreach { case (clazz, scopes) => + builder.append(clazz.toString) + builder.append(" in [") + scopes.forEach(scope => builder.append(scope.toString)) + builder.append("]") + builder.append(", ") + } + builder.toString() + } + } /** Record a dependency to the class `to` in a given `context` * from the current non-local enclosing class. @@ -424,33 +427,30 @@ class DependencyRecorder { def addClassDependency(toClass: Symbol, context: DependencyContext)(using Context): Unit = val fromClass = resolveDependencyFromClass if (fromClass.exists) - lastDepCache.addDependency(toClass, context) + lastFoundCache.addDependency(toClass, context) - private val _usedNames = new mutable.HashMap[Symbol, UsedNamesInClass] + private val _foundDeps = new mutable.HashMap[Symbol, FoundDepsInClass] /** Send the collected dependency information to Zinc and clear the local caches. */ def sendToZinc()(using Context): Unit = ctx.withIncCallback: cb => - usedNames.foreach: - case (clazz, usedNames) => + val siblingClassfiles = new mutable.HashMap[PlainFile, Path] + foundDeps.foreach: + case (clazz, foundDeps) => val className = classNameAsString(clazz) - usedNames.iterator.foreach: (usedName, scopes) => + foundDeps.names.foreach: (usedName, scopes) => cb.usedName(className, usedName.toString, scopes) - val siblingClassfiles = new mutable.HashMap[PlainFile, Path] - for (fromClass, partialDependencies) <- _classDependencies do - for (toClass, deps) <- partialDependencies.iterator do - for dep <- deps.asScala do - recordClassDependency(cb, fromClass, toClass, dep, siblingClassfiles) + for (toClass, deps) <- foundDeps.classes do + for dep <- deps.asScala do + recordClassDependency(cb, clazz, toClass, dep, siblingClassfiles) clear() /** Clear all state. */ def clear(): Unit = - _usedNames.clear() - _classDependencies.clear() + _foundDeps.clear() lastOwner = NoSymbol lastDepSource = NoSymbol - lastDepCache = null - lastUsedCache = null + lastFoundCache = null _responsibleForImports = NoSymbol /** Handles dependency on given symbol by trying to figure out if represents a term @@ -521,8 +521,7 @@ class DependencyRecorder { private var lastOwner: Symbol = _ private var lastDepSource: Symbol = _ - private var lastDepCache: ClassDepsInClass | Null = _ - private var lastUsedCache: UsedNamesInClass | Null = _ + private var lastFoundCache: FoundDepsInClass | Null = _ /** The source of the dependency according to `nonLocalEnclosingClass` * if it exists, otherwise fall back to `responsibleForImports`. @@ -537,8 +536,7 @@ class DependencyRecorder { val fromClass = if (source.is(PackageClass)) responsibleForImports else source if lastDepSource != fromClass then lastDepSource = fromClass - lastDepCache = _classDependencies.getOrElseUpdate(fromClass, new ClassDepsInClass) - lastUsedCache = _usedNames.getOrElseUpdate(fromClass, new UsedNamesInClass) + lastFoundCache = _foundDeps.getOrElseUpdate(fromClass, new FoundDepsInClass) } lastDepSource From 04fb40db4f6400956a16e77a3cb21755eb920a3c Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Tue, 15 Aug 2023 17:16:11 +0200 Subject: [PATCH 7/9] main cache is util.HashMap --- .../src/dotty/tools/dotc/sbt/ExtractDependencies.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala index e3b5f375f585..bd3ab4e3ae0f 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractDependencies.scala @@ -76,8 +76,8 @@ class ExtractDependencies extends Phase { collector.traverse(unit.tpdTree) if (ctx.settings.YdumpSbtInc.value) { - val deps = rec.foundDeps.map { case (clazz, found) => s"$clazz: ${found.classesString}" }.toArray[Object] - val names = rec.foundDeps.map { case (clazz, found) => s"$clazz: ${found.namesString}" }.toArray[Object] + val deps = rec.foundDeps.iterator.map { case (clazz, found) => s"$clazz: ${found.classesString}" }.toArray[Object] + val names = rec.foundDeps.iterator.map { case (clazz, found) => s"$clazz: ${found.namesString}" }.toArray[Object] Arrays.sort(deps) Arrays.sort(names) @@ -325,7 +325,7 @@ class DependencyRecorder { /** A map from a non-local class to the names and classes it uses, this does not include * names which are only defined and not referenced. */ - def foundDeps: collection.Map[Symbol, FoundDepsInClass] = _foundDeps + def foundDeps: util.ReadOnlyMap[Symbol, FoundDepsInClass] = _foundDeps /** Record a reference to the name of `sym` from the current non-local * enclosing class. @@ -429,13 +429,13 @@ class DependencyRecorder { if (fromClass.exists) lastFoundCache.addDependency(toClass, context) - private val _foundDeps = new mutable.HashMap[Symbol, FoundDepsInClass] + private val _foundDeps = new util.EqHashMap[Symbol, FoundDepsInClass] /** Send the collected dependency information to Zinc and clear the local caches. */ def sendToZinc()(using Context): Unit = ctx.withIncCallback: cb => val siblingClassfiles = new mutable.HashMap[PlainFile, Path] - foundDeps.foreach: + _foundDeps.iterator.foreach: case (clazz, foundDeps) => val className = classNameAsString(clazz) foundDeps.names.foreach: (usedName, scopes) => From a813f2f4ef31d52fd0691ab0a1e58fe714644fa1 Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Wed, 16 Aug 2023 14:30:02 +0200 Subject: [PATCH 8/9] add unit tests for maps and sets --- .../src/dotty/tools/dotc/util/EqHashSet.scala | 30 ---- .../src/dotty/tools/dotc/util/HashSet.scala | 42 ++---- .../dotty/tools/dotc/util/EqHashMapTest.scala | 115 +++++++++++++++ .../dotty/tools/dotc/util/EqHashSetTest.scala | 119 +++++++++++++++ .../dotty/tools/dotc/util/HashMapTest.scala | 137 ++++++++++++++++++ .../dotty/tools/dotc/util/HashSetTest.scala | 117 +++++++++++++++ 6 files changed, 498 insertions(+), 62 deletions(-) create mode 100644 compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala create mode 100644 compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala create mode 100644 compiler/test/dotty/tools/dotc/util/HashMapTest.scala create mode 100644 compiler/test/dotty/tools/dotc/util/HashSetTest.scala diff --git a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala index 42aee97ce79c..44a050ae2bf8 100644 --- a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala @@ -85,36 +85,6 @@ class EqHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends override def +=(x: T): Unit = put(x) - override def remove(x: T): Boolean = - Stats.record(statsItem("remove")) - var idx = firstIndex(x) - var e: T | Null = entryAt(idx) - while e != null do - if isEqual(e.uncheckedNN, x) then - var hole = idx - while - idx = nextIndex(idx) - e = entryAt(idx) - e != null - do - val eidx = index(hash(e.uncheckedNN)) - if isDense - || index(eidx - (hole + 1)) > index(idx - (hole + 1)) - // entry `e` at `idx` can move unless `index(hash(e))` is in - // the (ring-)interval [hole + 1 .. idx] - then - setEntry(hole, e.uncheckedNN) - hole = idx - table(hole) = null - used -= 1 - return true - idx = nextIndex(idx) - e = entryAt(idx) - false - - override def -=(x: T): Unit = - remove(x) - private def addOld(x: T) = Stats.record(statsItem("re-enter")) var idx = firstIndex(x) diff --git a/compiler/src/dotty/tools/dotc/util/HashSet.scala b/compiler/src/dotty/tools/dotc/util/HashSet.scala index e8cabd13a097..3a973793d542 100644 --- a/compiler/src/dotty/tools/dotc/util/HashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/HashSet.scala @@ -64,48 +64,29 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Ge if used > limit then growTable() x - override def put(x: T): T = - Stats.record(statsItem("put")) + override def add(x: T): Boolean = + Stats.record(statsItem("enter")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) while e != null do - // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling - if isEqual(e.uncheckedNN, x) then return e.uncheckedNN + if isEqual(e.uncheckedNN, x) then return false // already entered idx = nextIndex(idx) e = entryAt(idx) addEntryAt(idx, x) + true // first entry - override def +=(x: T): Unit = put(x) - - override def remove(x: T): Boolean = - Stats.record(statsItem("remove")) + override def put(x: T): T = + Stats.record(statsItem("put")) var idx = firstIndex(x) var e: T | Null = entryAt(idx) while e != null do - if isEqual(e.uncheckedNN, x) then - var hole = idx - while - idx = nextIndex(idx) - e = entryAt(idx) - e != null - do - val eidx = index(hash(e.uncheckedNN)) - if isDense - || index(eidx - (hole + 1)) > index(idx - (hole + 1)) - // entry `e` at `idx` can move unless `index(hash(e))` is in - // the (ring-)interval [hole + 1 .. idx] - then - setEntry(hole, e.uncheckedNN) - hole = idx - table(hole) = null - used -= 1 - return true + // TODO: remove uncheckedNN when explicit-nulls is enabled for regule compiling + if isEqual(e.uncheckedNN, x) then return e.uncheckedNN idx = nextIndex(idx) e = entryAt(idx) - false + addEntryAt(idx, x) - override def -=(x: T): Unit = - remove(x) + override def +=(x: T): Unit = put(x) private def addOld(x: T) = Stats.record(statsItem("re-enter")) @@ -125,7 +106,4 @@ class HashSet[T](initialCapacity: Int = 8, capacityMultiple: Int = 2) extends Ge val e: T | Null = oldTable(idx).asInstanceOf[T | Null] if e != null then addOld(e.uncheckedNN) idx += 1 - - override def iterator: Iterator[T] = new EntryIterator(): - def entry(idx: Int) = entryAt(idx) } diff --git a/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala b/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala new file mode 100644 index 000000000000..561dabb555a9 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/EqHashMapTest.scala @@ -0,0 +1,115 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class EqHashMapTest: + + var counter = 0 + + // basic identity hash, and reference equality, but with a counter for ordering + class Id: + val count = { counter += 1; counter } + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + + @Test + def newEmpty: Unit = + val m = EqHashMap[Id, Int]() + assert(m.size == 0) + assert(m.iterator.toList == Nil) + + @Test + def update: Unit = + val m = EqHashMap[Id, Int]() + assert(m.size == 0 && !m.contains(id1)) + m.update(id1, 1) + assert(m.size == 1 && m(id1) == 1) + m.update(id1, 2) // replace value + assert(m.size == 1 && m(id1) == 2) + m.update(id3, 3) // new key + assert(m.size == 2 && m(id1) == 2 && m(id3) == 3) + + @Test + def getOrElseUpdate: Unit = + val m = EqHashMap[Id, Int]() + // add id1 + assert(m.size == 0 && !m.contains(id1)) + val added = m.getOrElseUpdate(id1, 1) + assert(added == 1 && m.size == 1 && m(id1) == 1) + // try add id1 again + val addedAgain = m.getOrElseUpdate(id1, 23) + assert(addedAgain != 23 && m.size == 1 && m(id1) == 1) // no change + + private def fullMap() = + val m = EqHashMap[Id, Int]() + m.update(id1, 1) + m.update(id2, 2) + m + + @Test + def remove: Unit = + val m = fullMap() + // remove id2 + m.remove(id2) + assert(m.size == 1) + assert(m.contains(id1) && !m.contains(id2)) + // remove id1 + m -= id1 + assert(m.size == 0) + assert(!m.contains(id1) && !m.contains(id2)) + + @Test + def lookup: Unit = + val m = fullMap() + assert(m.lookup(id1) == 1) + assert(m.lookup(id2) == 2) + assert(m.lookup(id3) == null) + + @Test + def iterator: Unit = + val m = fullMap() + assert(m.iterator.toList.sorted == List(id1 -> 1,id2 -> 2)) + + @Test + def clear: Unit = + locally: + val s1 = fullMap() + s1.clear() + assert(s1.size == 0) + locally: + val s2 = fullMap() + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash set is based on reference equality, i.e. does not use universal equality */ + @Test + def referenceEquality: Unit = + val i1, i2 = I32(1) // different instances + + assert(i1.equals(i2)) // structural equality + assert(i1 ne i2) // reference inequality + + val m = locally: + val m = EqHashMap[I32, Int]() + m(i1) = 23 + m(i2) = 29 + m + + assert(m.size == 2 && m(i1) == 23 && m(i2) == 29) + assert(m.keysIterator.toSet == Set(i1)) // scala.Set delegates to universal equality + end referenceEquality + diff --git a/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala b/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala new file mode 100644 index 000000000000..1c1ffe0b7931 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/EqHashSetTest.scala @@ -0,0 +1,119 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class EqHashSetTest: + + var counter = 0 + + // basic identity hash, and reference equality, but with a counter for ordering + class Id: + val count = { counter += 1; counter } + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + + @Test + def newEmpty: Unit = + val s = EqHashSet[Id]() + assert(s.size == 0) + assert(s.iterator.toList == Nil) + + @Test + def put: Unit = + val s = EqHashSet[Id]() + // put id1 + assert(s.size == 0 && !s.contains(id1)) + s += id1 + assert(s.size == 1 && s.contains(id1)) + // put id2 + assert(!s.contains(id2)) + s.put(id2) + assert(s.size == 2 && s.contains(id1) && s.contains(id2)) + // put id3 + s ++= List(id3) + assert(s.size == 3 && s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def add: Unit = + val s = EqHashSet[Id]() + // add id1 + assert(s.size == 0 && !s.contains(id1)) + val added = s.add(id1) + assert(added && s.size == 1 && s.contains(id1)) + // try add id1 again + val addedAgain = s.add(id1) + assert(!addedAgain && s.size == 1 && s.contains(id1)) // no change + + @Test + def construct: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + assert(s.size == 3) + assert(s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def remove: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + // remove id2 + s.remove(id2) + assert(s.size == 2) + assert(s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id1 + s -= id1 + assert(s.size == 1) + assert(!s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id3 + s --= List(id3) + assert(s.size == 0) + assert(!s.contains(id1) && !s.contains(id2) && !s.contains(id3)) + + @Test + def lookup: Unit = + val s = EqHashSet.from(List(id1, id2)) + assert(s.lookup(id1) eq id1) + assert(s.lookup(id2) eq id2) + assert(s.lookup(id3) eq null) + + @Test + def iterator: Unit = + val s = EqHashSet.from(List(id1,id2,id3)) + assert(s.iterator.toList.sorted == List(id1,id2,id3)) + + @Test + def clear: Unit = + locally: + val s1 = EqHashSet.from(List(id1,id2,id3)) + s1.clear() + assert(s1.size == 0) + locally: + val s2 = EqHashSet.from(List(id1,id2,id3)) + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash map is based on reference equality, i.e. does not use universal equality */ + @Test + def referenceEquality: Unit = + val i1, i2 = I32(1) // different instances + + assert(i1.equals(i2)) // structural equality + assert(i1 ne i2) // reference inequality + + val s = EqHashSet.from(List(i1,i2)) + + assert(s.size == 2 && s.contains(i1) && s.contains(i2)) + assert(s.iterator.toSet == Set(i1)) // scala.Set delegates to universal equality + end referenceEquality + diff --git a/compiler/test/dotty/tools/dotc/util/HashMapTest.scala b/compiler/test/dotty/tools/dotc/util/HashMapTest.scala new file mode 100644 index 000000000000..97bf8446756c --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/HashMapTest.scala @@ -0,0 +1,137 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class HashMapTest: + + var counter = 0 + + // structural hash and equality, but with a counter for ordering + class Id(val count: Int = { counter += 1; counter }): + override def hashCode(): Int = count + override def equals(that: Any): Boolean = that match + case that: Id => this.count == that.count + case _ => false + def makeCopy: Id = new Id(count) + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + assert(id1 != id2 && id1 != id3 && id2 != id3) + + @Test + def newEmpty: Unit = + val m = HashMap[Id, Int]() + assert(m.size == 0) + assert(m.iterator.toList == Nil) + + @Test + def update: Unit = + val m = HashMap[Id, Int]() + assert(m.size == 0 && !m.contains(id1)) + m.update(id1, 1) + assert(m.size == 1 && m(id1) == 1) + m.update(id1, 2) // replace value + assert(m.size == 1 && m(id1) == 2) + m.update(id3, 3) // new key + assert(m.size == 2 && m(id1) == 2 && m(id3) == 3) + + @Test + def getOrElseUpdate: Unit = + val m = HashMap[Id, Int]() + // add id1 + assert(m.size == 0 && !m.contains(id1)) + val added = m.getOrElseUpdate(id1, 1) + assert(added == 1 && m.size == 1 && m(id1) == 1) + // try add id1 again + val addedAgain = m.getOrElseUpdate(id1, 23) + assert(addedAgain != 23 && m.size == 1 && m(id1) == 1) // no change + + class StatefulHash: + var hashCount = 0 + override def hashCode(): Int = { hashCount += 1; super.hashCode() } + + @Test + def getOrElseUpdate_hashesAtMostOnce: Unit = + locally: + val sh1 = StatefulHash() + val m = HashMap[StatefulHash, Int]() // will be a dense map with default size + val added = m.getOrElseUpdate(sh1, 1) + assert(sh1.hashCount == 0) // no hashing at all for dense maps + locally: + val sh1 = StatefulHash() + val m = HashMap[StatefulHash, Int](64) // not dense + val added = m.getOrElseUpdate(sh1, 1) + assert(sh1.hashCount == 1) // would be 2 if for example getOrElseUpdate was implemented as lookup + update + + private def fullMap() = + val m = HashMap[Id, Int]() + m.update(id1, 1) + m.update(id2, 2) + m + + @Test + def remove: Unit = + val m = fullMap() + // remove id2 + m.remove(id2) + assert(m.size == 1) + assert(m.contains(id1) && !m.contains(id2)) + // remove id1 + m -= id1 + assert(m.size == 0) + assert(!m.contains(id1) && !m.contains(id2)) + + @Test + def lookup: Unit = + val m = fullMap() + assert(m.lookup(id1) == 1) + assert(m.lookup(id2) == 2) + assert(m.lookup(id3) == null) + + @Test + def iterator: Unit = + val m = fullMap() + assert(m.iterator.toList.sorted == List(id1 -> 1,id2 -> 2)) + + @Test + def clear: Unit = + locally: + val s1 = fullMap() + s1.clear() + assert(s1.size == 0) + locally: + val s2 = fullMap() + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + // basic structural equality and hash code + class I32(val x: Int): + override def hashCode(): Int = x + override def equals(that: Any): Boolean = that match + case that: I32 => this.x == that.x + case _ => false + + /** the hash map is based on universal equality, i.e. does not use reference equality */ + @Test + def universalEquality: Unit = + val id2_2 = id2.makeCopy + + assert(id2.equals(id2_2)) // structural equality + assert(id2 ne id2_2) // reference inequality + + val m = locally: + val m = HashMap[Id, Int]() + m(id2) = 23 + m(id2_2) = 29 + m + + assert(m.size == 1 && m(id2) == 29 && m(id2_2) == 29) + assert(m.keysIterator.toList.head eq id2) // does not replace id2 with id2_2 + end universalEquality + diff --git a/compiler/test/dotty/tools/dotc/util/HashSetTest.scala b/compiler/test/dotty/tools/dotc/util/HashSetTest.scala new file mode 100644 index 000000000000..2089be508a4c --- /dev/null +++ b/compiler/test/dotty/tools/dotc/util/HashSetTest.scala @@ -0,0 +1,117 @@ +package dotty.tools.dotc.util + +import org.junit.Test +import org.junit.Assert.* + +class HashSetTest: + + var counter = 0 + + // structural hash and equality, with a counter for ordering + class Id(val count: Int = { counter += 1; counter }): + override def hashCode: Int = count + override def equals(that: Any): Boolean = that match + case that: Id => this.count == that.count + case _ => false + def makeCopy: Id = new Id(count) + + val id1, id2, id3 = Id() + + given Ordering[Id] = Ordering.by(_.count) + + @Test + def invariant: Unit = + assert((id1 ne id2) && (id1 ne id3) && (id2 ne id3)) + assert(id1 != id2 && id1 != id3 && id2 != id3) + + @Test + def newEmpty: Unit = + val s = HashSet[Id]() + assert(s.size == 0) + assert(s.iterator.toList == Nil) + + @Test + def put: Unit = + val s = HashSet[Id]() + // put id1 + assert(s.size == 0 && !s.contains(id1)) + s += id1 + assert(s.size == 1 && s.contains(id1)) + // put id2 + assert(!s.contains(id2)) + s.put(id2) + assert(s.size == 2 && s.contains(id1) && s.contains(id2)) + // put id3 + s ++= List(id3) + assert(s.size == 3 && s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def add: Unit = + val s = HashSet[Id]() + // add id1 + assert(s.size == 0 && !s.contains(id1)) + val added = s.add(id1) + assert(added && s.size == 1 && s.contains(id1)) + // try add id1 again + val addedAgain = s.add(id1) + assert(!addedAgain && s.size == 1 && s.contains(id1)) // no change + + @Test + def construct: Unit = + val s = HashSet.from(List(id1,id2,id3)) + assert(s.size == 3) + assert(s.contains(id1) && s.contains(id2) && s.contains(id3)) + + @Test + def remove: Unit = + val s = HashSet.from(List(id1,id2,id3)) + // remove id2 + s.remove(id2) + assert(s.size == 2) + assert(s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id1 + s -= id1 + assert(s.size == 1) + assert(!s.contains(id1) && !s.contains(id2) && s.contains(id3)) + // remove id3 + s --= List(id3) + assert(s.size == 0) + assert(!s.contains(id1) && !s.contains(id2) && !s.contains(id3)) + + @Test + def lookup: Unit = + val s = HashSet.from(List(id1, id2)) + assert(s.lookup(id1) eq id1) + assert(s.lookup(id2) eq id2) + assert(s.lookup(id3) eq null) + + @Test + def iterator: Unit = + val s = HashSet.from(List(id1,id2,id3)) + assert(s.iterator.toList.sorted == List(id1,id2,id3)) + + @Test + def clear: Unit = + locally: + val s1 = HashSet.from(List(id1,id2,id3)) + s1.clear() + assert(s1.size == 0) + locally: + val s2 = HashSet.from(List(id1,id2,id3)) + s2.clear(resetToInitial = false) + assert(s2.size == 0) + + /** the hash set is based on universal equality, i.e. does not use reference equality */ + @Test + def universalEquality: Unit = + val id2_2 = id2.makeCopy + + assert(id2.equals(id2_2)) // structural equality + assert(id2 ne id2_2) // reference inequality + + val s = HashSet.from(List(id2,id2_2)) + + assert(s.size == 1 && s.contains(id2) && s.contains(id2_2)) + assert(s.iterator.toList == List(id2)) // single element + end universalEquality + From 8113165545320d44ecc7b96ebf11a6d97a3b7a4e Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Fri, 22 Sep 2023 16:52:30 +0200 Subject: [PATCH 9/9] address review comments --- compiler/src/dotty/tools/dotc/util/EqHashSet.scala | 2 +- compiler/src/dotty/tools/dotc/util/GenericHashSet.scala | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala index 44a050ae2bf8..d584441fd00a 100644 --- a/compiler/src/dotty/tools/dotc/util/EqHashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/EqHashSet.scala @@ -15,7 +15,7 @@ object EqHashSet: * initial size of the table will be the smallest power of two * that is equal or greater than the given `initialCapacity`. * Minimum value is 4. -* @param capacityMultiple The minimum multiple of capacity relative to used elements. + * @param capacityMultiple The minimum multiple of capacity relative to used elements. * The hash table will be re-sized once the number of elements * multiplied by capacityMultiple exceeds the current size of the hash table. * However, a table of size up to DenseLimit will be re-sized only diff --git a/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala index 704298e55fb7..7abe40a8e13d 100644 --- a/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/GenericHashSet.scala @@ -36,8 +36,7 @@ abstract class GenericHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int private def roundToPower(n: Int) = if n < 4 then 4 - else if Integer.bitCount(n) == 1 then n - else 1 << (32 - Integer.numberOfLeadingZeros(n)) + else 1 << (32 - Integer.numberOfLeadingZeros(n - 1)) def clear(resetToInitial: Boolean): Unit = used = 0 @@ -49,10 +48,10 @@ abstract class GenericHashSet[T](initialCapacity: Int = 8, capacityMultiple: Int protected def isDense = limit < DenseLimit - /** Hashcode, by default a processed `x.hashCode`, can be overridden */ + /** Hashcode, to be implemented in subclass */ protected def hash(key: T): Int - /** Hashcode, by default `equals`, can be overridden */ + /** Equality, to be implemented in subclass */ protected def isEqual(x: T, y: T): Boolean /** Turn hashcode `x` into a table index */