You've spent the afternoon building a StatefulProcessor for your TransformWithState streaming job. It tracks per-user sessions, accumulates running totals, or deduplicates events. Now you want to know if it actually works.
So you wire up a streaming source, start a query, push some test rows, wait for a micro-batch to fire, check the sink, squint at the logs, tweak something, restart, and do it all again. The feedback loop is slow, the state is invisible, and reproducing a specific sequence of inputs feels like guesswork.
TwsTester changes that. It's a small test helper, available in both PySpark and Scala starting in Spark 4.2.0 (DBR 18.2). You hand it your StatefulProcessor, feed it rows, and inspect outputs and internal state instantly, without standing up a streaming query.
TwsTester is a unit-testing harness for StatefulProcessor implementations used with the transformWithState operator in Structured Streaming. You hand it your processor and some input rows; it drives the processor the same way Spark would during a real micro-batch and returns the resulting output rows.
It's available at:
pyspark.sql.streaming.TwsTesterorg.apache.spark.sql.streaming.TwsTesterThe emphasis is on small, focused tests you can run in a notebook cell or a CI pipeline, not full end-to-end streaming tests with sources and sinks. Think of it as the difference between calling a function in a unit test and spinning up a full integration environment to exercise the same code path.
Three pain points come up repeatedly when developing stateful streaming logic:
TwsTester addresses all three:
This all runs locally against a regular SparkSession. No cluster infrastructure, no streaming sources, no sinks.
The workflow is straightforward:
transformWithState: define init, handleInputRows, and optionally handleExpiredTimer and close."None", "ProcessingTime", or "EventTime") and output mode ("Append", "Update", or "Complete").test(key, rows) with a grouping key and a list of input rows. TwsTester drives your processor and returns the output rows.peekValueState, peekListState, or peekMapState. Seed or override state with the corresponding update* methods.setProcessingTime or setWatermark to fire expired timers and collect their output.Steps 3-5 can be repeated in any order, which makes it easy to build up multi-step test scenarios.
Let's walk through a realistic development scenario. We'll build a processor that maintains a running sum of transaction amounts per user, write tests for it, discover a bug, fix it, and verify the fix. This is exactly the kind of tight iteration loop that TwsTester enables.
Here's our first pass at a RunningSumProcessor. It takes input rows with an amount field, accumulates a per-user total, and emits the updated running sum with each row.
from pyspark.sql import Row, SparkSession
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming import TwsTester
from pyspark.sql.types import StructType, StructField, LongType
spark = SparkSession.builder.getOrCreate()
class RunningSumProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
self.handle = handle
sum_schema = StructType([StructField("sum", LongType())])
self.sum_state = handle.getValueState("sum", sum_schema)
def handleInputRows(self, key, rows, timerValues):
running_sum = 0
results = []
for row in rows:
amount = row["amount"]
running_sum += amount
results.append(Row(user_id=key[0], amount=amount, running_sum=running_sum))
self.sum_state.update((running_sum,))
return iter(results)
def close(self) -> None:
passimport org.apache.spark.sql.streaming._
import org.apache.spark.sql.Encoders
class RunningSumProcessor
extends StatefulProcessor[String, Long, (String, Long, Long)] {
@transient private var sumState: ValueState[Long] = _
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
sumState = getHandle.getValueState[Long](
"sum", Encoders.scalaLong, TTLConfig.NONE)
}
override def handleInputRows(
key: String,
rows: Iterator[Long],
timerValues: TimerValues
): Iterator[(String, Long, Long)] = {
var runningSum = 0L
val results = rows.map { amount =>
runningSum += amount
(key, amount, runningSum)
}.toList
sumState.update(runningSum)
results.iterator
}
}Looks reasonable. Let's test it.
tester = TwsTester(
processor=RunningSumProcessor(),
timeMode="None",
outputMode="Append",
)val tester = new TwsTester(
processor = new RunningSumProcessor(),
timeMode = TimeMode.None(),
outputMode = OutputMode.Append()
)No sources, no sinks, no checkpoint directory.
We process two transactions for Alice in one call. The running sum should accumulate within the batch.
output = tester.test(
key="alice",
input=[Row(amount=10), Row(amount=5)],
)
for row in output:
print(row)
# Row(user_id='alice', amount=10, running_sum=10)
# Row(user_id='alice', amount=5, running_sum=15)val output = tester.test("alice", List(10L, 5L))
println(output)
// List(("alice", 10, 10), ("alice", 5, 15))Looks right. Alice's running sum goes 10, then 15.
Now Alice makes another transaction. She already has a running sum of 15, so the new total should be 18.
output2 = tester.test(
key="alice",
input=[Row(amount=3)],
)
for row in output2:
print(row)
# Expected: Row(user_id='alice', amount=3, running_sum=18)
# Actual: Row(user_id='alice', amount=3, running_sum=3) <- wrong!val output2 = tester.test("alice", List(3L))
println(output2)
// Expected: List(("alice", 3, 18))
// Actual: List(("alice", 3, 3)) <- wrong!The running sum is 3, not 18. The processor reset Alice's total instead of continuing from where it left off.
Look at the first line of handleInputRows:
running_sum = 0 # Always starts from zero, ignoring existing statevar runningSum = 0L // Always starts from zero, ignoring existing stateThe processor initializes running_sum to zero on every call instead of reading the existing value from state. Within a single batch this doesn't matter because the accumulation happens in a loop. But across batches, the state from previous calls is silently discarded.
Without TwsTester, this bug might not surface until production, where it would show up as totals that periodically "reset" at unpredictable intervals (whenever a new micro-batch starts for that key).
The fix is two lines: read existing state before accumulating.
class RunningSumProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
self.handle = handle
sum_schema = StructType([StructField("sum", LongType())])
self.sum_state = handle.getValueState("sum", sum_schema)
def handleInputRows(self, key, rows, timerValues):
# Fixed: read existing state instead of starting from zero
existing = self.sum_state.get()
running_sum = existing[0] if existing is not None else 0
results = []
for row in rows:
amount = row["amount"]
running_sum += amount
results.append(Row(user_id=key[0], amount=amount, running_sum=running_sum))
self.sum_state.update((running_sum,))
return iter(results)
def close(self) -> None:
passclass RunningSumProcessor
extends StatefulProcessor[String, Long, (String, Long, Long)] {
@transient private var sumState: ValueState[Long] = _
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
sumState = getHandle.getValueState[Long](
"sum", Encoders.scalaLong, TTLConfig.NONE)
}
override def handleInputRows(
key: String,
rows: Iterator[Long],
timerValues: TimerValues
): Iterator[(String, Long, Long)] = {
// Fixed: read existing state instead of starting from zero
var runningSum = if (sumState.exists()) sumState.get() else 0L
val results = rows.map { amount =>
runningSum += amount
(key, amount, runningSum)
}.toList
sumState.update(runningSum)
results.iterator
}
}tester = TwsTester(
processor=RunningSumProcessor(), # fixed version
timeMode="None",
outputMode="Append",
)
# Test 1: single batch
output = tester.test(key="alice", input=[Row(amount=10), Row(amount=5)])
for row in output:
print(row)
# Row(user_id='alice', amount=10, running_sum=10)
# Row(user_id='alice', amount=5, running_sum=15)
# Test 2: second batch for the same key
output2 = tester.test(key="alice", input=[Row(amount=3)])
for row in output2:
print(row)
# Row(user_id='alice', amount=3, running_sum=18) PASSval tester = new TwsTester(
processor = new RunningSumProcessor(), // fixed version
timeMode = TimeMode.None(),
outputMode = OutputMode.Append()
)
// Test 1: single batch
val output = tester.test("alice", List(10L, 5L))
println(output)
// List(("alice", 10, 10), ("alice", 5, 15))
// Test 2: second batch for the same key
val output2 = tester.test("alice", List(3L))
println(output2)
// List(("alice", 3, 18)) PASSBoth tests pass. From bug to fix to verification in seconds, not the minutes (or longer) it would take to redeploy a streaming query and push test data through a source.
After calling test(), the processor's state is still live inside the tester. We can peek at it:
alice_sum = tester.peekValueState("sum", "alice")
print(alice_sum) # (18,)val aliceSum = tester.peekValueState[Long]("sum", "alice")
println(aliceSum) // Some(18)We can also process rows for a different key and verify that state is isolated:
output_bob = tester.test(key="bob", input=[Row(amount=7)])
print(output_bob)
# [Row(user_id='bob', amount=7, running_sum=7)]
# Bob's state is independent
print(tester.peekValueState("sum", "bob")) # (7,)
print(tester.peekValueState("sum", "alice")) # (18,) -- unchangedval outputBob = tester.test("bob", List(7L))
println(outputBob)
// List(("bob", 7, 7))
// Bob's state is independent
println(tester.peekValueState[Long]("sum", "bob")) // Some(7)
println(tester.peekValueState[Long]("sum", "alice")) // Some(18) -- unchangedSometimes you want to test behavior starting from a known state rather than building up from zero. updateValueState lets you seed state before processing:
tester.updateValueState("sum", "carol", (100,))
output_carol = tester.test(key="carol", input=[Row(amount=25)])
print(output_carol)
# [Row(user_id='carol', amount=25, running_sum=125)]tester.updateValueState[Long]("sum", "carol", 100L)
val outputCarol = tester.test("carol", List(25L))
println(outputCarol)
// List(("carol", 25, 125))This is useful for testing edge cases: what happens at counter rollover, after a long idle period, or when the user already has existing state from a previous session.
Timers let a StatefulProcessor react to the passage of time: session timeouts, periodic flushes, TTL-based cleanup. TwsTester supports testing timer logic by letting you advance time manually.
Here's a processor that registers a processing-time timer to expire idle sessions:
class SessionTimeoutProcessor(StatefulProcessor):
TIMEOUT_MS = 60_000 # 1 minute
def init(self, handle: StatefulProcessorHandle) -> None:
self.handle = handle
schema = StructType([StructField("event_count", LongType())])
self.session = handle.getValueState("session", schema)
def handleInputRows(self, key, rows, timerValues):
existing = self.session.get()
event_count = existing[0] if existing is not None else 0
for row in rows:
event_count += 1
self.session.update((event_count,))
# Reset the timeout: register a timer at current time + timeout
current_time = timerValues.getCurrentProcessingTimeInMs()
self.handle.registerTimer(current_time + self.TIMEOUT_MS)
return iter([Row(user_id=key[0], event_count=event_count, status="active")])
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo):
# Timer fired, session timed out. Clear state and emit a closure event.
event_count = self.session.get()
self.session.clear()
count = event_count[0] if event_count is not None else 0
return iter([Row(user_id=key[0], event_count=count, status="expired")])
def close(self) -> None:
passclass SessionTimeoutProcessor
extends StatefulProcessor[String, String, (String, Long, String)] {
private val TimeoutMs = 60000L
@transient private var sessionState: ValueState[Long] = _
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
sessionState = getHandle.getValueState[Long](
"session", Encoders.scalaLong, TTLConfig.NONE)
}
override def handleInputRows(
key: String,
rows: Iterator[String],
timerValues: TimerValues
): Iterator[(String, Long, String)] = {
var eventCount = if (sessionState.exists()) sessionState.get() else 0L
rows.foreach(_ => eventCount += 1)
sessionState.update(eventCount)
val currentTime = timerValues.getCurrentProcessingTimeInMs()
getHandle.registerTimer(currentTime + TimeoutMs)
Iterator((key, eventCount, "active"))
}
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo
): Iterator[(String, Long, String)] = {
val count = if (sessionState.exists()) sessionState.get() else 0L
sessionState.clear()
Iterator((key, count, "expired"))
}
}Now test it with TwsTester in processing-time mode:
tester = TwsTester(
processor=SessionTimeoutProcessor(),
timeMode="ProcessingTime",
outputMode="Append",
)
# Set an initial processing time
tester.setProcessingTime(1_000)
# Process some events for alice
output = tester.test(
key="alice",
input=[Row(event="click"), Row(event="scroll")],
)
print(output)
# [Row(user_id='alice', event_count=2, status='active')]
# Advance time past the timeout threshold
expired_output = tester.setProcessingTime(62_000)
print(expired_output)
# [Row(user_id='alice', event_count=2, status='expired')]
# State has been cleared
print(tester.peekValueState("session", "alice"))
# Noneval tester = new TwsTester(
processor = new SessionTimeoutProcessor(),
timeMode = TimeMode.ProcessingTime(),
outputMode = OutputMode.Append()
)
// Set an initial processing time
tester.setProcessingTime(1000L)
// Process some events for alice
val output = tester.test("alice", List("click", "scroll"))
println(output)
// List(("alice", 2, "active"))
// Advance time past the timeout threshold
val expiredOutput = tester.setProcessingTime(62000L)
println(expiredOutput)
// List(("alice", 2, "expired"))
// State has been cleared
println(tester.peekValueState[Long]("session", "alice"))
// NoneWe just tested a complete session lifecycle (activity, timeout, cleanup) in a few lines, with full control over time.
The examples above work well in a notebook for interactive exploration. To lock in behavior for CI, wrap them in your test framework of choice.
Here's a pytest example. The same pattern works in ScalaTest; expand the spoiler below for the Scala version.
import pytest
from pyspark.sql import Row, SparkSession
from pyspark.sql.streaming import TwsTester
@pytest.fixture(scope="module")
def spark():
return SparkSession.builder.master("local[*]").getOrCreate()
def test_running_sum_single_batch(spark):
tester = TwsTester(
processor=RunningSumProcessor(),
timeMode="None",
outputMode="Append",
)
output = tester.test("alice", [Row(amount=10), Row(amount=5)])
assert len(output) == 2
assert output[0]["running_sum"] == 10
assert output[1]["running_sum"] == 15
def test_running_sum_across_batches(spark):
tester = TwsTester(
processor=RunningSumProcessor(),
timeMode="None",
outputMode="Append",
)
tester.test("alice", [Row(amount=10), Row(amount=5)])
output = tester.test("alice", [Row(amount=3)])
assert output[0]["running_sum"] == 18
def test_state_isolation_across_keys(spark):
tester = TwsTester(
processor=RunningSumProcessor(),
timeMode="None",
outputMode="Append",
)
tester.test("alice", [Row(amount=10)])
tester.test("bob", [Row(amount=20), Row(amount=5)])
assert tester.peekValueState("sum", "alice") == (10,)
assert tester.peekValueState("sum", "bob") == (25,)import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.streaming._
class RunningSumProcessorTest extends AnyFunSuite {
test("running sum accumulates within a single batch") {
val tester = new TwsTester(
processor = new RunningSumProcessor(),
timeMode = TimeMode.None(),
outputMode = OutputMode.Append()
)
val output = tester.test("alice", List(10L, 5L))
assert(output === List(("alice", 10L, 10L), ("alice", 5L, 15L)))
}
test("running sum accumulates across batches") {
val tester = new TwsTester(
processor = new RunningSumProcessor(),
timeMode = TimeMode.None(),
outputMode = OutputMode.Append()
)
tester.test("alice", List(10L, 5L))
val output = tester.test("alice", List(3L))
assert(output === List(("alice", 3L, 18L)))
}
test("state is isolated across keys") {
val tester = new TwsTester(
processor = new RunningSumProcessor(),
timeMode = TimeMode.None(),
outputMode = OutputMode.Append()
)
tester.test("alice", List(10L))
tester.test("bob", List(20L, 5L))
assert(tester.peekValueState[Long]("sum", "alice") === Some(10L))
assert(tester.peekValueState[Long]("sum", "bob") === Some(25L))
}
}No mock streaming sources, no awaitTermination, no checkpoint cleanup. Just direct assertions on processor behavior.
TwsTester is most valuable when you're in the thick of building or debugging stateful logic:
It's not a replacement for end-to-end integration tests. You'll still want to verify that your processor works correctly within a full transformWithState pipeline. But for the business logic inside the processor, TwsTester gives you the tight feedback loop that stateful streaming has always been missing.
Copy the example above into a Databricks Notebook or a local Spark environment running Spark 4.2.0+, replace RunningSumProcessor with your own processor, and start testing.
For the full API reference and more examples:
StatefulProcessor, state variables, and timer semantics.What is TwsTester?
TwsTester is a unit-testing harness for StatefulProcessor implementations used with the transformWithState operator in Apache Spark Structured Streaming. You hand it your processor and a list of input rows, and it drives the processor the same way Spark would inside a real micro-batch, returning the resulting output rows so you can assert on them directly.
Which Spark and Databricks Runtime versions include TwsTester?
TwsTester is available starting in Apache Spark 4.2.0, which corresponds to Databricks Runtime 18.2 and later.
Is TwsTester available in both PySpark and Scala?
Yes. The PySpark entry point is pyspark.sql.streaming.TwsTester and the Scala entry point is org.apache.spark.sql.streaming.TwsTester. The two surfaces mirror each other, so the same test patterns translate directly between languages.
Does TwsTester replace end-to-end streaming integration tests?
No. TwsTester targets the business logic inside a StatefulProcessor: per-key state, timers, output rows for given inputs. You should still run end-to-end tests against a real transformWithState pipeline to verify source/sink wiring, checkpointing, and fault tolerance. TwsTester compresses the inner feedback loop; integration tests still validate the outer one.
Can TwsTester test processing-time and event-time timers?
Yes. Construct the tester with timeMode="ProcessingTime" or timeMode="EventTime" and use setProcessingTime(...) or setWatermark(...) to advance time on demand. Any expired timers fire and their output is collected, so timeout, session-expiry, and periodic-flush logic can be exercised deterministically.
How is TwsTester different from the State Reader API?
The State Reader API reads state from a query's checkpoint directory after the fact, which is great for inspecting production state. TwsTester runs locally, drives the processor with controlled inputs, and lets you inspect or override state mid-test with peekValueState / updateValueState and friends. They are complementary: State Reader for production debugging, TwsTester for development and CI.
Can I use TwsTester in a CI pipeline?
Yes. Wrap each test in your framework of choice (pytest for PySpark, ScalaTest for Scala) and run them like any other unit test. TwsTester needs only a regular SparkSession, no streaming sources, sinks, or checkpoint directory, so it fits naturally into existing test infrastructure.
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.