Skip to content

Commit 21c6bbf

Browse files
authored
Merge pull request scala#10337 from som-snytt/issue/12745-map-from
Map.from and Set.from are more sound [ci: last-only]
2 parents 2e337cb + 943cb33 commit 21c6bbf

File tree

8 files changed

+159
-41
lines changed

8 files changed

+159
-41
lines changed

src/library/scala/collection/immutable/HashMap.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ final class HashMap[K, +V] private[immutable] (private[immutable] val rootNode:
5757

5858
override def keySet: Set[K] = if (size == 0) Set.empty else new HashKeySet
5959

60-
private final class HashKeySet extends ImmutableKeySet {
60+
private[immutable] final class HashKeySet extends ImmutableKeySet {
6161

6262
private[this] def newKeySetOrThis(newHashMap: HashMap[K, _]): Set[K] =
6363
if (newHashMap eq HashMap.this) this else newHashMap.keySet

src/library/scala/collection/immutable/Map.scala

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import scala.annotation.unchecked.uncheckedVariance
1818
import scala.collection.generic.DefaultSerializable
1919
import scala.collection.immutable.Map.Map4
2020
import scala.collection.mutable.{Builder, ReusableBuilder}
21+
import SeqMap.{SeqMap1, SeqMap2, SeqMap3, SeqMap4}
2122

2223
/** Base type of immutable Maps */
2324
trait Map[K, +V]
@@ -28,7 +29,7 @@ trait Map[K, +V]
2829

2930
override def mapFactory: scala.collection.MapFactory[Map] = Map
3031

31-
override final def toMap[K2, V2](implicit ev: (K, V) <:< (K2, V2)): Map[K2, V2] = this.asInstanceOf[Map[K2, V2]]
32+
override final def toMap[K2, V2](implicit ev: (K, V) <:< (K2, V2)): Map[K2, V2] = Map.from(this.asInstanceOf[Map[K2, V2]])
3233

3334
/** The same map with a given default function.
3435
* Note: The default is only used for `apply`. Other methods like `get`, `contains`, `iterator`, `keys`, etc.
@@ -141,7 +142,7 @@ trait MapOps[K, +V, +CC[X, +Y] <: MapOps[X, Y, CC, _], +C <: MapOps[K, V, CC, C]
141142
override def keySet: Set[K] = new ImmutableKeySet
142143

143144
/** The implementation class of the set returned by `keySet` */
144-
protected class ImmutableKeySet extends AbstractSet[K] with GenKeySet with DefaultSerializable {
145+
protected[immutable] class ImmutableKeySet extends AbstractSet[K] with GenKeySet with DefaultSerializable {
145146
def incl(elem: K): Set[K] = if (this(elem)) this else empty ++ this + elem
146147
def excl(elem: K): Set[K] = if (this(elem)) empty ++ this - elem else this
147148
}
@@ -206,11 +207,30 @@ object Map extends MapFactory[Map] {
206207

207208
def empty[K, V]: Map[K, V] = EmptyMap.asInstanceOf[Map[K, V]]
208209

209-
def from[K, V](it: collection.IterableOnce[(K, V)]): Map[K, V] =
210+
def from[K, V](it: IterableOnce[(K, V)]): Map[K, V] =
210211
it match {
211212
case it: Iterable[_] if it.isEmpty => empty[K, V]
212-
case m: Map[K, V] => m
213-
case _ => (newBuilder[K, V] ++= it).result()
213+
// Since IterableOnce[(K, V)] launders the variance of K,
214+
// identify only our implementations which can be soundly substituted.
215+
// For example, the ordering used by sorted maps would fail on widened key type. (scala/bug#12745)
216+
// The following type test is not sufficient: case m: Map[K, V] => m
217+
case m: HashMap[K, V] => m
218+
case m: Map1[K, V] => m
219+
case m: Map2[K, V] => m
220+
case m: Map3[K, V] => m
221+
case m: Map4[K, V] => m
222+
//case m: WithDefault[K, V] => m // cf SortedMap.WithDefault
223+
//case m: SeqMap[K, V] => SeqMap.from(it) // inlined here to avoid hard dependency
224+
case m: ListMap[K, V] => m
225+
case m: TreeSeqMap[K, V] => m
226+
case m: VectorMap[K, V] => m
227+
case m: SeqMap1[K, V] => m
228+
case m: SeqMap2[K, V] => m
229+
case m: SeqMap3[K, V] => m
230+
case m: SeqMap4[K, V] => m
231+
232+
// Maps with a reified key type must be rebuilt, such as `SortedMap` and `IntMap`.
233+
case _ => newBuilder[K, V].addAll(it).result()
214234
}
215235

216236
def newBuilder[K, V]: Builder[(K, V), Map[K, V]] = new MapBuilderImpl

src/library/scala/collection/immutable/SeqMap.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ object SeqMap extends MapFactory[SeqMap] {
4646

4747
def from[K, V](it: collection.IterableOnce[(K, V)]): SeqMap[K, V] =
4848
it match {
49-
case sm: SeqMap[K, V] => sm
49+
//case sm: SeqMap[K, V] => sm
50+
case m: ListMap[K, V] => m
51+
case m: TreeSeqMap[K, V] => m
52+
case m: VectorMap[K, V] => m
53+
case m: SeqMap1[K, V] => m
54+
case m: SeqMap2[K, V] => m
55+
case m: SeqMap3[K, V] => m
56+
case m: SeqMap4[K, V] => m
57+
case it: Iterable[_] if it.isEmpty => empty[K, V]
5058
case _ => (newBuilder[K, V] ++= it).result()
5159
}
5260

@@ -66,7 +74,7 @@ object SeqMap extends MapFactory[SeqMap] {
6674
}
6775

6876
@SerialVersionUID(3L)
69-
private final class SeqMap1[K, +V](key1: K, value1: V) extends SeqMap[K,V] with Serializable {
77+
private[immutable] final class SeqMap1[K, +V](key1: K, value1: V) extends SeqMap[K,V] with Serializable {
7078
override def size: Int = 1
7179
override def knownSize: Int = 1
7280
override def apply(key: K) = if (key == key1) value1 else throw new NoSuchElementException("key not found: " + key)
@@ -90,7 +98,7 @@ object SeqMap extends MapFactory[SeqMap] {
9098
}
9199

92100
@SerialVersionUID(3L)
93-
private final class SeqMap2[K, +V](key1: K, value1: V, key2: K, value2: V) extends SeqMap[K,V] with Serializable {
101+
private[immutable] final class SeqMap2[K, +V](key1: K, value1: V, key2: K, value2: V) extends SeqMap[K,V] with Serializable {
94102
override def size: Int = 2
95103
override def knownSize: Int = 2
96104
override def apply(key: K) =
@@ -125,7 +133,7 @@ object SeqMap extends MapFactory[SeqMap] {
125133
}
126134

127135
@SerialVersionUID(3L)
128-
private class SeqMap3[K, +V](key1: K, value1: V, key2: K, value2: V, key3: K, value3: V) extends SeqMap[K,V] with Serializable {
136+
private[immutable] class SeqMap3[K, +V](key1: K, value1: V, key2: K, value2: V, key3: K, value3: V) extends SeqMap[K,V] with Serializable {
129137
override def size: Int = 3
130138
override def knownSize: Int = 3
131139
override def apply(key: K) =
@@ -166,7 +174,7 @@ object SeqMap extends MapFactory[SeqMap] {
166174
}
167175

168176
@SerialVersionUID(3L)
169-
private final class SeqMap4[K, +V](key1: K, value1: V, key2: K, value2: V, key3: K, value3: V, key4: K, value4: V) extends SeqMap[K,V] with Serializable {
177+
private[immutable] final class SeqMap4[K, +V](key1: K, value1: V, key2: K, value2: V, key3: K, value3: V, key4: K, value4: V) extends SeqMap[K,V] with Serializable {
170178
override def size: Int = 4
171179
override def knownSize: Int = 4
172180
override def apply(key: K) =

src/library/scala/collection/immutable/Set.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,21 @@ object Set extends IterableFactory[Set] {
9696

9797
def from[E](it: collection.IterableOnce[E]): Set[E] =
9898
it match {
99-
// We want `SortedSet` (and subclasses, such as `BitSet`) to
100-
// rebuild themselves to avoid element type widening issues
101-
case _: SortedSet[E] => (newBuilder[E] ++= it).result()
102-
case _ if it.knownSize == 0 => empty[E]
103-
case s: Set[E] => s
104-
case _ => (newBuilder[E] ++= it).result()
99+
case _ if it.knownSize == 0 => empty[E]
100+
// Since IterableOnce[E] launders the variance of E,
101+
// identify only our implementations which can be soundly substituted.
102+
// It's not sufficient to match `SortedSet[E]` to rebuild and `Set[E]` to retain.
103+
case s: HashSet[E] => s
104+
case s: ListSet[E] => s
105+
case s: Set1[E] => s
106+
case s: Set2[E] => s
107+
case s: Set3[E] => s
108+
case s: Set4[E] => s
109+
case s: HashMap[E @unchecked, _]#HashKeySet => s
110+
case s: MapOps[E, Any, Map, Map[E, Any]]#ImmutableKeySet @unchecked => s
111+
// We also want `SortedSet` (and subclasses, such as `BitSet`)
112+
// to rebuild themselves, to avoid element type widening issues.
113+
case _ => newBuilder[E].addAll(it).result()
105114
}
106115

107116
def newBuilder[A]: Builder[A, Set[A]] = new SetBuilderImpl[A]

src/testkit/scala/tools/testkit/ReflectUtil.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ object ReflectUtil {
4242
f.setAccessible(true)
4343
}
4444

45+
def getFinalFieldAccessible[T: ClassTag](n: String): Field =
46+
classTag[T]
47+
.runtimeClass.getDeclaredField(n)
48+
.tap { f =>
49+
if ((f.getModifiers & Modifier.PUBLIC) == 0)
50+
f.setAccessible(true)
51+
}
52+
4553
// finds method with exact name or name$suffix but not name$default$suffix
4654
def getMethodAccessible[A: ClassTag](name: String): Method =
4755
implicitly[ClassTag[A]]

test/junit/scala/collection/FactoriesTest.scala

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package scala.collection
22

3-
import org.junit.Assert.{assertEquals, assertSame, assertTrue}
3+
import org.junit.Assert.{assertEquals, assertFalse, assertSame, assertTrue}
4+
import org.junit.Test
45

56
import scala.collection.mutable.ArrayBuffer
6-
import org.junit.{Assert, Test}
77

88
import scala.collection.{immutable => im}
99

@@ -16,15 +16,15 @@ class FactoriesTest {
1616
def cloneCollection[A, C](xs: Iterable[A])(implicit bf: BuildFrom[xs.type, A, C]): C =
1717
bf.fromSpecific(xs)(xs)
1818

19-
Assert.assertEquals("ArrayBuffer", cloneCollection(seq).collectionClassName)
19+
assertEquals("ArrayBuffer", cloneCollection(seq).collectionClassName)
2020
}
2121

2222
@Test def factoryIgnoresSourceCollectionFactory(): Unit = {
2323

2424
def cloneElements[A, C](xs: Iterable[A])(cb: Factory[A, C]): C =
2525
cb.fromSpecific(xs)
2626

27-
Assert.assertEquals("List", cloneElements(seq)(Seq).collectionClassName)
27+
assertEquals("List", cloneElements(seq)(Seq).collectionClassName)
2828
}
2929

3030
def apply(factory: IterableFactory[Iterable]): Unit = {
@@ -203,8 +203,8 @@ class FactoriesTest {
203203
im.Set(1),
204204
im.HashSet("a", "b", "c"),
205205
im.ListSet('c', 'd'),
206-
im.Map("a" -> 1, "b" -> 1, "c" -> 1).keySet,
207-
im.HashMap("a" -> 1, "b" -> 1, "c" -> 1).keySet,
206+
im.Map("a" -> 1, "b" -> 1, "c" -> 1).keySet, // MapOps$ImmutableKeySet
207+
im.HashMap("a" -> 1, "b" -> 1, "c" -> 1).keySet, // HashKeySet
208208
)
209209

210210
sortedFactoryFromIterableOnceReturnsSameReference(SortedSet, im.SortedSet)(
@@ -225,8 +225,9 @@ class FactoriesTest {
225225

226226
mapFactoryFromIterableOnceReturnsSameReference(Map, im.Map)(im.Map(1 -> 2), im.HashMap(1 -> 2))
227227
mapFactoryFromIterableOnceReturnsSameReference(im.HashMap)(im.HashMap(1 -> 2))
228-
mapFactoryFromIterableOnceReturnsSameReference(Map, im.Map)(im.IntMap(1 -> 2))
229-
mapFactoryFromIterableOnceReturnsSameReference(Map, im.Map)(im.LongMap(1L -> 2))
228+
// unsound due to widening, scala/bug#12745
229+
//mapFactoryFromIterableOnceReturnsSameReference(Map, im.Map)(im.IntMap(1 -> 2))
230+
//mapFactoryFromIterableOnceReturnsSameReference(Map, im.Map)(im.LongMap(1L -> 2))
230231

231232
mapFactoryFromIterableOnceReturnsSameReference(im.SeqMap, Map, im.Map)(
232233
im.ListMap(1 -> 2),
@@ -291,6 +292,74 @@ class FactoriesTest {
291292

292293
}
293294

295+
@Test def `custom set requires rebuild`: Unit = {
296+
import scala.collection.immutable.{Set, SortedSet}
297+
def testSame(xs: Set[Int], ys: Set[Any]): Boolean = {
298+
assertFalse(ys("oops"))
299+
assertFalse(ys.contains("oops"))
300+
xs.eq(ys)
301+
}
302+
val s1 = Set(42)
303+
assertTrue(testSame(s1, Set.from(s1)))
304+
val ss = SortedSet(42)
305+
assertFalse(testSame(ss, Set.from(ss)))
306+
307+
class Custom extends Set[Int] {
308+
// Members declared in scala.collection.IterableOnce
309+
def iterator: Iterator[Int] = Iterator.empty // implements `def iterator: Iterator[A]`
310+
311+
// Members declared in scala.collection.SetOps
312+
def contains(elem: Int): Boolean = ??? // implements `def contains(elem: A): Boolean`
313+
314+
// Members declared in scala.collection.immutable.SetOps
315+
def excl(elem: Int): scala.collection.immutable.Set[Int] = ??? // implements `def excl(elem: A): C`
316+
def incl(elem: Int): scala.collection.immutable.Set[Int] = ??? // implements `def incl(elem: A): C`
317+
}
318+
val custom = new Custom
319+
assertFalse(testSame(custom, Set.from(custom)))
320+
}
321+
322+
@Test def `select maps do not require rebuild`: Unit = {
323+
import scala.collection.immutable.{IntMap, ListMap, SeqMap, SortedMap, TreeSeqMap, VectorMap}
324+
325+
object X {
326+
val iter: Iterable[(Any, String)] = List(1, 2, 3, 4, 5).map(i => i -> i.toString).to(SortedMap.sortedMapFactory)
327+
val set: Map[Any, String] = Map.from(iter)
328+
}
329+
330+
// where ys is constructed from(xs), verify no CEE, return true if same
331+
def testSame(xs: Map[Int, String], ys: Map[Any, String]): Boolean = {
332+
assertTrue(ys.get("oops").isEmpty)
333+
assertFalse(ys.contains("oops"))
334+
xs.eq(ys)
335+
}
336+
assertFalse(X.set.contains("oops")) // was CCE
337+
// exercise small Maps
338+
1.to(5).foreach { n =>
339+
val m = Map(1.to(n).map(i => i -> i.toString): _*)
340+
assertTrue(testSame(m, Map.from(m)))
341+
}
342+
// other Maps that don't require rebuilding
343+
val listMap = ListMap(42 -> "four two")
344+
assertTrue(testSame(listMap, Map.from(listMap)))
345+
val treeSeqMap = TreeSeqMap(42 -> "four two")
346+
assertTrue(testSame(treeSeqMap, Map.from(treeSeqMap)))
347+
val vectorMap = VectorMap(42 -> "four two")
348+
assertTrue(testSame(vectorMap, Map.from(vectorMap)))
349+
val seqMap = SeqMap.empty[Int, String] + (42 -> "four two")
350+
assertTrue(testSame(seqMap, Map.from(seqMap)))
351+
// require rebuilding
352+
val sordid = SortedMap(42 -> "four two")
353+
assertFalse(testSame(sordid, Map.from(sordid)))
354+
val defaulted = listMap.withDefault(_.toString * 2)
355+
assertFalse(testSame(defaulted, Map.from(defaulted))) // deoptimized, see desorted
356+
val desorted = sordid.withDefault(_.toString * 2)
357+
assertFalse(testSame(desorted, Map.from(desorted)))
358+
359+
assertTrue(Map.from(IntMap(42 -> "once", 27 -> "upon"): Iterable[(Any, String)]).get("a time").isEmpty)
360+
}
361+
362+
// java.lang.ClassCastException: class java.lang.String cannot be cast to class java.lang.Integer
294363

295364
implicitly[Factory[Char, String]]
296365
implicitly[Factory[Char, Array[Char]]]

test/junit/scala/collection/immutable/MapTest.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package scala.collection.immutable
22

3-
import org.junit.Assert.assertEquals
3+
import org.junit.Assert.{assertEquals, assertSame, assertTrue}
44
import org.junit.Test
55

6-
import scala.annotation.nowarn
7-
86
class MapTest {
97

108
@Test def builderCompare1(): Unit = {
@@ -144,10 +142,17 @@ class MapTest {
144142
}
145143
}
146144

147-
@Test @nowarn("cat=deprecation")
145+
@Test @deprecated("Tests deprecated API", since="2.13.11")
148146
def t12699(): Unit = {
149147
val m1: HashMap[Int, Int] = HashMap(1 -> 1)
150-
assertEquals(7, m1.+(elem1 = 2 -> 2, elem2 = 3 -> 3, elems = List( 4 -> 4, 5 -> 5, 6 -> 6, 7 -> 7): _*).size)
148+
assertEquals(7, m1.+(elem1 = 2 -> 2, elem2 = 3 -> 3, elems = List(4 -> 4, 5 -> 5, 6 -> 6, 7 -> 7): _*).size)
151149
assertEquals(7, m1.+(2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5, 6 -> 6, 7 -> 7).size)
152150
}
151+
152+
@Test def `t10496 unsound toMap`: Unit = {
153+
val t = Map(42 -> 27)
154+
assertSame(t, t.toMap[Any, Any])
155+
assertTrue(t.toMap[Any, Any].get("hi").isEmpty)
156+
assertTrue(TreeMap((1, 2)).toMap[Any, Any].get("hi").isEmpty) // was: CCE
157+
}
153158
}

0 commit comments

Comments
 (0)