Skip to content

Commit aea4fd6

Browse files
authored
Add test for CVE-2022-36944
Test is currently broken, as `ReflectUtil` doesn't exist in this repo
1 parent e4ea3ab commit aea4fd6

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

compat/src/test/scala/test/scala/collection/LazyListTest.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,74 @@ import scala.collection.mutable.{Builder, ListBuffer}
2222
import scala.util.Try
2323

2424
class LazyListTest {
25+
26+
@Test
27+
def serialization(): Unit = {
28+
import java.io._
29+
30+
def serialize(obj: AnyRef): Array[Byte] = {
31+
val buffer = new ByteArrayOutputStream
32+
val out = new ObjectOutputStream(buffer)
33+
out.writeObject(obj)
34+
buffer.toByteArray
35+
}
36+
37+
def deserialize(a: Array[Byte]): AnyRef = {
38+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
39+
in.readObject
40+
}
41+
42+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
43+
44+
val l = LazyList.from(10)
45+
46+
val ld1 = serializeDeserialize(l)
47+
assertEquals(l.take(10).toList, ld1.take(10).toList)
48+
49+
l.tail.head
50+
val ld2 = serializeDeserialize(l)
51+
assertEquals(l.take(10).toList, ld2.take(10).toList)
52+
53+
LazyListTest.serializationForceCount = 0
54+
val u = LazyList.from(10).map(x => { LazyListTest.serializationForceCount += 1; x })
55+
56+
@unused def printDiff(): Unit = {
57+
val a = serialize(u)
58+
// TODO: replace with implementation of `getFieldAccessible`, since it's private to the scala/scala repo
59+
ReflectUtil.getFieldAccessible[LazyList[_]]("scala$collection$immutable$LazyList$$stateEvaluated").setBoolean(u, true)
60+
val b = serialize(u)
61+
val i = a.zip(b).indexWhere(p => p._1 != p._2)
62+
println("difference: ")
63+
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
64+
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
65+
}
66+
67+
// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
68+
// printDiff()
69+
70+
val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97, 118, 97, 46)
71+
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97, 118, 97, 46)
72+
73+
assertEquals(LazyListTest.serializationForceCount, 0)
74+
75+
u.head
76+
assertEquals(LazyListTest.serializationForceCount, 1)
77+
78+
val data = serialize(u)
79+
var i = data.indexOfSlice(from)
80+
to.foreach(x => {data(i) = x; i += 1})
81+
82+
val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]
83+
84+
// this check failed before scala/scala#10118, deserialization triggered evaluation
85+
assertEquals(LazyListTest.serializationForceCount, 1)
86+
87+
ud1.tail.head
88+
assertEquals(LazyListTest.serializationForceCount, 2)
89+
90+
u.tail.head
91+
assertEquals(LazyListTest.serializationForceCount, 3)
92+
}
2593

2694
@Test
2795
def t6727_and_t6440_and_8627(): Unit = {
@@ -403,3 +471,7 @@ class LazyListTest {
403471
assertEquals(1, count)
404472
}
405473
}
474+
475+
object LazyListTest {
476+
var serializationForceCount = 0
477+
}

0 commit comments

Comments
 (0)