Demo: Exploring Checkpointed State¶
The following demo shows the internals of the checkpointed state of a stateful streaming query.
The demo uses the state checkpoint directory that was used in Demo: Streaming Watermark with Aggregation in Append Output Mode.
// Change the path to match your configuration
val checkpointRootLocation = "/tmp/checkpoint-watermark_demo/state"
val version = 1L
import org.apache.spark.sql.execution.streaming.state.StateStoreId
val storeId = StateStoreId(
checkpointRootLocation,
operatorId = 0,
partitionId = 0)
// The key and value schemas should match the watermark demo
// .groupBy(window($"time", windowDuration.toString) as "sliding_window")
import org.apache.spark.sql.types.{TimestampType, StructField, StructType}
val keySchema = StructType(
StructField("sliding_window",
StructType(
StructField("start", TimestampType, nullable = true) ::
StructField("end", TimestampType, nullable = true) :: Nil),
nullable = false) :: Nil)
scala> keySchema.printTreeString
root
|-- sliding_window: struct (nullable = false)
| |-- start: timestamp (nullable = true)
| |-- end: timestamp (nullable = true)
// .agg(collect_list("batch") as "batches", collect_list("value") as "values")
import org.apache.spark.sql.types.{ArrayType, LongType}
val valueSchema = StructType(
StructField("batches", ArrayType(LongType, true), true) ::
StructField("values", ArrayType(LongType, true), true) :: Nil)
scala> valueSchema.printTreeString
root
|-- batches: array (nullable = true)
| |-- element: long (containsNull = true)
|-- values: array (nullable = true)
| |-- element: long (containsNull = true)
val indexOrdinal = None
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
val storeConf = StateStoreConf(spark.sessionState.conf)
val hadoopConf = spark.sessionState.newHadoopConf()
import org.apache.spark.sql.execution.streaming.state.StateStoreProvider
val provider = StateStoreProvider.createAndInit(
storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
// You may want to use the following higher-level code instead
import java.util.UUID
val queryRunId = UUID.randomUUID
import org.apache.spark.sql.execution.streaming.state.StateStoreProviderId
val storeProviderId = StateStoreProviderId(storeId, queryRunId)
import org.apache.spark.sql.execution.streaming.state.StateStore
val store = StateStore.get(
storeProviderId,
keySchema,
valueSchema,
indexOrdinal,
version,
storeConf,
hadoopConf)
import org.apache.spark.sql.execution.streaming.state.UnsafeRowPair
def formatRowPair(rowPair: UnsafeRowPair) = {
s"(${rowPair.key.getLong(0)}, ${rowPair.value.getLong(0)})"
}
store.iterator.map(formatRowPair).foreach(println)
// WIP: Missing value (per window)
def formatRowPair(rowPair: UnsafeRowPair) = {
val window = rowPair.key.getStruct(0, 2)
import scala.concurrent.duration._
val begin = window.getLong(0).millis.toSeconds
val end = window.getLong(1).millis.toSeconds
val value = rowPair.value.getStruct(0, 4)
// input is (time, value, batch) all longs
val t = value.getLong(1).millis.toSeconds
val v = value.getLong(2)
val b = value.getLong(3)
s"(key: [$begin, $end], ($t, $v, $b))"
}
store.iterator.map(formatRowPair).foreach(println)