Skip to content

Commit 1884125

Browse files
committed
Make util.WeakHashSet a subclass of util.MutableSet
Thus making it a drop-in replacement for util.HashSet. Also add `@constructorOnly` annotations for clarity.
1 parent dd8398d commit 1884125

File tree

2 files changed

+16
-58
lines changed

2 files changed

+16
-58
lines changed

compiler/src/dotty/tools/dotc/util/MutableSet.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ abstract class MutableSet[T] extends ReadOnlySet[T]:
88
def +=(x: T): Unit
99

1010
/** Like `+=` but return existing element equal to `x` of it exists,
11-
* `x` itself otherwose.
11+
* `x` itself otherwise.
1212
*/
1313
def put(x: T): T
1414

compiler/src/dotty/tools/dotc/util/WeakHashSet.scala

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
/** Taken from the original implementation of WeakHashSet in scala-reflect
1+
/** Adapted from the original implementation of WeakHashSet in scala-reflect
22
*/
33
package dotty.tools.dotc.util
44

55
import java.lang.ref.{ReferenceQueue, WeakReference}
66

7-
import scala.annotation.tailrec
7+
import scala.annotation.{ constructorOnly, tailrec }
88
import scala.collection.mutable
99

1010
/**
@@ -17,7 +17,7 @@ import scala.collection.mutable
1717
* This set implementation is not in general thread safe without external concurrency control. However it behaves
1818
* properly when GC concurrently collects elements in this set.
1919
*/
20-
final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) extends mutable.Set[A] {
20+
final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) extends MutableSet[A] {
2121

2222
import WeakHashSet._
2323

@@ -61,8 +61,6 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
6161

6262
private def computeThreshold: Int = (table.size * loadFactor).ceil.toInt
6363

64-
def get(elem: A): Option[A] = Option(findEntry(elem))
65-
6664
/**
6765
* find the bucket associated with an element's hash code
6866
*/
@@ -145,10 +143,7 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
145143
tableLoop(0)
146144
}
147145

148-
def contains(elem: A): Boolean = findEntry(elem) ne null
149-
150-
// from scala.reflect.internal.Set, find an element or null if it isn't contained
151-
def findEntry(elem: A): A = elem match {
146+
def lookup(elem: A): A | Null = elem match {
152147
case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
153148
case _ =>
154149
removeStaleEntries()
@@ -160,14 +155,14 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
160155
case null => null.asInstanceOf[A]
161156
case _ =>
162157
val entryElem = entry.get
163-
if (elem.equals(entryElem)) entryElem
158+
if (isEqual(elem, entryElem)) entryElem
164159
else linkedListLoop(entry.tail)
165160
}
166161

167162
linkedListLoop(table(bucket))
168163
}
169-
// add an element to this set unless it's already in there and return the element
170-
def findEntryOrUpdate(elem: A): A = elem match {
164+
165+
def put(elem: A): A = elem match {
171166
case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
172167
case _ =>
173168
removeStaleEntries()
@@ -187,42 +182,17 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
187182
case null => add()
188183
case _ =>
189184
val entryElem = entry.get
190-
if (elem.equals(entryElem)) entryElem
185+
if (isEqual(elem, entryElem)) entryElem
191186
else linkedListLoop(entry.tail)
192187
}
193188

194189
linkedListLoop(oldHead)
195190
}
196191

197-
// add an element to this set unless it's already in there and return this set
198-
override def addOne(elem: A): this.type = elem match {
199-
case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
200-
case _ =>
201-
removeStaleEntries()
202-
val hash = elem.hashCode
203-
val bucket = bucketFor(hash)
204-
val oldHead = table(bucket)
205-
206-
def add(): Unit = {
207-
table(bucket) = new Entry(elem, hash, oldHead, queue)
208-
count += 1
209-
if (count > threshold) resize()
210-
}
211-
212-
@tailrec
213-
def linkedListLoop(entry: Entry[A]): Unit = entry match {
214-
case null => add()
215-
case _ if elem.equals(entry.get) => ()
216-
case _ => linkedListLoop(entry.tail)
217-
}
218-
219-
linkedListLoop(oldHead)
220-
this
221-
}
192+
def +=(elem: A): Unit = put(elem)
222193

223-
// remove an element from this set and return this set
224-
override def subtractOne(elem: A): this.type = elem match {
225-
case null => this
194+
def -=(elem: A): Unit = elem match {
195+
case null =>
226196
case _ =>
227197
removeStaleEntries()
228198
val bucket = bucketFor(elem.hashCode)
@@ -232,16 +202,14 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
232202
@tailrec
233203
def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = entry match {
234204
case null => ()
235-
case _ if elem.equals(entry.get) => remove(bucket, prevEntry, entry)
205+
case _ if isEqual(elem, entry.get) => remove(bucket, prevEntry, entry)
236206
case _ => linkedListLoop(entry, entry.tail)
237207
}
238208

239209
linkedListLoop(null, table(bucket))
240-
this
241210
}
242211

243-
// empty this set
244-
override def clear(): Unit = {
212+
def clear(): Unit = {
245213
table = new Array[Entry[A]](table.size)
246214
threshold = computeThreshold
247215
count = 0
@@ -251,21 +219,11 @@ final class WeakHashSet[A <: AnyRef](initialCapacity: Int, loadFactor: Double) e
251219
queueLoop()
252220
}
253221

254-
// true if this set is empty
255-
override def empty: This = new WeakHashSet[A](initialCapacity, loadFactor)
256-
257-
// the number of elements in this set
258-
override def size: Int = {
222+
def size: Int = {
259223
removeStaleEntries()
260224
count
261225
}
262226

263-
override def isEmpty: Boolean = size == 0
264-
override def foreach[U](f: A => U): Unit = iterator foreach f
265-
266-
// It has the `()` because iterator runs `removeStaleEntries()`
267-
override def toList(): List[A] = iterator.toList
268-
269227
// Iterator over all the elements in this set in no particular order
270228
override def iterator: Iterator[A] = {
271229
removeStaleEntries()
@@ -386,7 +344,7 @@ object WeakHashSet {
386344
* A single entry in a WeakHashSet. It's a WeakReference plus a cached hash code and
387345
* a link to the next Entry in the same bucket
388346
*/
389-
private class Entry[A](element: A, val hash:Int, var tail: Entry[A], queue: ReferenceQueue[A]) extends WeakReference[A](element, queue)
347+
private class Entry[A](@constructorOnly element: A, val hash:Int, var tail: Entry[A], @constructorOnly queue: ReferenceQueue[A]) extends WeakReference[A](element, queue)
390348

391349
private final val defaultInitialCapacity = 16
392350
private final val defaultLoadFactor = .75

0 commit comments

Comments
 (0)