flatMapGroupsWithState Operator — Arbitrary Stateful Streaming Aggregation (with Explicit State Logic)
flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
Every time the state function func is executed for a key, the state (as GroupState[S] ) is for this key only.
FIXME Why can’t flatMapGroupsWithState work with Complete output mode?
scala> spark.version
res0: String = 2.3.0-SNAPSHOT
import java.sql.Timestamp
type DeviceId = Int
case class Signal(timestamp: java.sql.Timestamp, value: Long, deviceId: DeviceId)
// input stream
import org.apache.spark.sql.functions._
val signals = spark.
option("rowsPerSecond", 1).
withColumn("value", $"value" % 10). // <-- randomize the values (just for fun)
withColumn("deviceId", rint(rand() * 10) cast "int"). // <-- 10 devices randomly assigned to values
as[Signal] // <-- convert to our type (from "unpleasant" Row)
scala> signals.explain
== Physical Plan ==
*Project [timestamp#0, (value#1L % 10) AS value#5L, cast(ROUND((rand(4440296395341152993) * 10.0)) as int) AS deviceId#9]
+- StreamingRelation rate, [timestamp#0, value#1L]
// stream processing using flatMapGroupsWithState operator
val device: Signal => DeviceId = { case Signal(_, _, deviceId) => deviceId }
val signalsByDevice = signals.groupByKey(device)
import org.apache.spark.sql.streaming.GroupState
type Key = Int
type Count = Long
type State = Map[Key, Count]
case class EventsCounted(deviceId: DeviceId, count: Long)
def countValuesPerKey(deviceId: Int, signalsPerDevice: Iterator[Signal], state: GroupState[State]): Iterator[EventsCounted] = {
val values = signalsPerDevice.toList
println(s"Device: $deviceId")
println(s"Signals (${values.size}):")
values.zipWithIndex.foreach { case (v, idx) => println(s"$idx. $v") }
println(s"State: $state")
// update the state with the count of elements for the key
val initialState: State = Map(deviceId -> 0)
val oldState = state.getOption.getOrElse(initialState)
// the name to highlight that the state is for the key only
val newValue = oldState(deviceId) + values.size
val newState = Map(deviceId -> newValue)
// you must not return as it's already consumed
// that leads to a very subtle error where no elements are in an iterator
// iterators are one-pass data structures
Iterator(EventsCounted(deviceId, newValue))
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
val signalCounter = signalsByDevice.flatMapGroupsWithState(
outputMode = OutputMode.Append,
timeoutConf = GroupStateTimeout.NoTimeout)(func = countValuesPerKey)
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import scala.concurrent.duration._
val sq = signalCounter.
option("truncate", false).
Batch: 0
17/08/21 08:57:29 INFO StreamExecution: Streaming query made progress: {
"id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
"runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
"name" : null,
"timestamp" : "2017-08-21T06:57:26.719Z",
"batchId" : 0,
"numInputRows" : 0,
"processedRowsPerSecond" : 0.0,
"durationMs" : {
"addBatch" : 2404,
"getBatch" : 22,
"getOffset" : 0,
"queryPlanning" : 141,
"triggerExecution" : 2626,
"walCommit" : 41
"stateOperators" : [ {
"numRowsTotal" : 0,
"numRowsUpdated" : 0,
"memoryUsedBytes" : 12599
} ],
"sources" : [ {
"description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
"startOffset" : null,
"endOffset" : 0,
"numInputRows" : 0,
"processedRowsPerSecond" : 0.0
} ],
"sink" : {
"description" : "ConsoleSink[numRows=20, truncate=false]"
17/08/21 08:57:29 DEBUG StreamExecution: batch 0 committed
Batch: 1
Device: 3
Signals (1):
0. Signal(2017-08-21 08:57:27.682,1,3)
State: GroupState(<undefined>)
Device: 8
Signals (1):
0. Signal(2017-08-21 08:57:26.682,0,8)
State: GroupState(<undefined>)
Device: 7
Signals (1):
0. Signal(2017-08-21 08:57:28.682,2,7)
State: GroupState(<undefined>)
|3 |1 |
|8 |1 |
|7 |1 |
17/08/21 08:57:31 INFO StreamExecution: Streaming query made progress: {
"id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
"runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
"name" : null,
"timestamp" : "2017-08-21T06:57:30.004Z",
"batchId" : 1,
"numInputRows" : 3,
"inputRowsPerSecond" : 0.91324200913242,
"processedRowsPerSecond" : 2.2388059701492535,
"durationMs" : {
"addBatch" : 1245,
"getBatch" : 22,
"getOffset" : 0,
"queryPlanning" : 23,
"triggerExecution" : 1340,
"walCommit" : 44
"stateOperators" : [ {
"numRowsTotal" : 3,
"numRowsUpdated" : 3,
"memoryUsedBytes" : 18095
} ],
"sources" : [ {
"description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
"startOffset" : 0,
"endOffset" : 3,
"numInputRows" : 3,
"inputRowsPerSecond" : 0.91324200913242,
"processedRowsPerSecond" : 2.2388059701492535
} ],
"sink" : {
"description" : "ConsoleSink[numRows=20, truncate=false]"
17/08/21 08:57:31 DEBUG StreamExecution: batch 1 committed
Batch: 2
Device: 1
Signals (1):
0. Signal(2017-08-21 08:57:36.682,0,1)
State: GroupState(<undefined>)
Device: 3
Signals (2):
0. Signal(2017-08-21 08:57:32.682,6,3)
1. Signal(2017-08-21 08:57:35.682,9,3)
State: GroupState(Map(3 -> 1))
Device: 5
Signals (1):
0. Signal(2017-08-21 08:57:34.682,8,5)
State: GroupState(<undefined>)
Device: 4
Signals (1):
0. Signal(2017-08-21 08:57:29.682,3,4)
State: GroupState(<undefined>)
Device: 8
Signals (2):
0. Signal(2017-08-21 08:57:31.682,5,8)
1. Signal(2017-08-21 08:57:33.682,7,8)
State: GroupState(Map(8 -> 1))
Device: 7
Signals (2):
0. Signal(2017-08-21 08:57:30.682,4,7)
1. Signal(2017-08-21 08:57:37.682,1,7)
State: GroupState(Map(7 -> 1))
Device: 0
Signals (1):
0. Signal(2017-08-21 08:57:38.682,2,0)
State: GroupState(<undefined>)
|1 |1 |
|3 |3 |
|5 |1 |
|4 |1 |
|8 |3 |
|7 |3 |
|0 |1 |
17/08/21 08:57:41 INFO StreamExecution: Streaming query made progress: {
"id" : "a43822a6-500b-4f02-9133-53e9d39eedbf",
"runId" : "79cb037e-0f28-4faf-a03e-2572b4301afe",
"name" : null,
"timestamp" : "2017-08-21T06:57:40.005Z",
"batchId" : 2,
"numInputRows" : 10,
"inputRowsPerSecond" : 0.9999000099990002,
"processedRowsPerSecond" : 9.242144177449168,
"durationMs" : {
"addBatch" : 1032,
"getBatch" : 8,
"getOffset" : 0,
"queryPlanning" : 19,
"triggerExecution" : 1082,
"walCommit" : 21
"stateOperators" : [ {
"numRowsTotal" : 7,
"numRowsUpdated" : 7,
"memoryUsedBytes" : 19023
} ],
"sources" : [ {
"description" : "RateSource[rowsPerSecond=1, rampUpTimeSeconds=0, numPartitions=8]",
"startOffset" : 3,
"endOffset" : 13,
"numInputRows" : 10,
"inputRowsPerSecond" : 0.9999000099990002,
"processedRowsPerSecond" : 9.242144177449168
} ],
"sink" : {
"description" : "ConsoleSink[numRows=20, truncate=false]"
17/08/21 08:57:41 DEBUG StreamExecution: batch 2 committed
// In the end...
// Use stateOperators to access the stats
scala> println(sq.lastProgress.stateOperators(0).prettyJson)
"numRowsTotal" : 7,
"numRowsUpdated" : 7,
"memoryUsedBytes" : 19023
Internally, flatMapGroupsWithState
operator creates a Dataset
with FlatMapGroupsWithState unary logical operator.
scala> :type signalCounter
scala> println(signalCounter.queryExecution.logical.numberedTreeString)
00 'SerializeFromObject [assertnotnull(assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true])).deviceId AS deviceId#25, assertnotnull(assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true])).count AS count#26L]
01 +- 'FlatMapGroupsWithState <function3>, unresolveddeserializer(upcast(getcolumnbyordinal(0, IntegerType), IntegerType, - root class: "scala.Int"), value#20), unresolveddeserializer(newInstance(class $line17.$read$$iw$$iw$Signal), timestamp#0, value#5L, deviceId#9), [value#20], [timestamp#0, value#5L, deviceId#9], obj#24: $line27.$read$$iw$$iw$EventsCounted, class[value[0]: map<int,bigint>], Append, false, NoTimeout
02 +- AppendColumns <function1>, class $line17.$read$$iw$$iw$Signal, [StructField(timestamp,TimestampType,true), StructField(value,LongType,false), StructField(deviceId,IntegerType,false)], newInstance(class $line17.$read$$iw$$iw$Signal), [input[0, int, false] AS value#20]
03 +- Project [timestamp#0, value#5L, cast(ROUND((rand(4440296395341152993) * cast(10 as double))) as int) AS deviceId#9]
04 +- Project [timestamp#0, (value#1L % cast(10 as bigint)) AS value#5L]
05 +- StreamingRelation DataSource(org.apache.spark.sql.SparkSession@385c6d6b,rate,List(),None,List(),None,Map(rowsPerSecond -> 1),None), rate, [timestamp#0, value#1L]
scala> signalCounter.explain
== Physical Plan ==
*SerializeFromObject [assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true]).deviceId AS deviceId#25, assertnotnull(input[0, $line27.$read$$iw$$iw$EventsCounted, true]).count AS count#26L]
+- FlatMapGroupsWithState <function3>, value#20: int, newInstance(class $line17.$read$$iw$$iw$Signal), [value#20], [timestamp#0, value#5L, deviceId#9], obj#24: $line27.$read$$iw$$iw$EventsCounted, StatefulOperatorStateInfo(<unknown>,50c7ece5-0716-4e43-9b56-09842db8baf1,0,0), class[value[0]: map<int,bigint>], Append, NoTimeout, 0, 0
+- *Sort [value#20 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(value#20, 200)
+- AppendColumns <function1>, newInstance(class $line17.$read$$iw$$iw$Signal), [input[0, int, false] AS value#20]
+- *Project [timestamp#0, (value#1L % 10) AS value#5L, cast(ROUND((rand(4440296395341152993) * 10.0)) as int) AS deviceId#9]
+- StreamingRelation rate, [timestamp#0, value#1L]
reports a IllegalArgumentException
when the input outputMode
is neither Append
nor Update
scala> val result = signalsByDevice.flatMapGroupsWithState(
| outputMode = OutputMode.Complete,
| timeoutConf = GroupStateTimeout.NoTimeout)(func = stateFn)
java.lang.IllegalArgumentException: The output mode of function should be append or update
at org.apache.spark.sql.KeyValueGroupedDataset.flatMapGroupsWithState(KeyValueGroupedDataset.scala:381)
... 54 elided
FIXME Examples for append and update output modes (to demo the difference) |
FIXME Examples for GroupStateTimeout.EventTimeTimeout with withWatermark operator