Skip to content

Demo: Internals of FlatMapGroupsWithStateExec Physical Operator

The following demo shows the internals of FlatMapGroupsWithStateExec physical operator in a Arbitrary Stateful Streaming Aggregation.

// Reduce the number of partitions and hence the state stores
// That is supposed to make debugging state checkpointing easier
val numShufflePartitions = 1
import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS
spark.sessionState.conf.setConf(SHUFFLE_PARTITIONS, numShufflePartitions)
assert(spark.sessionState.conf.numShufflePartitions == numShufflePartitions)

// Define event "format"
// Use :paste mode in spark-shell
import java.sql.Timestamp
case class Event(time: Timestamp, value: Long)
import scala.concurrent.duration._
object Event {
  def apply(secs: Long, value: Long): Event = {
    Event(new Timestamp(secs.seconds.toMillis), value)
  }
}

// Using memory data source for full control of the input
import org.apache.spark.sql.execution.streaming.MemoryStream
implicit val sqlCtx = spark.sqlContext
val events = MemoryStream[Event]
val values = events.toDS
assert(values.isStreaming, "values must be a streaming Dataset")

values.printSchema
/**
root
 |-- time: timestamp (nullable = true)
 |-- value: long (nullable = false)
*/

import scala.concurrent.duration._
val delayThreshold = 10.seconds
val valuesWatermarked = values
  .withWatermark(eventTime = "time", delayThreshold.toString) // required for EventTimeTimeout

// Could use Long directly, but...
// Let's use case class to make the demo a bit more advanced
case class Count(value: Long)

import java.sql.Timestamp
import org.apache.spark.sql.streaming.GroupState
val keyCounts = (key: Long, values: Iterator[(Timestamp, Long)], state: GroupState[Count]) => {
  println(s""">>> keyCounts(key = $key, state = ${state.getOption.getOrElse("<empty>")})""")
  println(s">>> >>> currentProcessingTimeMs: ${state.getCurrentProcessingTimeMs}")
  println(s">>> >>> currentWatermarkMs: ${state.getCurrentWatermarkMs}")
  println(s">>> >>> hasTimedOut: ${state.hasTimedOut}")
  val count = Count(values.length)
  Iterator((key, count))
}

import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
val valuesCounted = valuesWatermarked
  .as[(Timestamp, Long)] // convert DataFrame to Dataset to make groupByKey easier to write
  .groupByKey { case (time, value) => value }
  .flatMapGroupsWithState(
    OutputMode.Update,
    timeoutConf = GroupStateTimeout.EventTimeTimeout)(func = keyCounts)
  .toDF("value", "count")

valuesCounted.explain
/**
== Physical Plan ==
*(2) Project [_1#928L AS value#931L, _2#929 AS count#932]
+- *(2) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true])._1 AS _1#928L, if (isnull(assertnotnull(input[0, scala.Tuple2, true])._2)) null else named_struct(value, assertnotnull(assertnotnull(input[0, scala.Tuple2, true])._2).value) AS _2#929]
   +- FlatMapGroupsWithState $line140.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$Lambda$4117/181063008@d2cdc82, value#923: bigint, newInstance(class scala.Tuple2), [value#923L], [time#915-T10000ms, value#916L], obj#927: scala.Tuple2, state info [ checkpoint = <unknown>, runId = 9af3d00c-fe1f-46a0-8630-4e0d0af88042, opId = 0, ver = 0, numPartitions = 1], class[value[0]: bigint], 2, Update, EventTimeTimeout, 0, 0
      +- *(1) Sort [value#923L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(value#923L, 1)
            +- AppendColumns $line140.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$Lambda$4118/2131767153@3e606b4c, newInstance(class scala.Tuple2), [input[0, bigint, false] AS value#923L]
               +- EventTimeWatermark time#915: timestamp, interval 10 seconds
                  +- StreamingRelation MemoryStream[time#915,value#916L], [time#915, value#916L]
*/

val queryName = "FlatMapGroupsWithStateExec_demo"
val checkpointLocation = s"/tmp/checkpoint-$queryName"

// Delete the checkpoint location from previous executions
import java.nio.file.{Files, FileSystems}
import java.util.Comparator
import scala.collection.JavaConverters._
val path = FileSystems.getDefault.getPath(checkpointLocation)
if (Files.exists(path)) {
  Files.walk(path)
    .sorted(Comparator.reverseOrder())
    .iterator
    .asScala
    .foreach(p => p.toFile.delete)
}

import org.apache.spark.sql.streaming.OutputMode.Update
val streamingQuery = valuesCounted
  .writeStream
  .format("memory")
  .queryName(queryName)
  .option("checkpointLocation", checkpointLocation)
  .outputMode(Update)
  .start

assert(streamingQuery.status.message == "Waiting for data to arrive")

// Use web UI to monitor the metrics of the streaming query
// Go to http://localhost:4040/SQL/ and click one of the Completed Queries with Job IDs

// You may also want to check out checkpointed state
// in /tmp/checkpoint-FlatMapGroupsWithStateExec_demo/state/0/0

val batch = Seq(
  Event(secs = 1,  value = 1),
  Event(secs = 15, value = 2))
events.addData(batch)
streamingQuery.processAllAvailable()

/**
>>> keyCounts(key = 1, state = <empty>)
>>> >>> currentProcessingTimeMs: 1561881557237
>>> >>> currentWatermarkMs: 0
>>> >>> hasTimedOut: false
>>> keyCounts(key = 2, state = <empty>)
>>> >>> currentProcessingTimeMs: 1561881557237
>>> >>> currentWatermarkMs: 0
>>> >>> hasTimedOut: false
*/

spark.table(queryName).show(truncate = false)
/**
+-----+-----+
|value|count|
+-----+-----+
|1    |[1]  |
|2    |[1]  |
+-----+-----+
*/

// With at least one execution we can review the execution plan
streamingQuery.explain
/**
== Physical Plan ==
*(2) Project [_1#928L AS value#931L, _2#929 AS count#932]
+- *(2) SerializeFromObject [assertnotnull(input[0, scala.Tuple2, true])._1 AS _1#928L, if (isnull(assertnotnull(input[0, scala.Tuple2, true])._2)) null else named_struct(value, assertnotnull(assertnotnull(input[0, scala.Tuple2, true])._2).value) AS _2#929]
   +- FlatMapGroupsWithState $line140.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$Lambda$4117/181063008@d2cdc82, value#923: bigint, newInstance(class scala.Tuple2), [value#923L], [time#915-T10000ms, value#916L], obj#927: scala.Tuple2, state info [ checkpoint = file:/tmp/checkpoint-FlatMapGroupsWithStateExec_demo/state, runId = 95c3917c-2fd7-45b2-86f6-6c01f0115e1d, opId = 0, ver = 1, numPartitions = 1], class[value[0]: bigint], 2, Update, EventTimeTimeout, 1561881557499, 5000
      +- *(1) Sort [value#923L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(value#923L, 1)
            +- AppendColumns $line140.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$Lambda$4118/2131767153@3e606b4c, newInstance(class scala.Tuple2), [input[0, bigint, false] AS value#923L]
               +- EventTimeWatermark time#915: timestamp, interval 10 seconds
                  +- LocalTableScan <empty>, [time#915, value#916L]
*/

type Millis = Long
def toMillis(datetime: String): Millis = {
  import java.time.format.DateTimeFormatter
  import java.time.LocalDateTime
  import java.time.ZoneOffset
  LocalDateTime
    .parse(datetime, DateTimeFormatter.ISO_DATE_TIME)
    .toInstant(ZoneOffset.UTC)
    .toEpochMilli
}

val currentWatermark = streamingQuery.lastProgress.eventTime.get("watermark")
val currentWatermarkSecs = toMillis(currentWatermark).millis.toSeconds.seconds

val expectedWatermarkSecs = 5.seconds
assert(currentWatermarkSecs == expectedWatermarkSecs, s"Current event-time watermark is $currentWatermarkSecs, but should be $expectedWatermarkSecs (maximum event time - delayThreshold ${delayThreshold.toMillis})")

// Let's access the FlatMapGroupsWithStateExec physical operator
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.execution.streaming.StreamExecution
val engine: StreamExecution = streamingQuery
  .asInstanceOf[StreamingQueryWrapper]
  .streamingQuery

import org.apache.spark.sql.execution.streaming.IncrementalExecution
val lastMicroBatch: IncrementalExecution = engine.lastExecution

// Access executedPlan that is the optimized physical query plan ready for execution
// All streaming optimizations have been applied at this point
val plan = lastMicroBatch.executedPlan

// Find the FlatMapGroupsWithStateExec physical operator
import org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec
val flatMapOp = plan.collect { case op: FlatMapGroupsWithStateExec => op }.head

// Display metrics
import org.apache.spark.sql.execution.metric.SQLMetric
def formatMetrics(name: String, metric: SQLMetric) = {
  val desc = metric.name.getOrElse("")
  val value = metric.value
  f"| $name%-30s | $desc%-69s | $value%-10s"
}
flatMapOp.metrics.map { case (name, metric) => formatMetrics(name, metric) }.foreach(println)
/**
| numTotalStateRows              | number of total state rows                                            | 0
| stateMemory                    | memory used by state total (min, med, max)                            | 390
| loadedMapCacheHitCount         | count of cache hit on states cache in provider                        | 1
| numOutputRows                  | number of output rows                                                 | 0
| stateOnCurrentVersionSizeBytes | estimated size of state only on current version total (min, med, max) | 102
| loadedMapCacheMissCount        | count of cache miss on states cache in provider                       | 0
| commitTimeMs                   | time to commit changes total (min, med, max)                          | -2
| allRemovalsTimeMs              | total time to remove rows total (min, med, max)                       | -2
| numUpdatedStateRows            | number of updated state rows                                          | 0
| allUpdatesTimeMs               | total time to update rows total (min, med, max)                       | -2
*/

val batch = Seq(
  Event(secs = 1,  value = 1),  // under the watermark (5000 ms) so it's disregarded
  Event(secs = 6,  value = 3))  // above the watermark so it should be counted
events.addData(batch)
streamingQuery.processAllAvailable()

/**
>>> keyCounts(key = 3, state = <empty>)
>>> >>> currentProcessingTimeMs: 1561881643568
>>> >>> currentWatermarkMs: 5000
>>> >>> hasTimedOut: false
*/

spark.table(queryName).show(truncate = false)
/**
+-----+-----+
|value|count|
+-----+-----+
|1    |[1]  |
|2    |[1]  |
|3    |[1]  |
+-----+-----+
*/

val batch = Seq(
  Event(secs = 17,  value = 3))  // advances the watermark
events.addData(batch)
streamingQuery.processAllAvailable()

/**
>>> keyCounts(key = 3, state = <empty>)
>>> >>> currentProcessingTimeMs: 1561881672887
>>> >>> currentWatermarkMs: 5000
>>> >>> hasTimedOut: false
*/

val currentWatermark = streamingQuery.lastProgress.eventTime.get("watermark")
val currentWatermarkSecs = toMillis(currentWatermark).millis.toSeconds.seconds

val expectedWatermarkSecs = 7.seconds
assert(currentWatermarkSecs == expectedWatermarkSecs, s"Current event-time watermark is $currentWatermarkSecs, but should be $expectedWatermarkSecs (maximum event time - delayThreshold ${delayThreshold.toMillis})")

spark.table(queryName).show(truncate = false)
/**
+-----+-----+
|value|count|
+-----+-----+
|1    |[1]  |
|2    |[1]  |
|3    |[1]  |
|3    |[1]  |
+-----+-----+
*/

val batch = Seq(
  Event(secs = 18,  value = 3))  // advances the watermark
events.addData(batch)
streamingQuery.processAllAvailable()

/**
>>> keyCounts(key = 3, state = <empty>)
>>> >>> currentProcessingTimeMs: 1561881778165
>>> >>> currentWatermarkMs: 7000
>>> >>> hasTimedOut: false
*/

// Eventually...
streamingQuery.stop()