Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ package org.apache.spark.sql.pipelines.graph
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.json4s.JsonAST.{JArray, JString}
import org.json4s.jackson.JsonMethods.{compact, parse}

import org.apache.spark.SparkException
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.{AnalysisException, Dataset, Row}
Expand Down Expand Up @@ -346,8 +350,6 @@ object AutoCdcAuxiliaryTable {
* upstream.
*/
def serializeKeyColumnNames(names: Seq[String]): String = {
import org.json4s.JsonAST.{JArray, JString}
import org.json4s.jackson.JsonMethods.compact
compact(JArray(names.map(JString(_)).toList))
}

Expand All @@ -357,8 +359,6 @@ object AutoCdcAuxiliaryTable {
* upstream.
*/
def parseKeyColumnNames(raw: String): Option[Seq[String]] = {
import org.json4s.JsonAST.{JArray, JString}
import org.json4s.jackson.JsonMethods.parse
val parsed = try Some(parse(raw)) catch { case NonFatal(_) => None }
parsed.flatMap {
case JArray(elems) =>
Expand Down Expand Up @@ -406,7 +406,7 @@ trait AutoCdcMergeWriteBase {
val (catalog, v2Identifier) = PipelinesCatalogUtils.resolveTableCatalog(spark, auxIdent)

if (!catalog.tableExists(v2Identifier)) {
val properties = scala.collection.mutable.Map.empty[String, String]
val properties = mutable.Map.empty[String, String]

// Inherit the target's format so MERGE semantics line up. When unspecified, omit the
// provider so the catalog falls back to its default.
Expand Down Expand Up @@ -440,18 +440,21 @@ trait AutoCdcMergeWriteBase {
}

/**
* Returns the resolved AutoCDC key column names as they appear in the auxiliary schema, in
* `changeArgs.keys` declaration order.
* Resolves each AutoCDC key in `changeArgs.keys` to its [[StructField]] in
* [[auxiliaryTableSchema]], preserving `changeArgs.keys` declaration order. This is the
* expected (flow-declared) side of drift validation, distinct from the keys recorded on an
* existing auxiliary table.
*
* [[AutoCdcMergeFlow]] should have validated that all `changeArgs.keys` exist in the deduced
* aux/target schemas by now, so a missing key is an internal error rather than a user-facing
* condition.
*/
private def auxiliaryKeyColumnNames: Seq[String] = {
private def expectedAuxiliaryKeyFields: Seq[StructField] = {
val resolver = spark.sessionState.conf.resolver
changeArgs.keys.map { key =>
auxiliaryTableSchema.fields
.find(field => resolver(field.name, key.name))
.map(_.name)
.getOrElse(
// This should never happen at this point, as [[AutoCdcMergeFlow]] should have validated
// all changeArgs.keys exist in the deduced aux/target table schemas by now.
throw SparkException.internalError(
s"AutoCDC key column '${key.name}' is missing from the auxiliary table schema " +
s"for flow ${identifier.unquotedString} writing to target " +
Expand All @@ -461,6 +464,12 @@ trait AutoCdcMergeWriteBase {
}
}

/**
* Returns the resolved AutoCDC key column names as they appear in the auxiliary schema, in
* `changeArgs.keys` declaration order.
*/
private def auxiliaryKeyColumnNames: Seq[String] = expectedAuxiliaryKeyFields.map(_.name)

/**
* Validate that the target table's underlying connector implements
* [[SupportsRowLevelOperations]], which is the V2 connector contract for MERGE/UPDATE/DELETE
Expand Down Expand Up @@ -512,21 +521,10 @@ trait AutoCdcMergeWriteBase {
val resolver = spark.sessionState.conf.resolver
val existingAuxSchema = CatalogV2Util.v2ColumnsToStructType(existingAuxTable.columns())

// The expected key fields are looked up in [[auxiliaryTableSchema]], which by construction
// contains every key column with its source-derived dataType. We deliberately do not look
// them up in [[existingAuxSchema]] - that's the recorded side, and conflating the two
// sides would mask drift.
val expectedKeyFields: Seq[StructField] = changeArgs.keys.map { key =>
auxiliaryTableSchema.fields
.find(field => resolver(field.name, key.name))
.getOrElse(
// Construction of [[auxiliaryTableSchema]] already enforces all of the user-specified
// keys are present, so if we don't find a key it is truly an internal error.
throw SparkException.internalError(
s"Key column '${key.name}' was not found in the AutoCDC auxiliary table schema."
)
)
}
// Resolve the flow-declared (expected) keys from [[auxiliaryTableSchema]]. We deliberately
// do not look them up in [[existingAuxSchema]] - that's the recorded side, and conflating
// the two sides would mask drift. See [[expectedAuxiliaryKeyFields]].
val expectedKeyFields: Seq[StructField] = expectedAuxiliaryKeyFields
val recordedKeyNames = parseRecordedKeyColumnNames(existingAuxTable, auxIdent)
val recordedKeyFields: Seq[StructField] = recordedKeyNames.map { name =>
existingAuxSchema.fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import org.scalatest.{BeforeAndAfterEach, Suite}

import org.apache.spark.SparkThrowable
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.classic.DataFrame
import org.apache.spark.sql.connector.catalog.SharedTablesInMemoryRowLevelOperationTableCatalog
import org.apache.spark.sql.functions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.pipelines.autocdc.{
ChangeArgs,
Expand Down Expand Up @@ -207,6 +209,35 @@ trait AutoCdcGraphExecutionTestMixin extends BeforeAndAfterEach {
)
)

/**
* Build a single-flow AutoCDC pipeline: a [[TestGraphRegistrationContext]] that registers
* `target` under [[catalog]].[[namespace]] and one [[autoCdcFlow]] writing into it from
* `sourceDf`. Covers the common single-table/single-flow shape used across the AutoCDC E2E
* suites; tests that need multiple flows or non-AutoCDC datasets build the context inline.
*/
protected def singleAutoCdcFlowPipeline(
flowName: String,
target: String,
sourceDf: DataFrame,
keys: Seq[String],
sequencing: Column = functions.col("version"),
columnSelection: Option[ColumnSelection] = None,
deleteCondition: Option[Column] = None,
scdType: ScdType = ScdType.Type1): TestGraphRegistrationContext =
new TestGraphRegistrationContext(spark) {
registerTable(target, catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = flowName,
target = target,
query = dfFlowFunc(sourceDf),
keys = keys,
sequencing = sequencing,
columnSelection = columnSelection,
deleteCondition = deleteCondition,
scdType = scdType
))
}

/** Build a target row's `_cdc_metadata` struct value. */
protected def cdcMeta(deleteSeq: Option[Long], upsertSeq: Option[Long]): Row =
Row(deleteSeq.orNull, upsertSeq.orNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,11 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite
// resume cleanly.
val changeDataFeedStream = MemoryStream[(Int, String, Long)]
def buildGraphRegistrationContext(): TestGraphRegistrationContext =
new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(
changeDataFeedStream.toDF().toDF("id", "name", "version")
),
keys = Seq("id"),
sequencing = functions.col("version")
))
}
singleAutoCdcFlowPipeline(
"auto_cdc_flow",
"target",
changeDataFeedStream.toDF().toDF("id", "name", "version"),
Seq("id"))

// Run #1: insert id=1 at seq=1.
changeDataFeedStream.addData((1, "alice", 1L))
Expand Down Expand Up @@ -98,20 +91,17 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite

// Single MemoryStream reused across both runs so the streaming checkpoint can resume.
val stream = MemoryStream[(Int, String, Long, Boolean)]
def buildCtx(): TestGraphRegistrationContext = new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(stream.toDF().toDF("id", "name", "version", "is_delete")),
keys = Seq("id"),
sequencing = functions.col("version"),
def buildCtx(): TestGraphRegistrationContext =
singleAutoCdcFlowPipeline(
"auto_cdc_flow",
"target",
stream.toDF().toDF("id", "name", "version", "is_delete"),
Seq("id"),
deleteCondition = Some(functions.col("is_delete") === true),
columnSelection = Some(ColumnSelection.ExcludeColumns(
Seq(UnqualifiedColumnName("is_delete"))
))
))
}
)

// Run #1: delete id=1 at seq=10. Auxiliary table records seq=10 as the watermark.
stream.addData((1, "alice", 10L, true))
Expand Down Expand Up @@ -141,17 +131,8 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite

val stream = MemoryStream[(String, Int, Long)]
stream.addData(("alice", 1, 1L))
val ctx = new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(stream.toDF().toDF("name", "id", "version")),
keys = Seq("id"),
sequencing = functions.col("version")
))
}
runPipeline(ctx)
runPipeline(singleAutoCdcFlowPipeline(
"auto_cdc_flow", "target", stream.toDF().toDF("name", "id", "version"), Seq("id")))

val auxSchema = spark.table(auxTableNameFor("target")).schema

Expand Down Expand Up @@ -181,17 +162,9 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite

val stream = MemoryStream[(String, Int, String, Long)]
stream.addData(("v", 1, "us", 1L))
val ctx = new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(stream.toDF().toDF("value", "id", "region", "version")),
keys = Seq("region", "id"),
sequencing = functions.col("version")
))
}
runPipeline(ctx)
runPipeline(singleAutoCdcFlowPipeline(
"auto_cdc_flow", "target", stream.toDF().toDF("value", "id", "region", "version"),
Seq("region", "id")))

val auxSchema = spark.table(auxTableNameFor("target")).schema
assert(auxSchema.fieldNames.toSeq ==
Expand All @@ -211,16 +184,9 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite

// Single MemoryStream reused across both runs so the streaming checkpoint can resume.
val stream = MemoryStream[(Int, Long)]
def buildCtx(): TestGraphRegistrationContext = new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(stream.toDF().toDF("id", "version")),
keys = Seq("id"),
sequencing = functions.col("version")
))
}
def buildCtx(): TestGraphRegistrationContext =
singleAutoCdcFlowPipeline(
"auto_cdc_flow", "target", stream.toDF().toDF("id", "version"), Seq("id"))

stream.addData((1, 1L))
runPipeline(buildCtx())
Expand Down Expand Up @@ -276,18 +242,12 @@ class AutoCdcScd1AuxiliaryTableDurabilitySuite

// Single MemoryStream reused across both runs so the streaming checkpoint can resume.
val stream = MemoryStream[(String, String, String, String, Long)]
def buildCtx(): TestGraphRegistrationContext = new TestGraphRegistrationContext(spark) {
registerTable("target", catalog = Some(catalog), database = Some(namespace))
registerFlow(autoCdcFlow(
name = "auto_cdc_flow",
target = "target",
query = dfFlowFunc(
stream.toDF().toDF((keyNames :+ "version"): _*)
),
keys = backtickQuotedKeys,
sequencing = functions.col("version")
))
}
def buildCtx(): TestGraphRegistrationContext =
singleAutoCdcFlowPipeline(
"auto_cdc_flow",
"target",
stream.toDF().toDF((keyNames :+ "version"): _*),
backtickQuotedKeys)

// Run #1: a single insert with arbitrary non-empty key values.
stream.addData(("v1", "v2", "v3", "v4", 1L))
Expand Down
Loading