Skip to content

Commit 3ef547d

Browse files
authored
Adding support for update output mode to structured streaming (#1839)
This commit adds support for "update" as the output mode for spark structured streaming to Elasticsearch. Closes #1123
1 parent c7bdb67 commit 3ef547d

File tree

7 files changed

+213
-25
lines changed

7 files changed

+213
-25
lines changed

mr/src/main/java/org/elasticsearch/hadoop/rest/InitializationUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public abstract class InitializationUtils {
5151
public static void checkIdForOperation(Settings settings) {
5252
String operation = settings.getOperation();
5353

54-
if (ConfigurationOptions.ES_OPERATION_UPDATE.equals(operation)) {
54+
if (ConfigurationOptions.ES_OPERATION_UPDATE.equals(operation) || ConfigurationOptions.ES_OPERATION_UPSERT.equals(operation)) {
5555
Assert.isTrue(StringUtils.hasText(settings.getMappingId()),
5656
String.format("Operation [%s] requires an id but none (%s) was specified", operation, ConfigurationOptions.ES_MAPPING_ID));
5757
}

spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkStructuredStreaming.scala

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020
package org.elasticsearch.spark.integration
2121

22+
import com.fasterxml.jackson.databind.ObjectMapper
23+
2224
import java.io.File
2325
import java.nio.file.Files
2426
import java.sql.Timestamp
2527
import java.util.concurrent.TimeUnit
2628
import java.{lang => jl}
2729
import java.{util => ju}
28-
2930
import javax.xml.bind.DatatypeConverter
3031
import org.apache.hadoop.fs.Path
3132
import org.apache.spark.SparkConf
@@ -56,8 +57,7 @@ import org.hamcrest.Matchers.containsString
5657
import org.hamcrest.Matchers.is
5758
import org.hamcrest.Matchers.not
5859
import org.junit.{AfterClass, Assert, Assume, BeforeClass, ClassRule, FixMethodOrder, Rule, Test}
59-
import org.junit.Assert.assertThat
60-
import org.junit.Assert.assertTrue
60+
import org.junit.Assert.{assertEquals, assertThat, assertTrue}
6161
import org.junit.rules.TemporaryFolder
6262
import org.junit.runner.RunWith
6363
import org.junit.runners.MethodSorters
@@ -585,4 +585,63 @@ class AbstractScalaEsSparkStructuredStreaming(prefix: String, something: Boolean
585585
.start(target)
586586
}
587587
}
588+
589+
@Test
590+
def testUpdate(): Unit = {
591+
val target = wrapIndex(resource("test-update", "data", version))
592+
val docPath = wrapIndex(docEndpoint("test-update", "data", version))
593+
val test = new StreamingQueryTestHarness[Record](spark)
594+
595+
test.withInput(Record(1, "Spark"))
596+
.withInput(Record(2, "Hadoop"))
597+
.withInput(Record(3, "YARN"))
598+
.startTest {
599+
test.stream
600+
.writeStream
601+
.outputMode("update")
602+
.option("checkpointLocation", checkpoint(target))
603+
.option(ES_MAPPING_ID, "id")
604+
.format("es")
605+
.start(target)
606+
}
607+
test.waitForPartialCompletion()
608+
609+
assertTrue(RestUtils.exists(target))
610+
assertTrue(RestUtils.exists(docPath + "/1"))
611+
assertTrue(RestUtils.exists(docPath + "/2"))
612+
assertTrue(RestUtils.exists(docPath + "/3"))
613+
var searchResult = RestUtils.get(target + "/_search?")
614+
assertThat(searchResult, containsString("Spark"))
615+
assertThat(searchResult, containsString("Hadoop"))
616+
assertThat(searchResult, containsString("YARN"))
617+
618+
test.withInput(Record(1, "Spark"))
619+
.withInput(Record(2, "Hadoop2"))
620+
.withInput(Record(3, "YARN"))
621+
test.waitForCompletion()
622+
searchResult = RestUtils.get(target + "/_search?version=true")
623+
val result: java.util.Map[String, Object] = new ObjectMapper().readValue(searchResult, classOf[java.util.Map[String, Object]])
624+
val hits = result.get("hits").asInstanceOf[java.util.Map[String, Object]].get("hits").asInstanceOf[java.util.List[java.util.Map[String,
625+
Object]]]
626+
hits.forEach(hit => {
627+
hit.get("_id").asInstanceOf[String] match {
628+
case "1" => {
629+
assertEquals(1, hit.get("_version"))
630+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
631+
assertEquals("Spark", value)
632+
}
633+
case "2" => {
634+
assertEquals(2, hit.get("_version")) // The only one that should have been updated
635+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
636+
assertEquals("Hadoop2", value)
637+
}
638+
case "3" => {
639+
assertEquals(1, hit.get("_version"))
640+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
641+
assertEquals("YARN", value)
642+
}
643+
case _ => throw new AssertionError("Unexpected result")
644+
}
645+
})
646+
}
588647
}

spark/sql-20/src/itest/scala/org/elasticsearch/spark/sql/streaming/StreamingQueryTestHarness.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
7676
private var foundExpectedException: Boolean = false
7777
private var encounteredException: Option[String] = None
7878

79-
private val latch = new CountDownLatch(1)
79+
private var latch = new CountDownLatch(1) // expects just a single batch
8080

8181
def incrementExpected(): Unit = inputsRequired = inputsRequired + 1
8282

@@ -153,6 +153,10 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
153153

154154
def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)
155155

156+
def expectAnotherBatch(): Unit = {
157+
latch = new CountDownLatch(1)
158+
}
159+
156160
def assertExpectedExceptions(message: Option[String]): Unit = {
157161
expectingToThrow match {
158162
case Some(exceptionClass) =>
@@ -211,7 +215,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
211215
* Add input to test server. Updates listener's bookkeeping to know when it's safe to shut down the stream
212216
*/
213217
def withInput(data: S): StreamingQueryTestHarness[S] = {
214-
ensureState(Init) {
218+
ensureState(Init, Running) {
215219
testingServer.sendData(TestingSerde.serialize(data))
216220
listener.incrementExpected()
217221
}
@@ -320,6 +324,30 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
320324
}
321325
}
322326

327+
/**
328+
* Waits until all inputs are processed on the streaming query, but leaves the query open with the listener still in place, expecting
329+
* another batch of inputs.
330+
*/
331+
def waitForPartialCompletion(): Unit = {
332+
ensureState(Running) {
333+
currentState match {
334+
case Running =>
335+
try {
336+
// Wait for query to complete consuming records
337+
if (!listener.waitOnComplete(testTimeout)) {
338+
throw new TimeoutException("Timed out on waiting for stream to complete.")
339+
}
340+
listener.expectAnotherBatch()
341+
} catch {
342+
case e: Throwable =>
343+
// Best effort to shutdown queries before throwing
344+
scrubState()
345+
throw e
346+
}
347+
}
348+
}
349+
}
350+
323351
// tears down literally everything indiscriminately, mostly for cleanup after a failure
324352
private[this] def scrubState(): Unit = {
325353
sparkSession.streams.removeListener(listener)

spark/sql-20/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.util.Calendar
2222
import java.util.Date
2323
import java.util.Locale
2424
import java.util.UUID
25-
2625
import javax.xml.bind.DatatypeConverter
2726
import org.apache.commons.logging.LogFactory
2827
import org.apache.spark.rdd.RDD
@@ -63,6 +62,7 @@ import org.apache.spark.sql.internal.SQLConf
6362
import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException
6463
import org.elasticsearch.hadoop.EsHadoopIllegalStateException
6564
import org.elasticsearch.hadoop.cfg.ConfigurationOptions
65+
import org.elasticsearch.hadoop.cfg.ConfigurationOptions.ES_WRITE_OPERATION
6666
import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions
6767
import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions.INTERNAL_TRANSPORT_POOLING_KEY
6868
import org.elasticsearch.hadoop.cfg.Settings
@@ -122,12 +122,6 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro
122122
// Verify compatiblity versions for alpha:
123123
StructuredStreamingVersionLock.checkCompatibility(sparkSession)
124124

125-
// For now we only support Append style output mode
126-
if (outputMode != OutputMode.Append()) {
127-
throw new EsHadoopIllegalArgumentException("Append is only supported OutputMode for Elasticsearch. " +
128-
s"Cannot continue with [$outputMode].")
129-
}
130-
131125
// Should not support partitioning. We already allow people to split data into different
132126
// indices with the index pattern functionality. Potentially could add this later if a need
133127
// arises by appending patterns to the provided index, but that's probably feature overload.
@@ -142,6 +136,19 @@ private[sql] class DefaultSource extends RelationProvider with SchemaRelationPro
142136
.load(sqlContext.sparkContext.getConf)
143137
.merge(streamParams(mapConfig.toMap, sparkSession).asJava)
144138

139+
// For now we only support Update and Append style output modes
140+
if (outputMode == OutputMode.Update()) {
141+
val writeOperation = jobSettings.getProperty(ES_WRITE_OPERATION);
142+
if (writeOperation == null) {
143+
jobSettings.setProperty(ES_WRITE_OPERATION, ConfigurationOptions.ES_OPERATION_UPSERT)
144+
} else if (writeOperation != ConfigurationOptions.ES_OPERATION_UPSERT) {
145+
throw new EsHadoopIllegalArgumentException("Output mode update is only supported if es.write.operation is unset or set to upsert")
146+
}
147+
} else if (outputMode != OutputMode.Append()) {
148+
throw new EsHadoopIllegalArgumentException("Append and update are the only supported OutputModes for Elasticsearch. " +
149+
s"Cannot continue with [$outputMode].")
150+
}
151+
145152
InitializationUtils.discoverClusterInfo(jobSettings, LogFactory.getLog(classOf[DefaultSource]))
146153
InitializationUtils.checkIdForOperation(jobSettings)
147154
InitializationUtils.checkIndexExistence(jobSettings)

spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkStructuredStreaming.scala

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919

2020
package org.elasticsearch.spark.integration
2121

22+
import com.fasterxml.jackson.databind.ObjectMapper
23+
2224
import java.io.File
2325
import java.nio.file.Files
2426
import java.sql.Timestamp
2527
import java.util.concurrent.TimeUnit
2628
import java.{lang => jl}
2729
import java.{util => ju}
28-
2930
import javax.xml.bind.DatatypeConverter
3031
import org.apache.hadoop.fs.Path
3132
import org.apache.spark.SparkConf
@@ -56,8 +57,7 @@ import org.hamcrest.Matchers.containsString
5657
import org.hamcrest.Matchers.is
5758
import org.hamcrest.Matchers.not
5859
import org.junit.{AfterClass, Assert, Assume, BeforeClass, ClassRule, FixMethodOrder, Rule, Test}
59-
import org.junit.Assert.assertThat
60-
import org.junit.Assert.assertTrue
60+
import org.junit.Assert.{assertEquals, assertThat, assertTrue}
6161
import org.junit.rules.TemporaryFolder
6262
import org.junit.runner.RunWith
6363
import org.junit.runners.MethodSorters
@@ -585,4 +585,63 @@ class AbstractScalaEsSparkStructuredStreaming(prefix: String, something: Boolean
585585
.start(target)
586586
}
587587
}
588+
589+
@Test
590+
def testUpdate(): Unit = {
591+
val target = wrapIndex(resource("test-update", "data", version))
592+
val docPath = wrapIndex(docEndpoint("test-update", "data", version))
593+
val test = new StreamingQueryTestHarness[Record](spark)
594+
595+
test.withInput(Record(1, "Spark"))
596+
.withInput(Record(2, "Hadoop"))
597+
.withInput(Record(3, "YARN"))
598+
.startTest {
599+
test.stream
600+
.writeStream
601+
.outputMode("update")
602+
.option("checkpointLocation", checkpoint(target))
603+
.option(ES_MAPPING_ID, "id")
604+
.format("es")
605+
.start(target)
606+
}
607+
test.waitForPartialCompletion()
608+
609+
assertTrue(RestUtils.exists(target))
610+
assertTrue(RestUtils.exists(docPath + "/1"))
611+
assertTrue(RestUtils.exists(docPath + "/2"))
612+
assertTrue(RestUtils.exists(docPath + "/3"))
613+
var searchResult = RestUtils.get(target + "/_search?")
614+
assertThat(searchResult, containsString("Spark"))
615+
assertThat(searchResult, containsString("Hadoop"))
616+
assertThat(searchResult, containsString("YARN"))
617+
618+
test.withInput(Record(1, "Spark"))
619+
.withInput(Record(2, "Hadoop2"))
620+
.withInput(Record(3, "YARN"))
621+
test.waitForCompletion()
622+
searchResult = RestUtils.get(target + "/_search?version=true")
623+
val result: java.util.Map[String, Object] = new ObjectMapper().readValue(searchResult, classOf[java.util.Map[String, Object]])
624+
val hits = result.get("hits").asInstanceOf[java.util.Map[String, Object]].get("hits").asInstanceOf[java.util.List[java.util.Map[String,
625+
Object]]]
626+
hits.forEach(hit => {
627+
hit.get("_id").asInstanceOf[String] match {
628+
case "1" => {
629+
assertEquals(1, hit.get("_version"))
630+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
631+
assertEquals("Spark", value)
632+
}
633+
case "2" => {
634+
assertEquals(2, hit.get("_version")) // The only one that should have been updated
635+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
636+
assertEquals("Hadoop2", value)
637+
}
638+
case "3" => {
639+
assertEquals(1, hit.get("_version"))
640+
val value = hit.get("_source").asInstanceOf[java.util.Map[String, Object]].get("name").asInstanceOf[String]
641+
assertEquals("YARN", value)
642+
}
643+
case _ => throw new AssertionError("Unexpected result")
644+
}
645+
})
646+
}
588647
}

spark/sql-30/src/itest/scala/org/elasticsearch/spark/sql/streaming/StreamingQueryTestHarness.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
7676
private var foundExpectedException: Boolean = false
7777
private var encounteredException: Option[String] = None
7878

79-
private val latch = new CountDownLatch(1)
79+
private var latch = new CountDownLatch(1) // expects just a single batch
8080

8181
def incrementExpected(): Unit = inputsRequired = inputsRequired + 1
8282

@@ -153,6 +153,10 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
153153

154154
def waitOnComplete(timeValue: TimeValue): Boolean = latch.await(timeValue.millis, TimeUnit.MILLISECONDS)
155155

156+
def expectAnotherBatch(): Unit = {
157+
latch = new CountDownLatch(1)
158+
}
159+
156160
def assertExpectedExceptions(message: Option[String]): Unit = {
157161
expectingToThrow match {
158162
case Some(exceptionClass) =>
@@ -211,7 +215,7 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
211215
* Add input to test server. Updates listener's bookkeeping to know when it's safe to shut down the stream
212216
*/
213217
def withInput(data: S): StreamingQueryTestHarness[S] = {
214-
ensureState(Init) {
218+
ensureState(Init, Running) {
215219
testingServer.sendData(TestingSerde.serialize(data))
216220
listener.incrementExpected()
217221
}
@@ -320,6 +324,30 @@ class StreamingQueryTestHarness[S <: java.io.Serializable : Encoder](val sparkSe
320324
}
321325
}
322326

327+
/**
328+
* Waits until all inputs are processed on the streaming query, but leaves the query open with the listener still in place, expecting
329+
* another batch of inputs.
330+
*/
331+
def waitForPartialCompletion(): Unit = {
332+
ensureState(Running) {
333+
currentState match {
334+
case Running =>
335+
try {
336+
// Wait for query to complete consuming records
337+
if (!listener.waitOnComplete(testTimeout)) {
338+
throw new TimeoutException("Timed out on waiting for stream to complete.")
339+
}
340+
listener.expectAnotherBatch()
341+
} catch {
342+
case e: Throwable =>
343+
// Best effort to shutdown queries before throwing
344+
scrubState()
345+
throw e
346+
}
347+
}
348+
}
349+
}
350+
323351
// tears down literally everything indiscriminately, mostly for cleanup after a failure
324352
private[this] def scrubState(): Unit = {
325353
sparkSession.streams.removeListener(listener)

0 commit comments

Comments
 (0)