diff --git a/build.sbt b/build.sbt
index 4ecf29d24..5757d2c10 100644
--- a/build.sbt
+++ b/build.sbt
@@ -41,6 +41,7 @@ lazy val xml = crossProject.in(file("."))
libraryDependencies += "junit" % "junit" % "4.11" % "test",
libraryDependencies += "com.novocode" % "junit-interface" % "0.10" % "test",
+ libraryDependencies += "org.apache.commons" % "commons-lang3" % "3.6" % "test",
libraryDependencies += ("org.scala-lang" % "scala-compiler" % scalaVersion.value % "test").exclude("org.scala-lang.modules", s"scala-xml_${scalaVersion.value}")
)
.jsSettings(
diff --git a/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala b/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala
new file mode 100644
index 000000000..2d9f286d0
--- /dev/null
+++ b/jvm/src/test/scala/scala/xml/JavaByteSerialization.scala
@@ -0,0 +1,27 @@
+package scala.xml
+
+import java.io.Serializable
+import java.util.Base64
+import org.apache.commons.lang3.SerializationUtils
+
+object JavaByteSerialization {
+ def roundTrip[T <: Serializable](obj: T): T = {
+ SerializationUtils.roundtrip(obj)
+ }
+
+ def serialize[T <: Serializable](in: T): Array[Byte] = {
+ SerializationUtils.serialize(in)
+ }
+
+ def deserialize[T <: Serializable](in: Array[Byte]): T = {
+ SerializationUtils.deserialize(in)
+ }
+
+ def base64Encode[T <: Serializable](in: T): String = {
+ Base64.getEncoder.encodeToString(serialize[T](in))
+ }
+
+ def base64Decode[T <: Serializable](in: String): T = {
+ deserialize[T](Base64.getDecoder.decode(in))
+ }
+}
diff --git a/jvm/src/test/scala/scala/xml/SerializationTest.scala b/jvm/src/test/scala/scala/xml/SerializationTest.scala
index 6a63eae84..eb8f716ea 100644
--- a/jvm/src/test/scala/scala/xml/SerializationTest.scala
+++ b/jvm/src/test/scala/scala/xml/SerializationTest.scala
@@ -1,38 +1,23 @@
package scala.xml
-import java.io._
-
import org.junit.Assert.assertEquals
import org.junit.Test
class SerializationTest {
- def roundTrip[T](obj: T): T = {
- def serialize(in: T): Array[Byte] = {
- val bos = new ByteArrayOutputStream()
- val oos = new ObjectOutputStream(bos)
- oos.writeObject(in)
- oos.flush()
- bos.toByteArray()
- }
-
- def deserialize(in: Array[Byte]): T = {
- val bis = new ByteArrayInputStream(in)
- val ois = new ObjectInputStream(bis)
- ois.readObject.asInstanceOf[T]
- }
-
- deserialize(serialize(obj))
- }
-
@Test
def xmlLiteral: Unit = {
val n =
- assertEquals(n, roundTrip(n))
+ assertEquals(n, JavaByteSerialization.roundTrip(n))
}
@Test
def empty: Unit = {
- assertEquals(NodeSeq.Empty, roundTrip(NodeSeq.Empty))
+ assertEquals(NodeSeq.Empty, JavaByteSerialization.roundTrip(NodeSeq.Empty))
+ }
+
+ @Test
+ def unmatched: Unit = {
+ assertEquals(NodeSeq.Empty, JavaByteSerialization.roundTrip( \ "HTML"))
}
@Test
@@ -40,6 +25,13 @@ class SerializationTest {
val parent =
val children: Seq[Node] = parent.child
val asNodeSeq: NodeSeq = children
- assertEquals(asNodeSeq, roundTrip(asNodeSeq))
+ assertEquals(asNodeSeq, JavaByteSerialization.roundTrip(asNodeSeq))
+ }
+
+ @Test
+ def base64Encode: Unit = {
+ val str = JavaByteSerialization.base64Encode(NodeSeq.Empty)
+ assertEquals("rO0ABXNy", str.take(8))
+ assertEquals("AHhweA==", str.takeRight(8))
}
}