cancel
Showing results for 
Search instead for 
Did you mean: 
Technical Blog
Explore in-depth articles, tutorials, and insights on data analytics and machine learning in the Databricks Technical Blog. Stay updated on industry trends, best practices, and advanced techniques.
cancel
Showing results for 
Search instead for 
Did you mean: 
craig_lukasik
Databricks Employee
Databricks Employee

You've spent the afternoon building a StatefulProcessor for your TransformWithState streaming job. It tracks per-user sessions, accumulates running totals, or deduplicates events. Now you want to know if it actually works.

So you wire up a streaming source, start a query, push some test rows, wait for a micro-batch to fire, check the sink, squint at the logs, tweak something, restart, and do it all again. The feedback loop is slow, the state is invisible, and reproducing a specific sequence of inputs feels like guesswork.

TwsTester changes that. It's a small test helper, available in both PySpark and Scala starting in Spark 4.2.0 (DBR 18.2). You hand it your StatefulProcessor, feed it rows, and inspect outputs and internal state instantly, without standing up a streaming query.

What is TwsTester?

TwsTester is a unit-testing harness for StatefulProcessor implementations used with the transformWithState operator in Structured Streaming. You hand it your processor and some input rows; it drives the processor the same way Spark would during a real micro-batch and returns the resulting output rows.

It's available at:

  • PySpark: pyspark.sql.streaming.TwsTester
  • Scala: org.apache.spark.sql.streaming.TwsTester

The emphasis is on small, focused tests you can run in a notebook cell or a CI pipeline, not full end-to-end streaming tests with sources and sinks. Think of it as the difference between calling a function in a unit test and spinning up a full integration environment to exercise the same code path.

Why TwsTester?

Three pain points come up repeatedly when developing stateful streaming logic:

  1. Slow feedback loops. Every change to your processor means restarting a streaming query, wiring up sources and sinks, and waiting for data to flow through. For logic that's still taking shape, this is expensive.
  2. Hard-to-observe state and timers. In a running query, inspecting state mid-flight isn't straightforward. The State Reader API helps here; it lets you read state from a query's checkpoint directory after the fact, which is valuable for debugging production workloads and understanding state evolution over time. But the State Reader API requires access to the checkpoint directory, and it's oriented toward inspecting state after a query has run, not toward feeding controlled inputs and verifying outputs as part of a development loop.
  3. Limited reproducibility. Some bugs only surface with specific input orderings or timing. Reproducing those conditions end-to-end is tedious at best.

TwsTester addresses all three:

  • Process a controlled sequence of input rows per key and inspect the exact outputs.
  • Peek at and manipulate state directly, including value state, list state, and map state, at any point during a test.
  • Advance processing time or the watermark on demand, and verify that your timer logic fires correctly.
  • Inject initial state so you can test mid-stream scenarios like "this user already has an active session."

This all runs locally against a regular SparkSession. No cluster infrastructure, no streaming sources, no sinks.

How it works

The workflow is straightforward:

  1. Implement your StatefulProcessor as you normally would for transformWithState: define init, handleInputRows, and optionally handleExpiredTimer and close.
  2. Construct a TwsTester with your processor, choosing a time mode ("None", "ProcessingTime", or "EventTime") and output mode ("Append", "Update", or "Complete").
  3. Call test(key, rows) with a grouping key and a list of input rows. TwsTester drives your processor and returns the output rows.
  4. Inspect state with peekValueState, peekListState, or peekMapState. Seed or override state with the corresponding update* methods.
  5. Advance time with setProcessingTime or setWatermark to fire expired timers and collect their output.

Steps 3-5 can be repeated in any order, which makes it easy to build up multi-step test scenarios.

Example: catching a bug with TwsTester

Let's walk through a realistic development scenario. We'll build a processor that maintains a running sum of transaction amounts per user, write tests for it, discover a bug, fix it, and verify the fix. This is exactly the kind of tight iteration loop that TwsTester enables.

The processor (first attempt)

Here's our first pass at a RunningSumProcessor. It takes input rows with an amount field, accumulates a per-user total, and emits the updated running sum with each row.

from pyspark.sql import Row, SparkSession
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.streaming import TwsTester
from pyspark.sql.types import StructType, StructField, LongType

spark = SparkSession.builder.getOrCreate()


class RunningSumProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        self.handle = handle
        sum_schema = StructType([StructField("sum", LongType())])
        self.sum_state = handle.getValueState("sum", sum_schema)

    def handleInputRows(self, key, rows, timerValues):
        running_sum = 0

        results = []
        for row in rows:
            amount = row["amount"]
            running_sum += amount
            results.append(Row(user_id=key[0], amount=amount, running_sum=running_sum))

        self.sum_state.update((running_sum,))
        return iter(results)

    def close(self) -> None:
        pass
â–¸ Show Scala equivalent
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.Encoders

class RunningSumProcessor
    extends StatefulProcessor[String, Long, (String, Long, Long)] {

  @transient private var sumState: ValueState[Long] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    sumState = getHandle.getValueState[Long](
      "sum", Encoders.scalaLong, TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      rows: Iterator[Long],
      timerValues: TimerValues
  ): Iterator[(String, Long, Long)] = {
    var runningSum = 0L

    val results = rows.map { amount =>
      runningSum += amount
      (key, amount, runningSum)
    }.toList

    sumState.update(runningSum)
    results.iterator
  }
}

Looks reasonable. Let's test it.

Setting up TwsTester

tester = TwsTester(
    processor=RunningSumProcessor(),
    timeMode="None",
    outputMode="Append",
)
â–¸ Show Scala equivalent
val tester = new TwsTester(
  processor = new RunningSumProcessor(),
  timeMode = TimeMode.None(),
  outputMode = OutputMode.Append()
)

No sources, no sinks, no checkpoint directory.

First test: single batch

We process two transactions for Alice in one call. The running sum should accumulate within the batch.

output = tester.test(
    key="alice",
    input=[Row(amount=10), Row(amount=5)],
)
for row in output:
    print(row)
# Row(user_id='alice', amount=10, running_sum=10)
# Row(user_id='alice', amount=5, running_sum=15)
â–¸ Show Scala equivalent
val output = tester.test("alice", List(10L, 5L))
println(output)
// List(("alice", 10, 10), ("alice", 5, 15))

Looks right. Alice's running sum goes 10, then 15.

Second test: another batch for the same key

Now Alice makes another transaction. She already has a running sum of 15, so the new total should be 18.

output2 = tester.test(
    key="alice",
    input=[Row(amount=3)],
)
for row in output2:
    print(row)
# Expected: Row(user_id='alice', amount=3, running_sum=18)
# Actual:   Row(user_id='alice', amount=3, running_sum=3)   <- wrong!
â–¸ Show Scala equivalent
val output2 = tester.test("alice", List(3L))
println(output2)
// Expected: List(("alice", 3, 18))
// Actual:   List(("alice", 3, 3))   <- wrong!

The running sum is 3, not 18. The processor reset Alice's total instead of continuing from where it left off.

What went wrong

Look at the first line of handleInputRows:

running_sum = 0  # Always starts from zero, ignoring existing state
â–¸ Show Scala equivalent
var runningSum = 0L  // Always starts from zero, ignoring existing state

The processor initializes running_sum to zero on every call instead of reading the existing value from state. Within a single batch this doesn't matter because the accumulation happens in a loop. But across batches, the state from previous calls is silently discarded.

Without TwsTester, this bug might not surface until production, where it would show up as totals that periodically "reset" at unpredictable intervals (whenever a new micro-batch starts for that key).

Fixing the processor

The fix is two lines: read existing state before accumulating.

class RunningSumProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        self.handle = handle
        sum_schema = StructType([StructField("sum", LongType())])
        self.sum_state = handle.getValueState("sum", sum_schema)

    def handleInputRows(self, key, rows, timerValues):
        # Fixed: read existing state instead of starting from zero
        existing = self.sum_state.get()
        running_sum = existing[0] if existing is not None else 0

        results = []
        for row in rows:
            amount = row["amount"]
            running_sum += amount
            results.append(Row(user_id=key[0], amount=amount, running_sum=running_sum))

        self.sum_state.update((running_sum,))
        return iter(results)

    def close(self) -> None:
        pass
â–¸ Show Scala equivalent
class RunningSumProcessor
    extends StatefulProcessor[String, Long, (String, Long, Long)] {

  @transient private var sumState: ValueState[Long] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    sumState = getHandle.getValueState[Long](
      "sum", Encoders.scalaLong, TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      rows: Iterator[Long],
      timerValues: TimerValues
  ): Iterator[(String, Long, Long)] = {
    // Fixed: read existing state instead of starting from zero
    var runningSum = if (sumState.exists()) sumState.get() else 0L

    val results = rows.map { amount =>
      runningSum += amount
      (key, amount, runningSum)
    }.toList

    sumState.update(runningSum)
    results.iterator
  }
}

Re-running the tests

tester = TwsTester(
    processor=RunningSumProcessor(),  # fixed version
    timeMode="None",
    outputMode="Append",
)

# Test 1: single batch
output = tester.test(key="alice", input=[Row(amount=10), Row(amount=5)])
for row in output:
    print(row)
# Row(user_id='alice', amount=10, running_sum=10)
# Row(user_id='alice', amount=5, running_sum=15)

# Test 2: second batch for the same key
output2 = tester.test(key="alice", input=[Row(amount=3)])
for row in output2:
    print(row)
# Row(user_id='alice', amount=3, running_sum=18)  PASS
â–¸ Show Scala equivalent
val tester = new TwsTester(
  processor = new RunningSumProcessor(),  // fixed version
  timeMode = TimeMode.None(),
  outputMode = OutputMode.Append()
)

// Test 1: single batch
val output = tester.test("alice", List(10L, 5L))
println(output)
// List(("alice", 10, 10), ("alice", 5, 15))

// Test 2: second batch for the same key
val output2 = tester.test("alice", List(3L))
println(output2)
// List(("alice", 3, 18))  PASS

Both tests pass. From bug to fix to verification in seconds, not the minutes (or longer) it would take to redeploy a streaming query and push test data through a source.

Inspecting state directly

After calling test(), the processor's state is still live inside the tester. We can peek at it:

alice_sum = tester.peekValueState("sum", "alice")
print(alice_sum)  # (18,)
â–¸ Show Scala equivalent
val aliceSum = tester.peekValueState[Long]("sum", "alice")
println(aliceSum)  // Some(18)

We can also process rows for a different key and verify that state is isolated:

output_bob = tester.test(key="bob", input=[Row(amount=7)])
print(output_bob)
# [Row(user_id='bob', amount=7, running_sum=7)]

# Bob's state is independent
print(tester.peekValueState("sum", "bob"))    # (7,)
print(tester.peekValueState("sum", "alice"))   # (18,) -- unchanged
â–¸ Show Scala equivalent
val outputBob = tester.test("bob", List(7L))
println(outputBob)
// List(("bob", 7, 7))

// Bob's state is independent
println(tester.peekValueState[Long]("sum", "bob"))    // Some(7)
println(tester.peekValueState[Long]("sum", "alice"))   // Some(18) -- unchanged

Seeding state for mid-stream scenarios

Sometimes you want to test behavior starting from a known state rather than building up from zero. updateValueState lets you seed state before processing:

tester.updateValueState("sum", "carol", (100,))

output_carol = tester.test(key="carol", input=[Row(amount=25)])
print(output_carol)
# [Row(user_id='carol', amount=25, running_sum=125)]
â–¸ Show Scala equivalent
tester.updateValueState[Long]("sum", "carol", 100L)

val outputCarol = tester.test("carol", List(25L))
println(outputCarol)
// List(("carol", 25, 125))

This is useful for testing edge cases: what happens at counter rollover, after a long idle period, or when the user already has existing state from a previous session.

Adding timer logic

Timers let a StatefulProcessor react to the passage of time: session timeouts, periodic flushes, TTL-based cleanup. TwsTester supports testing timer logic by letting you advance time manually.

Here's a processor that registers a processing-time timer to expire idle sessions:

class SessionTimeoutProcessor(StatefulProcessor):
    TIMEOUT_MS = 60_000  # 1 minute

    def init(self, handle: StatefulProcessorHandle) -> None:
        self.handle = handle
        schema = StructType([StructField("event_count", LongType())])
        self.session = handle.getValueState("session", schema)

    def handleInputRows(self, key, rows, timerValues):
        existing = self.session.get()
        event_count = existing[0] if existing is not None else 0

        for row in rows:
            event_count += 1

        self.session.update((event_count,))

        # Reset the timeout: register a timer at current time + timeout
        current_time = timerValues.getCurrentProcessingTimeInMs()
        self.handle.registerTimer(current_time + self.TIMEOUT_MS)

        return iter([Row(user_id=key[0], event_count=event_count, status="active")])

    def handleExpiredTimer(self, key, timerValues, expiredTimerInfo):
        # Timer fired, session timed out. Clear state and emit a closure event.
        event_count = self.session.get()
        self.session.clear()
        count = event_count[0] if event_count is not None else 0
        return iter([Row(user_id=key[0], event_count=count, status="expired")])

    def close(self) -> None:
        pass
â–¸ Show Scala equivalent
class SessionTimeoutProcessor
    extends StatefulProcessor[String, String, (String, Long, String)] {

  private val TimeoutMs = 60000L
  @transient private var sessionState: ValueState[Long] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    sessionState = getHandle.getValueState[Long](
      "session", Encoders.scalaLong, TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      rows: Iterator[String],
      timerValues: TimerValues
  ): Iterator[(String, Long, String)] = {
    var eventCount = if (sessionState.exists()) sessionState.get() else 0L
    rows.foreach(_ => eventCount += 1)
    sessionState.update(eventCount)

    val currentTime = timerValues.getCurrentProcessingTimeInMs()
    getHandle.registerTimer(currentTime + TimeoutMs)

    Iterator((key, eventCount, "active"))
  }

  override def handleExpiredTimer(
      key: String,
      timerValues: TimerValues,
      expiredTimerInfo: ExpiredTimerInfo
  ): Iterator[(String, Long, String)] = {
    val count = if (sessionState.exists()) sessionState.get() else 0L
    sessionState.clear()
    Iterator((key, count, "expired"))
  }
}

Now test it with TwsTester in processing-time mode:

tester = TwsTester(
    processor=SessionTimeoutProcessor(),
    timeMode="ProcessingTime",
    outputMode="Append",
)

# Set an initial processing time
tester.setProcessingTime(1_000)

# Process some events for alice
output = tester.test(
    key="alice",
    input=[Row(event="click"), Row(event="scroll")],
)
print(output)
# [Row(user_id='alice', event_count=2, status='active')]

# Advance time past the timeout threshold
expired_output = tester.setProcessingTime(62_000)
print(expired_output)
# [Row(user_id='alice', event_count=2, status='expired')]

# State has been cleared
print(tester.peekValueState("session", "alice"))
# None
â–¸ Show Scala equivalent
val tester = new TwsTester(
  processor = new SessionTimeoutProcessor(),
  timeMode = TimeMode.ProcessingTime(),
  outputMode = OutputMode.Append()
)

// Set an initial processing time
tester.setProcessingTime(1000L)

// Process some events for alice
val output = tester.test("alice", List("click", "scroll"))
println(output)
// List(("alice", 2, "active"))

// Advance time past the timeout threshold
val expiredOutput = tester.setProcessingTime(62000L)
println(expiredOutput)
// List(("alice", 2, "expired"))

// State has been cleared
println(tester.peekValueState[Long]("session", "alice"))
// None

We just tested a complete session lifecycle (activity, timeout, cleanup) in a few lines, with full control over time.

Wrapping tests for CI

The examples above work well in a notebook for interactive exploration. To lock in behavior for CI, wrap them in your test framework of choice.

Here's a pytest example. The same pattern works in ScalaTest; expand the spoiler below for the Scala version.

import pytest
from pyspark.sql import Row, SparkSession
from pyspark.sql.streaming import TwsTester


@pytest.fixture(scope="module")
def spark():
    return SparkSession.builder.master("local[*]").getOrCreate()


def test_running_sum_single_batch(spark):
    tester = TwsTester(
        processor=RunningSumProcessor(),
        timeMode="None",
        outputMode="Append",
    )

    output = tester.test("alice", [Row(amount=10), Row(amount=5)])

    assert len(output) == 2
    assert output[0]["running_sum"] == 10
    assert output[1]["running_sum"] == 15


def test_running_sum_across_batches(spark):
    tester = TwsTester(
        processor=RunningSumProcessor(),
        timeMode="None",
        outputMode="Append",
    )

    tester.test("alice", [Row(amount=10), Row(amount=5)])
    output = tester.test("alice", [Row(amount=3)])

    assert output[0]["running_sum"] == 18


def test_state_isolation_across_keys(spark):
    tester = TwsTester(
        processor=RunningSumProcessor(),
        timeMode="None",
        outputMode="Append",
    )

    tester.test("alice", [Row(amount=10)])
    tester.test("bob", [Row(amount=20), Row(amount=5)])

    assert tester.peekValueState("sum", "alice") == (10,)
    assert tester.peekValueState("sum", "bob") == (25,)
â–¸ Show ScalaTest equivalent
import org.scalatest.funsuite.AnyFunSuite
import org.apache.spark.sql.streaming._

class RunningSumProcessorTest extends AnyFunSuite {

  test("running sum accumulates within a single batch") {
    val tester = new TwsTester(
      processor = new RunningSumProcessor(),
      timeMode = TimeMode.None(),
      outputMode = OutputMode.Append()
    )

    val output = tester.test("alice", List(10L, 5L))

    assert(output === List(("alice", 10L, 10L), ("alice", 5L, 15L)))
  }

  test("running sum accumulates across batches") {
    val tester = new TwsTester(
      processor = new RunningSumProcessor(),
      timeMode = TimeMode.None(),
      outputMode = OutputMode.Append()
    )

    tester.test("alice", List(10L, 5L))
    val output = tester.test("alice", List(3L))

    assert(output === List(("alice", 3L, 18L)))
  }

  test("state is isolated across keys") {
    val tester = new TwsTester(
      processor = new RunningSumProcessor(),
      timeMode = TimeMode.None(),
      outputMode = OutputMode.Append()
    )

    tester.test("alice", List(10L))
    tester.test("bob", List(20L, 5L))

    assert(tester.peekValueState[Long]("sum", "alice") === Some(10L))
    assert(tester.peekValueState[Long]("sum", "bob") === Some(25L))
  }
}

No mock streaming sources, no awaitTermination, no checkpoint cleanup. Just direct assertions on processor behavior.

Fitting TwsTester into your workflow

TwsTester is most valuable when you're in the thick of building or debugging stateful logic:

  • During development: run tests in a notebook cell after each change. The feedback loop drops from minutes to seconds.
  • For corner cases: construct specific input sequences that exercise edge conditions such as late arrivals, duplicate keys, empty batches, timer expirations at exact boundaries.
  • In CI: wrap the same tests in pytest or ScalaTest so regressions are caught before they reach a streaming job.

It's not a replacement for end-to-end integration tests. You'll still want to verify that your processor works correctly within a full transformWithState pipeline. But for the business logic inside the processor, TwsTester gives you the tight feedback loop that stateful streaming has always been missing.

Get started

Copy the example above into a Databricks Notebook or a local Spark environment running Spark 4.2.0+, replace RunningSumProcessor with your own processor, and start testing.

For the full API reference and more examples:

Frequently asked questions (FAQ)

What is TwsTester?TwsTester is a unit-testing harness for StatefulProcessor implementations used with the transformWithState operator in Apache Spark Structured Streaming. You hand it your processor and a list of input rows, and it drives the processor the same way Spark would inside a real micro-batch, returning the resulting output rows so you can assert on them directly.

Which Spark and Databricks Runtime versions include TwsTester?TwsTester is available starting in Apache Spark 4.2.0, which corresponds to Databricks Runtime 18.2 and later.

Is TwsTester available in both PySpark and Scala?Yes. The PySpark entry point is pyspark.sql.streaming.TwsTester and the Scala entry point is org.apache.spark.sql.streaming.TwsTester. The two surfaces mirror each other, so the same test patterns translate directly between languages.

Does TwsTester replace end-to-end streaming integration tests?No. TwsTester targets the business logic inside a StatefulProcessor: per-key state, timers, output rows for given inputs. You should still run end-to-end tests against a real transformWithState pipeline to verify source/sink wiring, checkpointing, and fault tolerance. TwsTester compresses the inner feedback loop; integration tests still validate the outer one.

Can TwsTester test processing-time and event-time timers?Yes. Construct the tester with timeMode="ProcessingTime" or timeMode="EventTime" and use setProcessingTime(...) or setWatermark(...) to advance time on demand. Any expired timers fire and their output is collected, so timeout, session-expiry, and periodic-flush logic can be exercised deterministically.

How is TwsTester different from the State Reader API?The State Reader API reads state from a query's checkpoint directory after the fact, which is great for inspecting production state. TwsTester runs locally, drives the processor with controlled inputs, and lets you inspect or override state mid-test with peekValueState / updateValueState and friends. They are complementary: State Reader for production debugging, TwsTester for development and CI.

Can I use TwsTester in a CI pipeline?Yes. Wrap each test in your framework of choice (pytest for PySpark, ScalaTest for Scala) and run them like any other unit test. TwsTester needs only a regular SparkSession, no streaming sources, sinks, or checkpoint directory, so it fits naturally into existing test infrastructure.