Skip to content
Open
163 changes: 161 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub fn fields_with_udf<F: UDFCoercionExt>(
let valid_types = get_valid_types_with_udf(type_signature, &current_types, func)?;
if valid_types
.iter()
.any(|data_type| data_type == &current_types)
.any(|data_type| data_types_match(data_type, &current_types))
{
return Ok(current_fields.to_vec());
}
Expand Down Expand Up @@ -236,7 +236,7 @@ pub fn data_types(
get_valid_types(function_name.as_ref(), type_signature, current_types)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
.any(|data_type| data_types_match(data_type, current_types))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to change this function, it's deprecated and not used anywhere

{
return Ok(current_types.to_vec());
}
Expand Down Expand Up @@ -307,6 +307,34 @@ fn try_coerce_types(
)
}

fn data_types_match(valid_types: &[DataType], current_types: &[DataType]) -> bool {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we aren't handling Map, Struct, or ListView -- is there a reason for that? In fact, the original bug report uses Map.

I wonder if we can simplify this to use equals_datatype from Arrow, as suggested by the bug reporter?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for mention.

The original reproducer does go through a Map, but the actual mismatch at the coercion point is on the extracted map value (List<Struct<...>>), not on the Map type itself. That’s why I kept the fast-path relaxation narrow.

I did try a broader equals_datatype approach first, but it was too permissive in this path and regressed existing cases where runtime kernels still require exact type identity, especially around Struct. I agree ListView / LargeListView should be handled consistently with List / LargeList, and I’ve updated the matcher for that.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it.

Just to help me understand, can you point at an SLT test (e.g., involving structs) that would regress if we used equals_datatype? Or if such an SLT test doesn't already exist, it would probably be a good idea to add one as a sanity check.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I locally verified that a broader equals_datatype-style matcher regresses existing SLTs.

In particular:

  • /datafusion/datafusion/sqllogictest/test_files/struct.slt
    select [{a: 1, b: 2}, {b: 3, a: 4}];
  • /datafusion/datafusion/sqllogictest/test_files/spark/array/array.slt
    SELECT array(arrow_cast(array(1,2), 'LargeList(Int64)'), array(3));

With the broader matching, both end up failing in array construction (MutableArrayData) because those paths still require exact runtime type identity. That was the main reason I kept this matcher narrower than equals_datatype, especially around Struct.

I agree it would be useful to make that boundary explicit, so I can also add a focused sanity-check regression test in this PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! That makes sense: the key point is that some Arrow kernels depend on struct field ordering, but the "field name" of a list has no influence on the representation of the data. Can we add a brief comment to data_type_matches to explain the rationale for the kinda-structural-equality we are implementing?

It seems like Map has the same behavior as the List variants: the "field name" does not impact the representation of the data. Should we handle that as well, for completeness?

fn field_matches(valid: &FieldRef, current: &FieldRef) -> bool {
valid.is_nullable() == current.is_nullable()
&& data_type_matches(valid.data_type(), current.data_type())
}

fn data_type_matches(valid: &DataType, current: &DataType) -> bool {
match (valid, current) {
(valid, current) if valid == current => true,
(DataType::List(valid), DataType::List(current))
| (DataType::LargeList(valid), DataType::LargeList(current)) => {
field_matches(valid, current)
}
(
DataType::FixedSizeList(valid, valid_size),
DataType::FixedSizeList(current, current_size),
) => valid_size == current_size && field_matches(valid, current),
_ => false,
}
}

valid_types.len() == current_types.len()
&& valid_types
.iter()
.zip(current_types)
.all(|(valid_type, current_type)| data_type_matches(valid_type, current_type))
}

fn get_valid_types_with_udf<F: UDFCoercionExt>(
signature: &TypeSignature,
current_types: &[DataType],
Expand Down Expand Up @@ -757,6 +785,10 @@ fn maybe_data_types(
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = &current_types[i];

// Keep exact equality here. Some kernels such as `make_array`
// require nested field names/order to match exactly at runtime.
// Structural-equivalence short-circuiting is handled earlier by
// `data_types_match`.
Comment on lines +805 to +808
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this reasoning and how it doesn't apply to the fix applied in fields_with_udf directly? As in, fields_with_udf first tries to see if the input fields already match with the valid types of the UDF and returns early if so (and this is where the PR applies a fix for equality). Then if it doesn't match exactly it'll try to coerce, which is this code path here. So why does coercion have this catch of nested fields needing to match exactly at runtime when the exact path doesn't?

if current_type == valid_type {
new_type.push(current_type.clone())
} else {
Expand Down Expand Up @@ -789,6 +821,10 @@ fn maybe_data_types_without_coercion(
for (i, valid_type) in valid_types.iter().enumerate() {
let current_type = &current_types[i];

// Keep exact equality here. Some kernels such as `make_array`
// require nested field names/order to match exactly at runtime.
// Structural-equivalence short-circuiting is handled earlier by
// `data_types_match`.
if current_type == valid_type {
new_type.push(current_type.clone())
} else if can_cast_types(current_type, valid_type) {
Expand Down Expand Up @@ -1044,6 +1080,99 @@ mod tests {
}
}

#[test]
fn test_maybe_data_types_uses_exact_nested_types() {
let struct_fields = vec![
Field::new("id", DataType::Utf8, true),
Field::new("prim", DataType::Boolean, true),
];
let current_type = DataType::List(Arc::new(Field::new(
"item",
DataType::Struct(struct_fields.clone().into()),
true,
)));
let valid_type = DataType::List(Arc::new(Field::new(
"element",
DataType::Struct(struct_fields.into()),
true,
)));

assert!(current_type.equals_datatype(&valid_type));
assert_ne!(current_type, valid_type);
assert_eq!(
maybe_data_types(std::slice::from_ref(&valid_type), &[current_type]),
Some(vec![valid_type])
);
}

#[test]
fn test_maybe_data_types_without_coercion_uses_exact_nested_types() {
let valid_type = DataType::Struct(
vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
]
.into(),
);
let current_type = DataType::Struct(
vec![
Field::new("b", DataType::Int64, true),
Field::new("a", DataType::Int64, true),
]
.into(),
);

assert!(current_type.equals_datatype(&valid_type));
assert_ne!(current_type, valid_type);
assert_eq!(
maybe_data_types_without_coercion(
std::slice::from_ref(&valid_type),
&[current_type],
),
Some(vec![valid_type])
);
}

#[test]
fn test_data_types_match_ignores_list_field_name() {
let struct_fields = vec![
Field::new("id", DataType::Utf8, true),
Field::new("prim", DataType::Boolean, true),
];
let current_type = DataType::List(Arc::new(Field::new(
"item",
DataType::Struct(struct_fields.clone().into()),
true,
)));
let valid_type = DataType::List(Arc::new(Field::new(
"element",
DataType::Struct(struct_fields.into()),
true,
)));

assert!(data_types_match(&[valid_type], &[current_type]));
}

#[test]
fn test_data_types_match_respects_struct_field_order() {
let valid_type = DataType::Struct(
vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Int64, true),
]
.into(),
);
let current_type = DataType::Struct(
vec![
Field::new("b", DataType::Int64, true),
Field::new("a", DataType::Int64, true),
]
.into(),
);

assert!(!data_types_match(&[valid_type], &[current_type]));
}

#[test]
fn test_get_valid_types_numeric() -> Result<()> {
let get_valid_types_flatten =
Expand Down Expand Up @@ -1223,6 +1352,36 @@ mod tests {
Ok(())
}

#[test]
fn test_fields_with_udf_preserves_equivalent_nested_types() -> Result<()> {
let struct_fields = vec![
Field::new("id", DataType::Utf8, true),
Field::new("prim", DataType::Boolean, true),
];
let current_type = DataType::List(Arc::new(Field::new(
"item",
DataType::Struct(struct_fields.clone().into()),
true,
)));
let signature_type = DataType::List(Arc::new(Field::new(
"element",
DataType::Struct(struct_fields.into()),
true,
)));

assert!(current_type.equals_datatype(&signature_type));

let current_fields = vec![Arc::new(Field::new("field", current_type, true))];
let coerced_fields = fields_with_udf(
&current_fields,
&MockUdf(Signature::exact(vec![signature_type], Volatility::Stable)),
)?;

assert_eq!(coerced_fields, current_fields);

Ok(())
}

#[test]
fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
let type_into = DataType::FixedSizeList(
Expand Down
36 changes: 36 additions & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,42 @@ mod test {
)
}

#[test]
Comment thread
feichai0017 marked this conversation as resolved.
Outdated
fn scalar_function_preserves_equivalent_nested_types() -> Result<()> {
let struct_fields = vec![
Field::new("id", Utf8, true),
Field::new("prim", DataType::Boolean, true),
];
let current_type = DataType::List(Arc::new(Field::new(
"item",
DataType::Struct(struct_fields.clone().into()),
true,
)));
let signature_type = DataType::List(Arc::new(Field::new(
"element",
DataType::Struct(struct_fields.into()),
true,
)));
let empty = empty_with_type(current_type);
let fun = ScalarUDF::new_from_impl(TestScalarUDF {
signature: Signature::exact(vec![signature_type], Volatility::Stable),
});
let scalar_function_expr =
Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(fun), vec![col("a")]));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
)?);

assert_analyzed_plan_eq!(
plan,
@r"
Projection: TestScalarUDF(a)
EmptyRelation: rows=0
"
)
}

#[test]
fn agg_udaf() -> Result<()> {
let empty = empty();
Expand Down
Loading