Skip to content

completion listener changes #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: read-store-changes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4972,6 +4972,13 @@
],
"sqlState" : "42802"
},
"STATE_STORE_UPDATING_AFTER_TASK_COMPLETION" : {
"message" : [
"State store id=<stateStoreId> still in updating state after task completed. If using foreachBatch, ",
"verify that it consumes the entire dataframe and does not catch and suppress errors during dataframe iteration."
],
"sqlState" : "XXKST"
},
"STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE" : {
"message" : [
"The streaming query failed to validate written state for value row.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.commons.io.IOUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._

import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand Down Expand Up @@ -116,6 +116,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
case object ABORTED extends STATE
case object RELEASED extends STATE

Option(TaskContext.get()).foreach { ctxt =>
ctxt.addTaskCompletionListener[Unit]( ctx => {
if (state == UPDATING) {
abort()
// Only throw error if the task is not already failed or interrupted
if (!ctx.isFailed() && !ctx.isInterrupted()) {
throw StateStoreErrors.stateStoreUpdatingAfterTaskCompletion(id)
}
}
})
}

private val newVersion = version + 1
@volatile private var state: STATE = UPDATING
private val finalDeltaFile: Path = deltaFile(newVersion)
Expand Down Expand Up @@ -960,7 +972,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
* @param endVersion checkpoint version to end with
* @return [[HDFSBackedStateStore]]
*/
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = {
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long,
readOnly: Boolean): StateStore = {
assert(!readOnly)
val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion)
logInfo(log"Retrieved snapshot at version " +
log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.io.CompressionCodec
Expand All @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider
with SupportsFineGrainedReplay {
import RocksDBStateStoreProvider._

class RocksDBStateStore(lastVersion: Long) extends StateStore {
class RocksDBStateStore(lastVersion: Long, var readOnly: Boolean) extends StateStore {
/** Trait and classes representing the internal state of the store */
trait STATE
case object UPDATING extends STATE
Expand All @@ -52,6 +52,36 @@ private[sql] class RocksDBStateStoreProvider
case object RELEASED extends STATE

@volatile private var state: STATE = UPDATING

Option(TaskContext.get()).foreach { ctxt =>
// Failure listeners are invoked before completion listeners.
// Listeners are invoked in LIFO manner compared to their
// registration, so we should not register any listeners
// after this one that could interfere with this logic.
ctxt.addTaskCompletionListener[Unit]( ctx => {
if (state == UPDATING) {
if (readOnly) {
release() // Only release, do not throw an error because we rely on
// CompletionListener to release for read-only store in
// mapPartitionsWithReadStateStore.
} else {
abort() // Abort since this is an error if stateful task completes
// without committing or aborting
// Only throw error if the task is not already failed or interrupted
// so that we don't override the original error.
if (!ctx.isFailed() && !ctx.isInterrupted()) {
throw StateStoreErrors.stateStoreUpdatingAfterTaskCompletion(id)
}
}
}
})

ctxt.addTaskFailureListener( (_, _) => {
abort() // Either the store is already aborted (this is a no-op) or
// we need to abort it.
})
}

@volatile private var isValidated = false

override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
Expand Down Expand Up @@ -489,13 +519,15 @@ private[sql] class RocksDBStateStoreProvider
// We need to match like this as opposed to case Some(ss: RocksDBStateStore)
// because of how the tests create the class in StateStoreRDDSuite
case Some(stateStore: ReadStateStore) if stateStore.isInstanceOf[RocksDBStateStore] =>
val rocksDBStateStore = stateStore.asInstanceOf[RocksDBStateStore]
rocksDBStateStore.readOnly = readOnly
stateStore.asInstanceOf[StateStore]
case Some(other) =>
throw new IllegalArgumentException(s"Existing store must be a RocksDBStateStore," +
s" store is actually ${other.getClass.getSimpleName}")
case None =>
// Create new store instance for getStore/getReadStore cases
new RocksDBStateStore(version)
new RocksDBStateStore(version, readOnly)
}
} catch {
case e: Throwable =>
Expand Down Expand Up @@ -619,7 +651,8 @@ private[sql] class RocksDBStateStoreProvider
* @param endVersion checkpoint version to end with
* @return [[StateStore]]
*/
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = {
override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long,
readOnly: Boolean): StateStore = {
try {
if (snapshotVersion < 1) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion)
Expand All @@ -628,7 +661,7 @@ private[sql] class RocksDBStateStoreProvider
throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion)
}
rocksDB.loadFromSnapshot(snapshotVersion, endVersion)
new RocksDBStateStore(endVersion)
new RocksDBStateStore(endVersion, readOnly)
}
catch {
case e: OutOfMemoryError =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ trait SupportsFineGrainedReplay {
* @param snapshotVersion checkpoint version of the snapshot to start with
* @param endVersion checkpoint version to end with
*/
def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore
def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long,
readOnly: Boolean = false): StateStore

/**
* Return an instance of [[ReadStateStore]] representing state data of the given version.
Expand All @@ -772,7 +773,7 @@ trait SupportsFineGrainedReplay {
* @param endVersion checkpoint version to end with
*/
def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = {
new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion))
new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion, readOnly = true))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ object StateStoreErrors {
def stateStoreOperationOutOfOrder(errorMsg: String): StateStoreOperationOutOfOrder = {
new StateStoreOperationOutOfOrder(errorMsg)
}

def stateStoreUpdatingAfterTaskCompletion(stateStoreId: StateStoreId):
StateStoreUpdatingAfterTaskCompletion = {
new StateStoreUpdatingAfterTaskCompletion(stateStoreId.toString)
}
}

class StateStoreDuplicateStateVariableDefined(stateVarName: String)
Expand Down Expand Up @@ -455,3 +460,9 @@ class StateStoreOperationOutOfOrder(errorMsg: String)
errorClass = "STATE_STORE_OPERATION_OUT_OF_ORDER",
messageParameters = Map("errorMsg" -> errorMsg)
)

class StateStoreUpdatingAfterTaskCompletion(stateStoreID: String)
extends SparkRuntimeException(
errorClass = "STATE_STORE_UPDATING_AFTER_TASK_COMPLETION",
messageParameters = Map("stateStoreId" -> stateStoreID)
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,14 @@ import org.apache.spark.util.SerializableConfiguration
* This allows a ReadStateStore to be reused by a subsequent StateStore operation.
*/
object StateStoreThreadLocalTracker {
/** Case class to hold both the store and its usage state */
case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false)

private val storeInfo: ThreadLocal[StoreInfo] = new ThreadLocal[StoreInfo]
private val storeInfo: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore]

def setStore(store: ReadStateStore): Unit = {
Option(storeInfo.get()) match {
case Some(info) => storeInfo.set(info.copy(store = store))
case None => storeInfo.set(StoreInfo(store))
}
storeInfo.set(store)
}

def getStore: Option[ReadStateStore] = {
Option(storeInfo.get()).map(_.store)
}

def setUsedForWriteStore(used: Boolean): Unit = {
Option(storeInfo.get()) match {
case Some(info) =>
storeInfo.set(info.copy(usedForWriteStore = used))
case None => // If there's no store set, we don't need to track usage
}
}

def isUsedForWriteStore: Boolean = {
Option(storeInfo.get()).exists(_.usedForWriteStore)
Option(storeInfo.get())
}

def clearStore(): Unit = {
Expand Down Expand Up @@ -177,9 +159,6 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)
if (writeStore.equals(readStateStore)) {
StateStoreThreadLocalTracker.setUsedForWriteStore(true)
}
writeStore
case None =>
StateStore.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,8 @@ package object state {

val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
val wrappedF = (store: StateStore, iter: Iterator[T]) => {
// Abort the state store in case of error
val ctxt = TaskContext.get()
ctxt.addTaskCompletionListener[Unit](_ => {
if (!store.hasCommitted) store.abort()
})
ctxt.addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
store.abort()
}
})
// Do not add CompletionListener here to clean up the state store because
// it is already added in RocksDBStateStore/HDFSBackedStateStore.
cleanedF(store, iter)
}

Expand Down Expand Up @@ -115,22 +107,12 @@ package object state {

val cleanedF = dataRDD.sparkContext.clean(storeReadFn)
val wrappedF = (store: ReadStateStore, iter: Iterator[T]) => {
// Clean up the state store.
val ctxt = TaskContext.get()
ctxt.addTaskCompletionListener[Unit](_ => {
if (!StateStoreThreadLocalTracker.isUsedForWriteStore) {
store.release()
}
// Do not call abort/release here to clean up the state store because
// it is already added in RocksDBStateStore/HDFSBackedStateStore.
// However, we still do need to clear the store from the StateStoreThreadLocalTracker.
TaskContext.get().addTaskCompletionListener[Unit](_ => {
StateStoreThreadLocalTracker.clearStore()
})
ctxt.addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = {
if (!StateStoreThreadLocalTracker.isUsedForWriteStore) {
store.abort()
}
StateStoreThreadLocalTracker.clearStore()
}
})
cleanedF(store, iter)
}
new ReadStateStoreRDD(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,23 @@ class ForeachBatchSinkSuite extends StreamTest {
query.awaitTermination()
}

test("SPARK-52008: foreachBatch that doesn't consume entire iterator") {
val mem = MemoryStream[Int]
val ds = mem.toDS().map(_ + 1)
mem.addData(1, 2, 3, 4, 5)

val queryEx = intercept[StreamingQueryException] {
val query = ds.writeStream.foreachBatch(
(batchDf: Dataset[Int], _: Long) => batchDf.show(2)).start()
query.awaitTermination()
}

val errClass = "STATE_STORE_UPDATING_AFTER_TASK_COMPLETION"

// verify that we classified the exception
assert(queryEx.getMessage.contains(errClass))
}

// ============== Helper classes and methods =================

private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
// Create input data for our chained operations
val inputData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))

var mappedReadStore: ReadStateStore = null
var mappedWriteStore: StateStore = null

// Chain operations: first read with ReadStateStore, then write with StateStore
val chainedResults = inputData
// First pass: read-only state store access
Expand All @@ -270,6 +273,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
spark.sessionState,
Some(castToImpl(spark).streams.stateStoreCoordinator)
) { (readStore, iter) =>
mappedReadStore = readStore

// Read values and store them for later verification
val inputItems = iter.toSeq // Materialize the input data

Expand All @@ -296,7 +301,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
) { (writeStore, writeIter) =>
if (writeIter.hasNext) {
val (readValues, allStoreValues, originalItems) = writeIter.next()
val usedForWriteStore = StateStoreThreadLocalTracker.isUsedForWriteStore
mappedWriteStore = writeStore
// Get all existing values from the write store to verify reuse
val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq

Expand All @@ -310,15 +315,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
writeStore.commit()

// Return all collected information for verification
Iterator((readValues, allStoreValues, storeValues, usedForWriteStore))
Iterator((readValues, allStoreValues, storeValues))
} else {
Iterator.empty
}
}

// Collect the results
val (readValues, initialStoreState, writeStoreValues,
storeWasReused) = chainedResults.collect().head
val (readValues, initialStoreState, writeStoreValues) = chainedResults.collect().head

// Verify read results
assert(readValues.toSet === Set(
Expand All @@ -333,8 +337,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter {
// Verify the existing values in the write store (should be the same as initial state)
assert(writeStoreValues.toSet === Set((("a", 0), 1), (("b", 0), 2)))

// Verify the thread local flag indicates reuse
assert(storeWasReused,
// Verify that the same store was used for both read and write operations
assert(mappedReadStore == mappedWriteStore,
"StateStoreThreadLocalTracker should indicate the read store was reused")

// Create another ReadStateStoreRDD to verify the final state (version 2)
Expand Down
Loading