Skip to content

Commit f36f95b

Browse files
committed
Use WeakHashSet instead of HashSet for hash-consing types
This mimics what Scala 2 has been doing for a long time now and serves the same purpose: it considerably reduces peak memory usage when compiling some projects, for example previously compiling the Scalatest tests required a heap of at least 11 GB, but now it fits in about 4 GB. This required changing the implementation of WeakHashSet to have overridable `hash` and `isEqual` methods just like HashSet, it also required making various private methods protected since NamedTypeUniques and AppliedUniques contain an inlined implementation of `put`. This commit also changes the default load factor of a WeakHashSet from 0.75 to 0.5 to match the load factor we use for HashSets, though note that Scala 2 has always been using 0.75. For a history of the usage of WeakHashSet in Scala 2 see: - scala/scala#247 - scala/scala#2605 - scala/scala#2901
1 parent 82cd467 commit f36f95b

File tree

3 files changed

+70
-72
lines changed

3 files changed

+70
-72
lines changed

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ object Contexts {
559559
def platform: Platform = base.platform
560560
def pendingUnderlying: util.HashSet[Type] = base.pendingUnderlying
561561
def uniqueNamedTypes: Uniques.NamedTypeUniques = base.uniqueNamedTypes
562-
def uniques: util.HashSet[Type] = base.uniques
562+
def uniques: util.WeakHashSet[Type] = base.uniques
563563

564564
def initialize()(using Context): Unit = base.initialize()
565565
}

compiler/src/dotty/tools/dotc/core/Uniques.scala

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ package core
44
import Types._, Contexts._, util.Stats._, Hashable._, Names._
55
import config.Config
66
import Decorators._
7-
import util.{HashSet, Stats}
7+
import util.{WeakHashSet, Stats}
8+
import WeakHashSet.Entry
9+
import scala.annotation.tailrec
810

9-
class Uniques extends HashSet[Type](Config.initialUniquesCapacity):
11+
class Uniques extends WeakHashSet[Type](Config.initialUniquesCapacity):
1012
override def hash(x: Type): Int = x.hash
1113
override def isEqual(x: Type, y: Type) = x.eql(y)
1214

@@ -32,7 +34,7 @@ object Uniques:
3234
if tp.hash == NotCached then tp
3335
else ctx.uniques.put(tp).asInstanceOf[T]
3436

35-
final class NamedTypeUniques extends HashSet[NamedType](Config.initialUniquesCapacity * 4) with Hashable:
37+
final class NamedTypeUniques extends WeakHashSet[NamedType](Config.initialUniquesCapacity * 4) with Hashable:
3638
override def hash(x: NamedType): Int = x.hash
3739

3840
def enterIfNew(prefix: Type, designator: Designator, isTerm: Boolean)(using Context): NamedType =
@@ -43,17 +45,25 @@ object Uniques:
4345
else new CachedTypeRef(prefix, designator, h)
4446
if h == NotCached then newType
4547
else
48+
// Inlined from WeakHashSet#put
4649
Stats.record(statsItem("put"))
47-
var idx = index(h)
48-
var e = entryAt(idx)
49-
while e != null do
50-
if (e.prefix eq prefix) && (e.designator eq designator) && (e.isTerm == isTerm) then return e
51-
idx = nextIndex(idx)
52-
e = entryAt(idx)
53-
addEntryAt(idx, newType)
50+
removeStaleEntries()
51+
val bucket = index(h)
52+
val oldHead = table(bucket)
53+
54+
@tailrec
55+
def linkedListLoop(entry: Entry[NamedType]): NamedType = entry match
56+
case null => addEntryAt(bucket, newType, h, oldHead)
57+
case _ =>
58+
val e = entry.get
59+
if e != null && (e.prefix eq prefix) && (e.designator eq designator) && (e.isTerm == isTerm) then e
60+
else linkedListLoop(entry.tail)
61+
62+
linkedListLoop(oldHead)
63+
end if
5464
end NamedTypeUniques
5565

56-
final class AppliedUniques extends HashSet[AppliedType](Config.initialUniquesCapacity * 2) with Hashable:
66+
final class AppliedUniques extends WeakHashSet[AppliedType](Config.initialUniquesCapacity * 2) with Hashable:
5767
override def hash(x: AppliedType): Int = x.hash
5868

5969
def enterIfNew(tycon: Type, args: List[Type]): AppliedType =
@@ -62,13 +72,21 @@ object Uniques:
6272
if monitored then recordCaching(h, classOf[CachedAppliedType])
6373
if h == NotCached then newType
6474
else
75+
// Inlined from WeakHashSet#put
6576
Stats.record(statsItem("put"))
66-
var idx = index(h)
67-
var e = entryAt(idx)
68-
while e != null do
69-
if (e.tycon eq tycon) && e.args.eqElements(args) then return e
70-
idx = nextIndex(idx)
71-
e = entryAt(idx)
72-
addEntryAt(idx, newType)
77+
removeStaleEntries()
78+
val bucket = index(h)
79+
val oldHead = table(bucket)
80+
81+
@tailrec
82+
def linkedListLoop(entry: Entry[AppliedType]): AppliedType = entry match
83+
case null => addEntryAt(bucket, newType, h, oldHead)
84+
case _ =>
85+
val e = entry.get
86+
if e != null && (e.tycon eq tycon) && e.args.eqElements(args) then e
87+
else linkedListLoop(entry.tail)
88+
89+
linkedListLoop(oldHead)
90+
end if
7391
end AppliedUniques
7492
end Uniques

compiler/src/dotty/tools/dotc/util/WeakHashSet.scala

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,23 @@ import scala.collection.mutable
1717
* This set implementation is not in general thread safe without external concurrency control. However it behaves
1818
* properly when GC concurrently collects elements in this set.
1919
*/
20-
final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) extends MutableSet[A] {
20+
abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Double = 0.5) extends MutableSet[A] {
2121

2222
import WeakHashSet._
2323

24-
def this() = this(initialCapacity = WeakHashSet.defaultInitialCapacity, loadFactor = WeakHashSet.defaultLoadFactor)
25-
2624
type This = WeakHashSet[A]
2725

2826
/**
2927
* queue of Entries that hold elements scheduled for GC
3028
* the removeStaleEntries() method works through the queue to remove
3129
* stale entries from the table
3230
*/
33-
private val queue = new ReferenceQueue[A]
31+
protected val queue = new ReferenceQueue[A]
3432

3533
/**
3634
* the number of elements in this set
3735
*/
38-
private var count = 0
36+
protected var count = 0
3937

4038
/**
4139
* from a specified initial capacity compute the capacity we'll use as being the next
@@ -52,33 +50,20 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
5250
/**
5351
* the underlying table of entries which is an array of Entry linked lists
5452
*/
55-
private var table = new Array[Entry[A]](computeCapacity)
53+
protected var table = new Array[Entry[A]](computeCapacity)
5654

5755
/**
5856
* the limit at which we'll increase the size of the hash table
5957
*/
60-
private var threshold = computeThreshold
58+
protected var threshold = computeThreshold
6159

6260
private def computeThreshold: Int = (table.size * loadFactor).ceil.toInt
6361

64-
/**
65-
* find the bucket associated with an element's hash code
66-
*/
67-
private def bucketFor(hash: Int): Int = {
68-
// spread the bits around to try to avoid accidental collisions using the
69-
// same algorithm as java.util.HashMap
70-
var h = hash
71-
h ^= h >>> 20 ^ h >>> 12
72-
h ^= h >>> 7 ^ h >>> 4
73-
74-
// this is finding h % table.length, but takes advantage of the
75-
// fact that table length is a power of 2,
76-
// if you don't do bit flipping in your head, if table.length
77-
// is binary 100000.. (with n 0s) then table.length - 1
78-
// is 1111.. with n 1's.
79-
// In other words this masks on the last n bits in the hash
80-
h & (table.length - 1)
81-
}
62+
protected def hash(key: A): Int
63+
protected def isEqual(x: A, y: A): Boolean = x.equals(y)
64+
65+
/** Turn hashcode `x` into a table index */
66+
protected def index(x: Int): Int = x & (table.length - 1)
8267

8368
/**
8469
* remove a single entry from a linked list in a given bucket
@@ -95,14 +80,14 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
9580
/**
9681
* remove entries associated with elements that have been gc'ed
9782
*/
98-
private def removeStaleEntries(): Unit = {
83+
protected def removeStaleEntries(): Unit = {
9984
def poll(): Entry[A] = queue.poll().asInstanceOf[Entry[A]]
10085

10186
@tailrec
10287
def queueLoop(): Unit = {
10388
val stale = poll()
10489
if (stale != null) {
105-
val bucket = bucketFor(stale.hash)
90+
val bucket = index(stale.hash)
10691

10792
@tailrec
10893
def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = if (stale eq entry) remove(bucket, prevEntry, entry)
@@ -120,7 +105,7 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
120105
/**
121106
* Double the size of the internal table
122107
*/
123-
private def resize(): Unit = {
108+
protected def resize(): Unit = {
124109
Stats.record(statsItem("resize"))
125110
val oldTable = table
126111
table = new Array[Entry[A]](oldTable.size * 2)
@@ -132,7 +117,7 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
132117
def linkedListLoop(entry: Entry[A]): Unit = entry match {
133118
case null => ()
134119
case _ =>
135-
val bucket = bucketFor(entry.hash)
120+
val bucket = index(entry.hash)
136121
val oldNext = entry.tail
137122
entry.tail = table(bucket)
138123
table(bucket) = entry
@@ -150,43 +135,43 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
150135
case _ =>
151136
Stats.record(statsItem("lookup"))
152137
removeStaleEntries()
153-
val hash = elem.hashCode
154-
val bucket = bucketFor(hash)
138+
val bucket = index(hash(elem))
155139

156140
@tailrec
157141
def linkedListLoop(entry: Entry[A]): A = entry match {
158142
case null => null.asInstanceOf[A]
159143
case _ =>
160144
val entryElem = entry.get
161-
if (elem.equals(entryElem)) entryElem
145+
if (isEqual(elem, entryElem)) entryElem
162146
else linkedListLoop(entry.tail)
163147
}
164148

165149
linkedListLoop(table(bucket))
166150
}
167151

152+
protected def addEntryAt(bucket: Int, elem: A, elemHash: Int, oldHead: Entry[A]): A = {
153+
Stats.record(statsItem("addEntryAt"))
154+
table(bucket) = new Entry(elem, elemHash, oldHead, queue)
155+
count += 1
156+
if (count > threshold) resize()
157+
elem
158+
}
159+
168160
def put(elem: A): A = elem match {
169161
case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
170162
case _ =>
171163
Stats.record(statsItem("put"))
172164
removeStaleEntries()
173-
val hash = elem.hashCode
174-
val bucket = bucketFor(hash)
165+
val h = hash(elem)
166+
val bucket = index(h)
175167
val oldHead = table(bucket)
176168

177-
def add() = {
178-
table(bucket) = new Entry(elem, hash, oldHead, queue)
179-
count += 1
180-
if (count > threshold) resize()
181-
elem
182-
}
183-
184169
@tailrec
185170
def linkedListLoop(entry: Entry[A]): A = entry match {
186-
case null => add()
171+
case null => addEntryAt(bucket, elem, h, oldHead)
187172
case _ =>
188173
val entryElem = entry.get
189-
if (elem.equals(entryElem)) entryElem
174+
if (isEqual(elem, entryElem)) entryElem
190175
else linkedListLoop(entry.tail)
191176
}
192177

@@ -200,14 +185,14 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
200185
case _ =>
201186
Stats.record(statsItem("-="))
202187
removeStaleEntries()
203-
val bucket = bucketFor(elem.hashCode)
188+
val bucket = index(hash(elem))
204189

205190

206191

207192
@tailrec
208193
def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = entry match {
209194
case null => ()
210-
case _ if elem.equals(entry.get) => remove(bucket, prevEntry, entry)
195+
case _ if isEqual(elem, entry.get) => remove(bucket, prevEntry, entry)
211196
case _ => linkedListLoop(entry, entry.tail)
212197
}
213198

@@ -307,9 +292,9 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
307292
assert(entry.get != null, s"$entry had a null value indicated that gc activity was happening during diagnostic validation or that a null value was inserted")
308293
computedCount += 1
309294
val cachedHash = entry.hash
310-
val realHash = entry.get.hashCode
295+
val realHash = hash(entry.get)
311296
assert(cachedHash == realHash, s"for $entry cached hash was $cachedHash but should have been $realHash")
312-
val computedBucket = bucketFor(realHash)
297+
val computedBucket = index(realHash)
313298
assert(computedBucket == bucket, s"for $entry the computed bucket was $computedBucket but should have been $bucket")
314299

315300
entry = entry.tail
@@ -355,11 +340,6 @@ object WeakHashSet {
355340
* A single entry in a WeakHashSet. It's a WeakReference plus a cached hash code and
356341
* a link to the next Entry in the same bucket
357342
*/
358-
private class Entry[A](@constructorOnly element: A, val hash:Int, var tail: Entry[A], @constructorOnly queue: ReferenceQueue[A]) extends WeakReference[A](element, queue)
359-
360-
private final val defaultInitialCapacity = 16
361-
private final val defaultLoadFactor = .75
343+
class Entry[A](@constructorOnly element: A, val hash:Int, var tail: Entry[A], @constructorOnly queue: ReferenceQueue[A]) extends WeakReference[A](element, queue)
362344

363-
def apply[A <: AnyRef](initialCapacity: Int = defaultInitialCapacity, loadFactor: Double = defaultLoadFactor): WeakHashSet[A] =
364-
new WeakHashSet(initialCapacity, loadFactor)
365345
}

0 commit comments

Comments
 (0)