diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ed57e271a4c96..5c8fc08556a10 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4972,6 +4972,13 @@ ], "sqlState" : "42802" }, + "STATE_STORE_UPDATING_AFTER_TASK_COMPLETION" : { + "message" : [ + "State store id= still in updating state after task completed. If using foreachBatch, ", + "verify that it consumes the entire dataframe and does not catch and suppress errors during dataframe iteration." + ], + "sqlState" : "XXKST" + }, "STATE_STORE_VALUE_ROW_FORMAT_VALIDATION_FAILURE" : { "message" : [ "The streaming query failed to validate written state for value row.", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 20189c8007ef7..b91195ac410a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, LogKeys, MDC, MessageWithContext} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -116,6 +116,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with case object ABORTED extends STATE case object RELEASED extends STATE + Option(TaskContext.get()).foreach { ctxt => + ctxt.addTaskCompletionListener[Unit]( ctx => { + if (state == UPDATING) { + abort() + // Only throw error if the task is not already failed or interrupted + if (!ctx.isFailed() && !ctx.isInterrupted()) { + throw StateStoreErrors.stateStoreUpdatingAfterTaskCompletion(id) + } + } + }) + } + private val newVersion = version + 1 @volatile private var state: STATE = UPDATING private val finalDeltaFile: Path = deltaFile(newVersion) @@ -960,7 +972,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with * @param endVersion checkpoint version to end with * @return [[HDFSBackedStateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long, + readOnly: Boolean): StateStore = { + assert(!readOnly) val newMap = replayLoadedMapFromSnapshot(snapshotVersion, endVersion) logInfo(log"Retrieved snapshot at version " + log"${MDC(LogKeys.STATE_STORE_VERSION, snapshotVersion)} and apply delta files to version " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 8f4a7041e6581..f819315c90a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.io.CompressionCodec @@ -43,7 +43,7 @@ private[sql] class RocksDBStateStoreProvider with SupportsFineGrainedReplay { import RocksDBStateStoreProvider._ - class RocksDBStateStore(lastVersion: Long) extends StateStore { + class RocksDBStateStore(lastVersion: Long, var readOnly: Boolean) extends StateStore { /** Trait and classes representing the internal state of the store */ trait STATE case object UPDATING extends STATE @@ -52,6 +52,36 @@ private[sql] class RocksDBStateStoreProvider case object RELEASED extends STATE @volatile private var state: STATE = UPDATING + + Option(TaskContext.get()).foreach { ctxt => + // Failure listeners are invoked before completion listeners. + // Listeners are invoked in LIFO manner compared to their + // registration, so we should not register any listeners + // after this one that could interfere with this logic. + ctxt.addTaskCompletionListener[Unit]( ctx => { + if (state == UPDATING) { + if (readOnly) { + release() // Only release, do not throw an error because we rely on + // CompletionListener to release for read-only store in + // mapPartitionsWithReadStateStore. + } else { + abort() // Abort since this is an error if stateful task completes + // without committing or aborting + // Only throw error if the task is not already failed or interrupted + // so that we don't override the original error. + if (!ctx.isFailed() && !ctx.isInterrupted()) { + throw StateStoreErrors.stateStoreUpdatingAfterTaskCompletion(id) + } + } + } + }) + + ctxt.addTaskFailureListener( (_, _) => { + abort() // Either the store is already aborted (this is a no-op) or + // we need to abort it. + }) + } + @volatile private var isValidated = false override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId @@ -489,13 +519,15 @@ private[sql] class RocksDBStateStoreProvider // We need to match like this as opposed to case Some(ss: RocksDBStateStore) // because of how the tests create the class in StateStoreRDDSuite case Some(stateStore: ReadStateStore) if stateStore.isInstanceOf[RocksDBStateStore] => + val rocksDBStateStore = stateStore.asInstanceOf[RocksDBStateStore] + rocksDBStateStore.readOnly = readOnly stateStore.asInstanceOf[StateStore] case Some(other) => throw new IllegalArgumentException(s"Existing store must be a RocksDBStateStore," + s" store is actually ${other.getClass.getSimpleName}") case None => // Create new store instance for getStore/getReadStore cases - new RocksDBStateStore(version) + new RocksDBStateStore(version, readOnly) } } catch { case e: Throwable => @@ -619,7 +651,8 @@ private[sql] class RocksDBStateStoreProvider * @param endVersion checkpoint version to end with * @return [[StateStore]] */ - override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore = { + override def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long, + readOnly: Boolean): StateStore = { try { if (snapshotVersion < 1) { throw QueryExecutionErrors.unexpectedStateStoreVersion(snapshotVersion) @@ -628,7 +661,7 @@ private[sql] class RocksDBStateStoreProvider throw QueryExecutionErrors.unexpectedStateStoreVersion(endVersion) } rocksDB.loadFromSnapshot(snapshotVersion, endVersion) - new RocksDBStateStore(endVersion) + new RocksDBStateStore(endVersion, readOnly) } catch { case e: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ea1085749e26b..d6221341ababb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -759,7 +759,8 @@ trait SupportsFineGrainedReplay { * @param snapshotVersion checkpoint version of the snapshot to start with * @param endVersion checkpoint version to end with */ - def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long): StateStore + def replayStateFromSnapshot(snapshotVersion: Long, endVersion: Long, + readOnly: Boolean = false): StateStore /** * Return an instance of [[ReadStateStore]] representing state data of the given version. @@ -772,7 +773,7 @@ trait SupportsFineGrainedReplay { * @param endVersion checkpoint version to end with */ def replayReadStateFromSnapshot(snapshotVersion: Long, endVersion: Long): ReadStateStore = { - new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion)) + new WrappedReadStateStore(replayStateFromSnapshot(snapshotVersion, endVersion, readOnly = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index af96db4a50361..26715f0269b6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -221,6 +221,11 @@ object StateStoreErrors { def stateStoreOperationOutOfOrder(errorMsg: String): StateStoreOperationOutOfOrder = { new StateStoreOperationOutOfOrder(errorMsg) } + + def stateStoreUpdatingAfterTaskCompletion(stateStoreId: StateStoreId): + StateStoreUpdatingAfterTaskCompletion = { + new StateStoreUpdatingAfterTaskCompletion(stateStoreId.toString) + } } class StateStoreDuplicateStateVariableDefined(stateVarName: String) @@ -455,3 +460,9 @@ class StateStoreOperationOutOfOrder(errorMsg: String) errorClass = "STATE_STORE_OPERATION_OUT_OF_ORDER", messageParameters = Map("errorMsg" -> errorMsg) ) + +class StateStoreUpdatingAfterTaskCompletion(stateStoreID: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_UPDATING_AFTER_TASK_COMPLETION", + messageParameters = Map("stateStoreId" -> stateStoreID) + ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index d78c5229e0ac2..03a9e0e7aa69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -32,32 +32,14 @@ import org.apache.spark.util.SerializableConfiguration * This allows a ReadStateStore to be reused by a subsequent StateStore operation. */ object StateStoreThreadLocalTracker { - /** Case class to hold both the store and its usage state */ - case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false) - - private val storeInfo: ThreadLocal[StoreInfo] = new ThreadLocal[StoreInfo] + private val storeInfo: ThreadLocal[ReadStateStore] = new ThreadLocal[ReadStateStore] def setStore(store: ReadStateStore): Unit = { - Option(storeInfo.get()) match { - case Some(info) => storeInfo.set(info.copy(store = store)) - case None => storeInfo.set(StoreInfo(store)) - } + storeInfo.set(store) } def getStore: Option[ReadStateStore] = { - Option(storeInfo.get()).map(_.store) - } - - def setUsedForWriteStore(used: Boolean): Unit = { - Option(storeInfo.get()) match { - case Some(info) => - storeInfo.set(info.copy(usedForWriteStore = used)) - case None => // If there's no store set, we don't need to track usage - } - } - - def isUsedForWriteStore: Boolean = { - Option(storeInfo.get()).exists(_.usedForWriteStore) + Option(storeInfo.get()) } def clearStore(): Unit = { @@ -177,9 +159,6 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( stateSchemaBroadcast, useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey) - if (writeStore.equals(readStateStore)) { - StateStoreThreadLocalTracker.setUsedForWriteStore(true) - } writeStore case None => StateStore.get( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b0a94052c9900..f30cee029cfcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -67,16 +67,8 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) val wrappedF = (store: StateStore, iter: Iterator[T]) => { - // Abort the state store in case of error - val ctxt = TaskContext.get() - ctxt.addTaskCompletionListener[Unit](_ => { - if (!store.hasCommitted) store.abort() - }) - ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { - store.abort() - } - }) + // Do not add CompletionListener here to clean up the state store because + // it is already added in RocksDBStateStore/HDFSBackedStateStore. cleanedF(store, iter) } @@ -115,22 +107,12 @@ package object state { val cleanedF = dataRDD.sparkContext.clean(storeReadFn) val wrappedF = (store: ReadStateStore, iter: Iterator[T]) => { - // Clean up the state store. - val ctxt = TaskContext.get() - ctxt.addTaskCompletionListener[Unit](_ => { - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { - store.release() - } + // Do not call abort/release here to clean up the state store because + // it is already added in RocksDBStateStore/HDFSBackedStateStore. + // However, we still do need to clear the store from the StateStoreThreadLocalTracker. + TaskContext.get().addTaskCompletionListener[Unit](_ => { StateStoreThreadLocalTracker.clearStore() }) - ctxt.addTaskFailureListener(new TaskFailureListener { - override def onTaskFailure(context: TaskContext, error: Throwable): Unit = { - if (!StateStoreThreadLocalTracker.isUsedForWriteStore) { - store.abort() - } - StateStoreThreadLocalTracker.clearStore() - } - }) cleanedF(store, iter) } new ReadStateStoreRDD( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index 10ca65ec8c41c..5e7eb3b9674d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -254,6 +254,23 @@ class ForeachBatchSinkSuite extends StreamTest { query.awaitTermination() } + test("SPARK-52008: foreachBatch that doesn't consume entire iterator") { + val mem = MemoryStream[Int] + val ds = mem.toDS().map(_ + 1) + mem.addData(1, 2, 3, 4, 5) + + val queryEx = intercept[StreamingQueryException] { + val query = ds.writeStream.foreachBatch( + (batchDf: Dataset[Int], _: Long) => batchDf.show(2)).start() + query.awaitTermination() + } + + val errClass = "STATE_STORE_UPDATING_AFTER_TASK_COMPLETION" + + // verify that we classified the exception + assert(queryEx.getMessage.contains(errClass)) + } + // ============== Helper classes and methods ================= private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6095b26ecd6fe..2247fd21da597 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -259,6 +259,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { // Create input data for our chained operations val inputData = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0))) + var mappedReadStore: ReadStateStore = null + var mappedWriteStore: StateStore = null + // Chain operations: first read with ReadStateStore, then write with StateStore val chainedResults = inputData // First pass: read-only state store access @@ -270,6 +273,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { spark.sessionState, Some(castToImpl(spark).streams.stateStoreCoordinator) ) { (readStore, iter) => + mappedReadStore = readStore + // Read values and store them for later verification val inputItems = iter.toSeq // Materialize the input data @@ -296,7 +301,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { ) { (writeStore, writeIter) => if (writeIter.hasNext) { val (readValues, allStoreValues, originalItems) = writeIter.next() - val usedForWriteStore = StateStoreThreadLocalTracker.isUsedForWriteStore + mappedWriteStore = writeStore // Get all existing values from the write store to verify reuse val storeValues = writeStore.iterator().map(rowPairToDataPair).toSeq @@ -310,15 +315,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { writeStore.commit() // Return all collected information for verification - Iterator((readValues, allStoreValues, storeValues, usedForWriteStore)) + Iterator((readValues, allStoreValues, storeValues)) } else { Iterator.empty } } // Collect the results - val (readValues, initialStoreState, writeStoreValues, - storeWasReused) = chainedResults.collect().head + val (readValues, initialStoreState, writeStoreValues) = chainedResults.collect().head // Verify read results assert(readValues.toSet === Set( @@ -333,8 +337,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter { // Verify the existing values in the write store (should be the same as initial state) assert(writeStoreValues.toSet === Set((("a", 0), 1), (("b", 0), 2))) - // Verify the thread local flag indicates reuse - assert(storeWasReused, + // Verify that the same store was used for both read and write operations + assert(mappedReadStore == mappedWriteStore, "StateStoreThreadLocalTracker should indicate the read store was reused") // Create another ReadStateStoreRDD to verify the final state (version 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 1034a5edbdfc9..fc1ff1d5300f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -24,6 +24,7 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters._ import scala.util.Random @@ -48,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.tags.ExtendedSQLTest import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} // MaintenanceErrorOnCertainPartitionsProvider is a test-only provider that throws an // exception during maintenance for partitions 0 and 1 (these are arbitrary choices). It is @@ -1819,6 +1820,33 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(encoderSpec == deserializedEncoderSpec) } + test("SPARK-52008: TaskCompletionListener fails if store isn't committed/aborted") { + // Create a custom ExecutionContext with 1 thread + implicit val ec: ExecutionContext = ExecutionContext.fromExecutor( + ThreadUtils.newDaemonSingleThreadExecutor("single-thread-executor")) + val timeout = 5.seconds + + tryWithProviderResource(newStoreProvider()) { provider: StateStoreProvider => + val taskContext = TaskContext.empty() + var store: StateStore = null + + val fut = Future { + TaskContext.setTaskContext(taskContext) + + store = provider.getStore(0) + + // Task completion listener should abort and throw exception + taskContext.markTaskCompleted(None) + } + + val ex = intercept[SparkException] { + ThreadUtils.awaitResult(fut, timeout) + } + assert(ex.getCause.getMessage.contains("STATE_STORE_UPDATING_AFTER_TASK_COMPLETION")) + assert(taskContext.isFailed()) + } + } + /** Return a new provider with a random id */ def newStoreProvider(): ProviderClass