diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 0dafcf6bd339..b3ab903f0222 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -24,15 +24,21 @@ use std::sync::Arc; use super::SendableRecordBatchStream; use crate::expressions::{CastExpr, Column}; use crate::projection::{ProjectionExec, ProjectionExpr}; -use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::{ + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, Statistics, +}; use arrow::array::Array; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, plan_err}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use futures::{StreamExt, TryStreamExt}; @@ -94,57 +100,73 @@ fn build_file_list_recurse( /// /// This helper is intended for operators that combine independently planned children but /// expose a single declared output schema. It returns `input` unchanged when schemas already -/// match exactly. Otherwise, it validates that projection can safely produce the expected -/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing -/// positional order and aliases them to `expected_schema`'s field names. +/// match exactly. Otherwise, it validates positional compatibility and uses a plan-time +/// adapter whose advertised and emitted schema is exactly `expected_schema`. /// -/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input -/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects -/// differences that projection cannot safely normalize exactly, such as data type, metadata, -/// schema metadata, and nullability narrowing. -pub fn project_plan_to_schema( +/// Prefer this helper over rebinding batches inside a parent operator's stream. The alignment +/// is visible in the physical plan, while batch schema rebinding remains contained in the +/// adapter as the implementation detail required to uphold the plan-level schema contract. +/// +/// This helper can align field names and nullability to the declared schema. It rejects +/// differences that would change values or silently lose schema information, such as column +/// count, data type, field metadata, or schema metadata mismatches. +/// +/// When an adapter is required, it conservatively derives fresh equivalence properties from +/// `expected_schema` and drops child hash partitioning because field names/nullability may have +/// changed while the underlying partitioning expressions still refer to the child schema. +pub fn align_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { let input_schema = input.schema(); + if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); } - if input_schema.fields().len() != expected_schema.fields().len() { - return plan_err!( - "Cannot project plan to expected schema: expected {} column(s), got {}", - expected_schema.fields().len(), - input_schema.fields().len() - ); + if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { + debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); + return Ok(projected); } - if input_schema.metadata() != expected_schema.metadata() { - return plan_err!( - "Cannot project plan to expected schema: schema metadata differ" - ); + Ok(Arc::new(SchemaAlignExec::try_new( + input, + Arc::clone(expected_schema), + )?)) +} + +/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. +/// +/// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional +/// projection/alias is sufficient. It rejects requests where projection cannot advertise the +/// exact expected schema, such as nullability narrowing. +pub fn project_plan_to_schema( + input: Arc, + expected_schema: &SchemaRef, +) -> Result> { + let input_schema = input.schema(); + validate_schema_alignment(&input_schema, expected_schema, "project")?; + + if input_schema.as_ref() == expected_schema.as_ref() { + return Ok(input); } - if let Some((i, input_field, expected_field, mismatch)) = input_schema + if let Some((i, input_field, expected_field)) = input_schema .fields() .iter() .zip(expected_schema.fields().iter()) .enumerate() .find_map(|(i, (input_field, expected_field))| { - if input_field.data_type() != expected_field.data_type() { - Some((i, input_field, expected_field, "data type")) - } else if input_field.is_nullable() && !expected_field.is_nullable() { - Some((i, input_field, expected_field, "nullability")) - } else if input_field.metadata() != expected_field.metadata() { - Some((i, input_field, expected_field, "metadata")) - } else { - None - } + (input_field.is_nullable() && !expected_field.is_nullable()).then_some(( + i, + input_field, + expected_field, + )) }) { return plan_err!( "Cannot project plan column {i} ('{}') to expected output field '{}': \ - field {mismatch} differs (input field: {:?}, expected field: {:?})", + field nullability differs (input field: {:?}, expected field: {:?})", input_field.name(), expected_field.name(), input_field, @@ -180,6 +202,181 @@ pub fn project_plan_to_schema( Ok(Arc::new(projection)) } +fn validate_schema_alignment( + input_schema: &SchemaRef, + expected_schema: &SchemaRef, + operation: &str, +) -> Result<()> { + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot {operation} plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); + } + + if input_schema.metadata() != expected_schema.metadata() { + return plan_err!( + "Cannot {operation} plan to expected schema: schema metadata differ" + ); + } + + if let Some((i, input_field, expected_field, mismatch)) = input_schema + .fields() + .iter() + .zip(expected_schema.fields().iter()) + .enumerate() + .find_map(|(i, (input_field, expected_field))| { + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else if input_field.metadata() != expected_field.metadata() { + Some((i, input_field, expected_field, "metadata")) + } else { + None + } + }) + { + return plan_err!( + "Cannot {operation} plan column {i} ('{}') to expected output field '{}': \ + field {mismatch} differs (input field: {:?}, expected field: {:?})", + input_field.name(), + expected_field.name(), + input_field, + expected_field + ); + } + + Ok(()) +} + +/// Plan-time schema adapter for positional schema alignment. +/// +/// [`ProjectionExec`] cannot express every schema-only alignment. In particular, a column +/// expression remains nullable when its input field is nullable, so projection cannot advertise +/// a non-null expected field. This adapter is for cases where the operator-level contract has +/// already established that columns are positionally compatible and the child plan must expose +/// the declared schema exactly. +#[derive(Debug, Clone)] +pub struct SchemaAlignExec { + input: Arc, + schema: SchemaRef, + cache: Arc, +} + +impl SchemaAlignExec { + /// Create a new schema alignment adapter. + pub fn try_new(input: Arc, schema: SchemaRef) -> Result { + validate_schema_alignment(&input.schema(), &schema, "align")?; + + let input_properties = input.properties(); + let partitioning = match &input_properties.partitioning { + Partitioning::RoundRobinBatch(partitions) => { + Partitioning::RoundRobinBatch(*partitions) + } + partitioning => { + Partitioning::UnknownPartitioning(partitioning.partition_count()) + } + }; + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + partitioning, + input_properties.emission_type, + input_properties.boundedness, + ) + .with_evaluation_type(input_properties.evaluation_type) + .with_scheduling_type(input_properties.scheduling_type); + + Ok(Self { + input, + schema, + cache: Arc::new(properties), + }) + } + + /// Input plan being aligned. + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for SchemaAlignExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SchemaAlignExec") + } + DisplayFormatType::TreeRender => Ok(()), + } + } +} + +impl ExecutionPlan for SchemaAlignExec { + fn name(&self) -> &'static str { + "SchemaAlignExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [input] = children.try_into().map_err(|children: Vec<_>| { + datafusion_common::DataFusionError::Internal(format!( + "SchemaAlignExec expected 1 child, got {}", + children.len() + )) + })?; + Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let schema = Arc::clone(&self.schema); + let stream = self.input.execute(partition, context)?.map({ + let schema = Arc::clone(&schema); + move |batch| { + let batch = batch?; + if batch.schema().as_ref() == schema.as_ref() { + Ok(batch) + } else { + RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) + .map_err(Into::into) + } + } + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } +} + /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -309,6 +506,40 @@ mod tests { Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } + fn single_field_schema(name: &str, data_type: DataType, nullable: bool) -> SchemaRef { + Arc::new(Schema::new(vec![Field::new(name, data_type, nullable)])) + } + + fn single_i32_exec(name: &str, nullable: bool) -> Arc { + empty_exec(vec![Field::new(name, DataType::Int32, nullable)]) + } + + fn field_metadata_mismatch() -> (Arc, SchemaRef) { + let input = + empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( + HashMap::from([("source".to_string(), "input".to_string())]), + )]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ + ("source".to_string(), "expected".to_string()), + ])), + ])); + (input, expected_schema) + } + + fn schema_metadata_mismatch() -> (Arc, SchemaRef) { + let input_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("renamed", DataType::Int32, false)], + HashMap::from([("source".to_string(), "expected".to_string())]), + )); + (input, expected_schema) + } + #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -410,11 +641,7 @@ mod tests { #[test] fn project_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new( - "value", - DataType::Int32, - false, - )])); + let schema = single_field_schema("value", DataType::Int32, false); let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); let result = project_plan_to_schema(Arc::clone(&input), &schema)?; @@ -475,7 +702,7 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_column_count_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let input = single_i32_exec("a", false); let expected_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -487,9 +714,8 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_type_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); - let expected_schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("a", DataType::Float32, false); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field data type differs")); @@ -497,12 +723,8 @@ mod tests { #[test] fn project_plan_to_schema_widens_nullability() -> Result<()> { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "renamed", - DataType::Int32, - true, - )])); + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("renamed", DataType::Int32, true); let result = project_plan_to_schema(input, &expected_schema)?; @@ -512,44 +734,102 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_nullability_narrowing() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "renamed", - DataType::Int32, - false, - )])); + let input = single_i32_exec("a", true); + let expected_schema = single_field_schema("renamed", DataType::Int32, false); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field nullability differs")); } #[test] - fn project_plan_to_schema_errors_on_field_metadata_mismatch() { - let input = - empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( - HashMap::from([("source".to_string(), "input".to_string())]), - )]); + fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { + let schema = single_field_schema("value", DataType::Int32, false); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let result = align_plan_to_schema(Arc::clone(&input), &schema)?; + + assert!(Arc::ptr_eq(&input, &result)); + Ok(()) + } + + #[test] + fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { + let input = single_i32_exec("recursive_a", false); + let expected_schema = single_field_schema("a", DataType::Int32, false); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let projection = result + .downcast_ref::() + .expect("rename-only alignment should use ProjectionExec"); + assert!(Arc::ptr_eq(projection.input(), &input)); + assert_eq!(projection.schema(), expected_schema); + Ok(()) + } + + #[test] + fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { + let input = single_i32_exec("a", true); + let expected_schema = single_field_schema("renamed", DataType::Int32, false); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let aligned = result + .downcast_ref::() + .expect("nullability narrowing should use SchemaAlignExec"); + assert!(Arc::ptr_eq(aligned.input(), &input)); + assert_eq!(aligned.schema(), expected_schema); + Ok(()) + } + + #[test] + fn align_plan_to_schema_errors_on_column_count_mismatch() { + let input = single_i32_exec("a", false); let expected_schema = Arc::new(Schema::new(vec![ - Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ - ("source".to_string(), "expected".to_string()), - ])), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), ])); + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("expected 2 column")); + } + + #[test] + fn align_plan_to_schema_errors_on_type_mismatch() { + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("a", DataType::Float32, false); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field data type differs")); + } + + #[test] + fn align_plan_to_schema_errors_on_field_metadata_mismatch() { + let (input, expected_schema) = field_metadata_mismatch(); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field metadata differs")); + } + + #[test] + fn align_plan_to_schema_errors_on_schema_metadata_mismatch() { + let (input, expected_schema) = schema_metadata_mismatch(); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("schema metadata differ")); + } + + #[test] + fn project_plan_to_schema_errors_on_field_metadata_mismatch() { + let (input, expected_schema) = field_metadata_mismatch(); + let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field metadata differs")); } #[test] fn project_plan_to_schema_errors_on_schema_metadata_mismatch() { - let input_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("a", DataType::Int32, false)], - HashMap::from([("source".to_string(), "input".to_string())]), - )); - let input: Arc = Arc::new(EmptyExec::new(input_schema)); - let expected_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("renamed", DataType::Int32, false)], - HashMap::from([("source".to_string(), "expected".to_string())]), - )); + let (input, expected_schema) = schema_metadata_mismatch(); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("schema metadata differ")); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index c160f9a0dc76..5c79750332cc 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,7 +24,7 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::common::project_plan_to_schema; +use crate::common::align_plan_to_schema; use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, @@ -35,7 +35,7 @@ use crate::{ }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -90,13 +90,13 @@ impl RecursiveQueryExec { ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Use the same work table for both the WorkTableExec and the recursive term - let output_schema = - recursive_output_schema(&static_term.schema(), &recursive_term.schema()); - let static_term = project_plan_to_schema(static_term, &output_schema)?; + // Use the static term as the declared recursive CTE output schema. The + // recursive term is planned independently, so align it at plan construction + // time instead of patching batches in RecursiveQueryStream. + let output_schema = static_term.schema(); let recursive_term = assign_work_table(recursive_term, &work_table)?; - let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; - let cache = Self::compute_properties(output_schema); + let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; + let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { name, static_term, @@ -370,30 +370,6 @@ impl RecursiveQueryStream { } } -fn recursive_output_schema( - static_schema: &SchemaRef, - recursive_schema: &SchemaRef, -) -> SchemaRef { - let fields = static_schema - .fields() - .iter() - .zip(recursive_schema.fields()) - .map(|(static_field, recursive_field)| { - Field::new( - static_field.name(), - static_field.data_type().clone(), - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .with_metadata(static_field.metadata().clone()) - }) - .collect::>(); - - Arc::new(Schema::new_with_metadata( - fields, - static_schema.metadata().clone(), - )) -} - fn assign_work_table( plan: Arc, work_table: &Arc, @@ -520,6 +496,7 @@ fn new_groups_mask( #[cfg(test)] mod tests { use super::*; + use crate::common::SchemaAlignExec; use crate::empty::EmptyExec; use crate::projection::ProjectionExec; @@ -529,18 +506,25 @@ mod tests { Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } + fn recursive_exec( + static_term: Arc, + recursive_term: Arc, + ) -> Result { + RecursiveQueryExec::try_new( + "numbers".to_string(), + static_term, + recursive_term, + false, + ) + } + #[test] fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); - let exec = RecursiveQueryExec::try_new( - "numbers".to_string(), - Arc::clone(&static_term), - Arc::clone(&recursive_term), - false, - )?; + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; assert_eq!(exec.schema(), static_term.schema()); let projection = exec @@ -554,21 +538,23 @@ mod tests { } #[test] - fn recursive_query_exec_reconciles_nullability() -> Result<()> { + fn recursive_query_exec_preserves_static_nullability_contract() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); - let exec = RecursiveQueryExec::try_new( - "numbers".to_string(), - static_term, - recursive_term, - false, - )?; + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; - assert!(exec.schema().field(0).is_nullable()); - assert!(exec.static_term().schema().field(0).is_nullable()); - assert!(exec.recursive_term().schema().field(0).is_nullable()); + let static_schema = static_term.schema(); + assert_eq!(exec.schema(), static_schema); + assert_eq!(exec.static_term().schema(), static_schema); + assert_eq!(exec.recursive_term().schema(), static_schema); + assert!(!exec.schema().field(0).is_nullable()); + let aligned = exec + .recursive_term() + .downcast_ref::() + .expect("nullable recursive term should be aligned with SchemaAlignExec"); + assert!(Arc::ptr_eq(aligned.input(), &recursive_term)); Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d13e0d4f085e..bb5a18d53d82 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -699,7 +699,7 @@ WITH RECURSIVE region_sales AS ( SELECT s.salesperson_id AS salesperson_id, SUM(s.sale_amount) AS amount, - SUM(0) as level + 0 as level FROM sales s GROUP BY