Skip to content

Commit 590c6f1

Browse files
NthPortallrytz
authored andcommitted
Fix CVE-2022-36944 for LazyList
Backport fix for CVE-2022-36944 from 2.13. Code copy-pasted in a browser.
1 parent 53b8c17 commit 590c6f1

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import scala.collection.generic.{
3333
SeqFactory
3434
}
3535
import scala.collection.immutable.{LinearSeq, NumericRange}
36-
import scala.collection.mutable.{ArrayBuffer, Builder, StringBuilder}
36+
import scala.collection.mutable.{Builder, StringBuilder}
3737
import scala.language.implicitConversions
3838

3939
/** This class implements an immutable linked list that evaluates elements
@@ -516,10 +516,6 @@ final class LazyList[+A] private (private[this] var lazyState: () => LazyList.St
516516
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
517517
} else super.++:(prefix)(bf)
518518

519-
private def prependedAllToLL[B >: A](prefix: Traversable[B]): LazyList[B] =
520-
if (knownIsEmpty) LazyList.from(prefix)
521-
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
522-
523519
/** @inheritdoc
524520
*
525521
* $preservesLaziness
@@ -1512,14 +1508,17 @@ object LazyList extends SeqFactory[LazyList] {
15121508

15131509
private[this] def readObject(in: ObjectInputStream): Unit = {
15141510
in.defaultReadObject()
1515-
val init = new ArrayBuffer[A]
1511+
val init = new mutable.ListBuffer[A]
15161512
var initRead = false
15171513
while (!initRead) in.readObject match {
15181514
case SerializeEnd => initRead = true
1519-
case a => init += a.asInstanceOf[A]
1515+
case a => init += a.asInstanceOf[A]
15201516
}
15211517
val tail = in.readObject().asInstanceOf[LazyList[A]]
1522-
coll = tail.prependedAllToLL(init)
1518+
// scala/scala#10118: caution that no code path can evaluate `tail.state`
1519+
// before the resulting LazyList is returned
1520+
val it = init.toList.iterator
1521+
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
15231522
}
15241523

15251524
private[this] def readResolve(): Any = coll

compat/src/test/scala/test/scala/collection/LazyListTest.scala

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,73 @@ import scala.collection.mutable.{Builder, ListBuffer}
2222
import scala.util.Try
2323

2424
class LazyListTest {
25+
26+
@Test
27+
def serialization(): Unit = if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.12"))) {
28+
import java.io._
29+
30+
def serialize(obj: AnyRef): Array[Byte] = {
31+
val buffer = new ByteArrayOutputStream
32+
val out = new ObjectOutputStream(buffer)
33+
out.writeObject(obj)
34+
buffer.toByteArray
35+
}
36+
37+
def deserialize(a: Array[Byte]): AnyRef = {
38+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
39+
in.readObject
40+
}
41+
42+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
43+
44+
val l = LazyList.from(10)
45+
46+
val ld1 = serializeDeserialize(l)
47+
assertEquals(l.take(10).toList, ld1.take(10).toList)
48+
49+
l.tail.head
50+
val ld2 = serializeDeserialize(l)
51+
assertEquals(l.take(10).toList, ld2.take(10).toList)
52+
53+
LazyListTest.serializationForceCount = 0
54+
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x })
55+
56+
def printDiff(): Unit = {
57+
val a = serialize(u)
58+
classOf[LazyList[_]].getDeclaredField("scala$collection$compat$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
59+
val b = serialize(u)
60+
val i = a.zip(b).indexWhere(p => p._1 != p._2)
61+
println("difference: ")
62+
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
63+
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
64+
}
65+
66+
// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
67+
// printDiff()
68+
69+
val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
70+
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)
71+
72+
assertEquals(LazyListTest.serializationForceCount, 0)
73+
74+
u.head
75+
assertEquals(LazyListTest.serializationForceCount, 1)
76+
77+
val data = serialize(u)
78+
var i = data.indexOfSlice(from)
79+
to.foreach(x => {data(i) = x; i += 1})
80+
81+
val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]
82+
83+
// this check failed before scala/scala#10118, deserialization triggered evaluation
84+
assertEquals(LazyListTest.serializationForceCount, 1)
85+
86+
ud1.tail.head
87+
assertEquals(LazyListTest.serializationForceCount, 2)
88+
89+
u.tail.head
90+
assertEquals(LazyListTest.serializationForceCount, 3)
91+
}
2592

2693
@Test
2794
def t6727_and_t6440_and_8627(): Unit = {
@@ -403,3 +470,7 @@ class LazyListTest {
403470
assertEquals(1, count)
404471
}
405472
}
473+
474+
object LazyListTest {
475+
var serializationForceCount = 0
476+
}

0 commit comments

Comments
 (0)