Skip to content

Fix CVE-2022-36944 for LazyList #569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import scala.collection.generic.{
SeqFactory
}
import scala.collection.immutable.{LinearSeq, NumericRange}
import scala.collection.mutable.{ArrayBuffer, Builder, StringBuilder}
import scala.collection.mutable.{Builder, StringBuilder}
import scala.language.implicitConversions

/** This class implements an immutable linked list that evaluates elements
Expand Down Expand Up @@ -516,10 +516,6 @@ final class LazyList[+A] private (private[this] var lazyState: () => LazyList.St
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
} else super.++:(prefix)(bf)

private def prependedAllToLL[B >: A](prefix: Traversable[B]): LazyList[B] =
if (knownIsEmpty) LazyList.from(prefix)
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))

/** @inheritdoc
*
* $preservesLaziness
Expand Down Expand Up @@ -1512,14 +1508,17 @@ object LazyList extends SeqFactory[LazyList] {

private[this] def readObject(in: ObjectInputStream): Unit = {
in.defaultReadObject()
val init = new ArrayBuffer[A]
val init = new mutable.ListBuffer[A]
var initRead = false
while (!initRead) in.readObject match {
case SerializeEnd => initRead = true
case a => init += a.asInstanceOf[A]
}
val tail = in.readObject().asInstanceOf[LazyList[A]]
coll = tail.prependedAllToLL(init)
// scala/scala#10118: caution that no code path can evaluate `tail.state`
// before the resulting LazyList is returned
val it = init.toList.iterator
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
}

private[this] def readResolve(): Any = coll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,86 @@ class LazyListGCTest {
def tapEach_takeRight_headOption_allowsGC(): Unit = {
assertLazyListOpAllowsGC(_.tapEach(_).takeRight(2).headOption, _ => ())
}

@Test
def serialization(): Unit =
if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.12"))) {
import java.io._

def serialize(obj: AnyRef): Array[Byte] = {
val buffer = new ByteArrayOutputStream
val out = new ObjectOutputStream(buffer)
out.writeObject(obj)
buffer.toByteArray
}

def deserialize(a: Array[Byte]): AnyRef = {
val in = new ObjectInputStream(new ByteArrayInputStream(a))
in.readObject
}

def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]

val l = LazyList.from(10)

val ld1 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld1.take(10).toList)

l.tail.head
val ld2 = serializeDeserialize(l)
assertEquals(l.take(10).toList, ld2.take(10).toList)

LazyListGCTest.serializationForceCount = 0
val u = LazyList
.from(10)
.map(x => {
LazyListGCTest.serializationForceCount += 1; x
})

def printDiff(): Unit = {
val a = serialize(u)
classOf[LazyList[_]]
.getDeclaredField("scala$collection$compat$immutable$LazyList$$stateEvaluated")
.setBoolean(u, true)
val b = serialize(u)
val i = a.zip(b).indexWhere(p => p._1 != p._2)
println("difference: ")
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
}

// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
// printDiff()

val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97,
118, 97, 46)
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97,
118, 97, 46)

assertEquals(LazyListGCTest.serializationForceCount, 0)

u.head
assertEquals(LazyListGCTest.serializationForceCount, 1)

val data = serialize(u)
var i = data.indexOfSlice(from)
to.foreach(x => {
data(i) = x; i += 1
})

val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]

// this check failed before scala/scala#10118, deserialization triggered evaluation
assertEquals(LazyListGCTest.serializationForceCount, 1)

ud1.tail.head
assertEquals(LazyListGCTest.serializationForceCount, 2)

u.tail.head
assertEquals(LazyListGCTest.serializationForceCount, 3)
}
}

object LazyListGCTest {
var serializationForceCount = 0
}