diff --git a/build.sbt b/build.sbt index 4ecf29d24..5757d2c10 100644 --- a/build.sbt +++ b/build.sbt @@ -41,6 +41,7 @@ lazy val xml = crossProject.in(file(".")) libraryDependencies += "junit" % "junit" % "4.11" % "test", libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test", + libraryDependencies += "org.apache.commons" % "commons-lang3" % "3.6" % "test", libraryDependencies += ("org.scala-lang" % "scala-compiler" % scalaVersion.value % "test").exclude("org.scala-lang.modules", s"scala-xml_${scalaVersion.value}") ) .jsSettings( diff --git a/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala b/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala new file mode 100644 index 000000000..2d9f286d0 --- /dev/null +++ b/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala @@ -0,0 +1,27 @@ +package scala.xml + +import java.io.Serializable +import java.util.Base64 +import org.apache.commons.lang3.SerializationUtils + +object JavaByteSerialization { + def roundTrip[T <: Serializable](obj: T): T = { + SerializationUtils.roundtrip(obj) + } + + def serialize[T <: Serializable](in: T): Array[Byte] = { + SerializationUtils.serialize(in) + } + + def deserialize[T <: Serializable](in: Array[Byte]): T = { + SerializationUtils.deserialize(in) + } + + def base64Encode[T <: Serializable](in: T): String = { + Base64.getEncoder.encodeToString(serialize[T](in)) + } + + def base64Decode[T <: Serializable](in: String): T = { + deserialize[T](Base64.getDecoder.decode(in)) + } +} diff --git a/jvm/src/test/scala/scala/xml/SerializationTest.scala b/jvm/src/test/scala/scala/xml/SerializationTest.scala index 6a63eae84..eb8f716ea 100644 --- a/jvm/src/test/scala/scala/xml/SerializationTest.scala +++ b/jvm/src/test/scala/scala/xml/SerializationTest.scala @@ -1,38 +1,23 @@ package scala.xml -import java.io._ - import org.junit.Assert.assertEquals import org.junit.Test class SerializationTest { - def roundTrip[T](obj: T): T = { - def serialize(in: T): Array[Byte] = { - val bos = new ByteArrayOutputStream() - val oos = new ObjectOutputStream(bos) - oos.writeObject(in) - oos.flush() - bos.toByteArray() - } - - def deserialize(in: Array[Byte]): T = { - val bis = new ByteArrayInputStream(in) - val ois = new ObjectInputStream(bis) - ois.readObject.asInstanceOf[T] - } - - deserialize(serialize(obj)) - } - @Test def xmlLiteral: Unit = { val n = - assertEquals(n, roundTrip(n)) + assertEquals(n, JavaByteSerialization.roundTrip(n)) } @Test def empty: Unit = { - assertEquals(NodeSeq.Empty, roundTrip(NodeSeq.Empty)) + assertEquals(NodeSeq.Empty, JavaByteSerialization.roundTrip(NodeSeq.Empty)) + } + + @Test + def unmatched: Unit = { + assertEquals(NodeSeq.Empty, JavaByteSerialization.roundTrip( \ "HTML")) } @Test @@ -40,6 +25,13 @@ class SerializationTest { val parent = val children: Seq[Node] = parent.child val asNodeSeq: NodeSeq = children - assertEquals(asNodeSeq, roundTrip(asNodeSeq)) + assertEquals(asNodeSeq, JavaByteSerialization.roundTrip(asNodeSeq)) + } + + @Test + def base64Encode: Unit = { + val str = JavaByteSerialization.base64Encode(NodeSeq.Empty) + assertEquals("rO0ABXNy", str.take(8)) + assertEquals("AHhweA==", str.takeRight(8)) } }