diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ca6941a61ce6..0858c1ad84081 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -4144,6 +4144,8 @@ fn calc_func_dependencies_for_project( exprs: &[Expr], input: &LogicalPlan, ) -> Result { + const COMPUTED_EXPR_INDEX: usize = usize::MAX; + let input_fields = input.schema().field_names(); // Calculate expression indices (if present) in the input schema. let proj_indices = exprs @@ -4161,30 +4163,33 @@ fn calc_func_dependencies_for_project( Ok::<_, DataFusionError>( wildcard_fields .into_iter() - .filter_map(|(qualifier, f)| { + .map(|(qualifier, f)| { let flat_name = qualifier .map(|t| format!("{}.{}", t, f.name())) .unwrap_or_else(|| f.name().clone()); - input_fields.iter().position(|item| *item == flat_name) + input_fields + .iter() + .position(|item| *item == flat_name) + .unwrap_or(COMPUTED_EXPR_INDEX) }) .collect::>(), ) } Expr::Alias(alias) => { let name = format!("{}", alias.expr); - Ok(input_fields + let input_index = input_fields .iter() .position(|item| *item == name) - .map(|i| vec![i]) - .unwrap_or(vec![])) + .unwrap_or(COMPUTED_EXPR_INDEX); + Ok(vec![input_index]) } _ => { let name = format!("{expr}"); - Ok(input_fields + let input_index = input_fields .iter() .position(|item| *item == name) - .map(|i| vec![i]) - .unwrap_or(vec![])) + .unwrap_or(COMPUTED_EXPR_INDEX); + Ok(vec![input_index]) } }) .collect::>>()? @@ -4947,6 +4952,30 @@ mod tests { ]) } + #[test] + fn projection_with_leading_computed_column_preserves_pk() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let source = Arc::new( + LogicalTableSource::new(Arc::new(employee_schema())) + .with_constraints(constraints), + ); + let plan = LogicalPlanBuilder::scan("employee_csv", source, None)? + .project(vec![ + lit(1i32).alias("__common_expr_1"), + col("id"), + col("first_name"), + col("salary"), + ])? + .build()?; + + let deps = plan.schema().functional_dependencies(); + assert_eq!(deps.len(), 1); + assert_eq!(deps[0].source_indices, vec![1]); + + Ok(()) + } + fn i32_split_point(value: i32) -> SplitPoint { SplitPoint::new(vec![ScalarValue::Int32(Some(value))]) } diff --git a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part index ad23cd9079d48..e678f8b440dd4 100644 --- a/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/plans/q17.slt.part @@ -39,7 +39,7 @@ logical_plan 01)Projection: CAST(sum(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly 02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] 03)----Projection: lineitem.l_extendedprice -04)------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * avg(lineitem.l_quantity) +04)------LeftSemi Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * avg(lineitem.l_quantity) 05)--------Projection: lineitem.l_quantity, lineitem.l_extendedprice, part.p_partkey 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] @@ -55,7 +55,7 @@ physical_plan 02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice)] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice)] -05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * avg(lineitem.l_quantity)@1, projection=[l_extendedprice@1] +05)--------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * avg(lineitem.l_quantity)@1, projection=[l_extendedprice@1] 06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_quantity@1, l_extendedprice@2, p_partkey@3] 07)------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice], constraints=[PrimaryKey([0, 3])], file_type=csv, has_header=false