From 557d2214043c53eae56c8d4122cb74f9a72621a0 Mon Sep 17 00:00:00 2001 From: Christian McArthur Date: Mon, 4 May 2026 20:45:49 -0400 Subject: [PATCH] feat: add array_normalize scalar function --- .../functions-nested/src/array_normalize.rs | 207 ++++++++++++++++++ datafusion/functions-nested/src/lib.rs | 3 + .../test_files/array_normalize.slt | 146 ++++++++++++ .../source/user-guide/sql/scalar_functions.md | 33 +++ 4 files changed, 389 insertions(+) create mode 100644 datafusion/functions-nested/src/array_normalize.rs create mode 100644 datafusion/sqllogictest/test_files/array_normalize.slt diff --git a/datafusion/functions-nested/src/array_normalize.rs b/datafusion/functions-nested/src/array_normalize.rs new file mode 100644 index 0000000000000..eaa2b29bcfe7e --- /dev/null +++ b/datafusion/functions-nested/src/array_normalize.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_normalize function. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Float64Array, GenericListArray, OffsetSizeTrait}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, internal_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayNormalize, + array_normalize, + array, + "returns the L2-normalized vector for a numeric array.", + array_normalize_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the L2-normalized vector for the input numeric array, computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns NULL if the input is NULL, contains NULL elements, or has zero magnitude (all elements are zero). Returns an empty array for an empty input array.", + syntax_example = "array_normalize(array)", + sql_example = r#"```sql +> select array_normalize([3.0, 4.0]); ++-----------------------------+ +| array_normalize(List([3.0,4.0])) | ++-----------------------------+ +| [0.6, 0.8] | ++-----------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayNormalize { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayNormalize { + fn default() -> Self { + Self::new() + } +} + +impl ArrayNormalize { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_normalize".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayNormalize { + fn name(&self) -> &str { + "array_normalize" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // After `coerce_types`, `arg_types[0]` is one of List(Float64) or LargeList(Float64). + Ok(arg_types[0].clone()) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [arg_type] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + + if !matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + return plan_err!("{} does not support type {arg_type}", self.name()); + } + + let coerced = if matches!(arg_type, Null) { + List(Arc::new(Field::new_list_field(DataType::Float64, true))) + } else { + coerced_type_with_base_type_only(arg_type, &DataType::Float64, coercion) + }; + + Ok(vec![coerced]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_normalize_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_normalize_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_normalize", args)?; + match array.data_type() { + List(_) => general_array_normalize::(args), + LargeList(_) => general_array_normalize::(args), + arg_type => internal_err!( + "array_normalize received unexpected type after coercion: {arg_type}" + ), + } +} + +fn general_array_normalize(arrays: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&arrays[0])?; + let values = as_float64_array(list_array.values())?; + let offsets = list_array.value_offsets(); + + let mut new_values: Vec = Vec::with_capacity(values.len()); + let mut new_offsets: Vec = Vec::with_capacity(list_array.len() + 1); + new_offsets.push(O::usize_as(0)); + let mut validity: Vec = Vec::with_capacity(list_array.len()); + + for row in 0..list_array.len() { + if list_array.is_null(row) { + new_offsets.push(*new_offsets.last().unwrap()); + validity.push(false); + continue; + } + + let start = offsets[row].as_usize(); + let end = offsets[row + 1].as_usize(); + let len = end - start; + + let slice = values.slice(start, len); + if slice.null_count() != 0 { + new_offsets.push(*new_offsets.last().unwrap()); + validity.push(false); + continue; + } + + let vals = slice.values(); + + // Empty array: return empty array (no normalization needed, no division by zero risk) + if len == 0 { + new_offsets.push(*new_offsets.last().unwrap()); + validity.push(true); + continue; + } + + // Compute squared magnitude. + let mut sq_sum = 0.0; + for i in 0..len { + sq_sum += vals[i] * vals[i]; + } + + // Zero magnitude: undefined normalization. Emit NULL row. + if sq_sum == 0.0 { + new_offsets.push(*new_offsets.last().unwrap()); + validity.push(false); + continue; + } + + let mag = sq_sum.sqrt(); + for i in 0..len { + new_values.push(vals[i] / mag); + } + new_offsets.push(*new_offsets.last().unwrap() + O::usize_as(len)); + validity.push(true); + } + + let values_array = Arc::new(Float64Array::from(new_values)); + let nulls = NullBuffer::from(validity); + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + + Ok(Arc::new(GenericListArray::::try_new( + field, + OffsetBuffer::::new(new_offsets.into()), + values_array, + Some(nulls), + )?)) +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index e4d08a9a5e860..6c4556902b543 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -43,6 +43,7 @@ pub mod macros_lambda; pub mod array_any_match; pub mod array_compact; pub mod array_has; +pub mod array_normalize; pub mod array_transform; pub mod arrays_zip; pub mod cardinality; @@ -90,6 +91,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_normalize::array_normalize; pub use super::array_transform::array_transform; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; @@ -164,6 +166,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), + array_normalize::array_normalize_udf(), cosine_distance::cosine_distance_udf(), inner_product::inner_product_udf(), distance::array_distance_udf(), diff --git a/datafusion/sqllogictest/test_files/array_normalize.slt b/datafusion/sqllogictest/test_files/array_normalize.slt new file mode 100644 index 0000000000000..ba4711d02cf9d --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_normalize.slt @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## array_normalize + +# 3-4-5 right triangle: [3,4] / 5 = [0.6, 0.8] +query ? +select array_normalize([3.0, 4.0]); +---- +[0.6, 0.8] + +# Already-unit vector along axis: [1,0] -> [1,0] +query ? +select array_normalize([1.0, 0.0]); +---- +[1.0, 0.0] + +# 3D vector: [1,2,2] / 3 = [0.333..., 0.666..., 0.666...] +query ? +select array_normalize([1.0, 2.0, 2.0]); +---- +[0.3333333333333333, 0.6666666666666666, 0.6666666666666666] + +# Negative components preserved +query ? +select array_normalize([-3.0, 4.0]); +---- +[-0.6, 0.8] + +# Bare NULL input returns NULL +query ? +select array_normalize(NULL); +---- +NULL + +# NULL element inside a list returns NULL for that row +query ? +select array_normalize([1.0, NULL, 2.0]); +---- +NULL + +# Zero vector returns NULL (undefined normalization) +query ? +select array_normalize([0.0, 0.0]); +---- +NULL + +# Single non-zero component: [5] -> [1] +query ? +select array_normalize([5.0]); +---- +[1.0] + +# LargeList support +query ? +select array_normalize(arrow_cast([3.0, 4.0], 'LargeList(Float64)')); +---- +[0.6, 0.8] + +# FixedSizeList input (coerced to List) +query ? +select array_normalize(arrow_cast([3.0, 4.0], 'FixedSizeList(2, Float64)')); +---- +[0.6, 0.8] + +# Float32 inner type (coerced to Float64) +query ? +select array_normalize(arrow_cast([3.0, 4.0], 'List(Float32)')); +---- +[0.6, 0.8] + +# Int64 inner type (coerced to Float64) +query ? +select array_normalize(arrow_cast([3, 4], 'List(Int64)')); +---- +[0.6, 0.8] + +# Integer literals (coerced to Float64) +query ? +select array_normalize([3, 4]); +---- +[0.6, 0.8] + +# Unsupported non-list input (plan error) +query error array_normalize does not support type +select array_normalize(1); + +# Multi-row query: normal row, NULL row, zero-vector row, NULL-element row +query ? +select array_normalize(column1) from (values + (make_array(3.0, 4.0)), + (NULL), + (make_array(0.0, 0.0)), + (make_array(1.0, NULL)) +) as t(column1); +---- +[0.6, 0.8] +NULL +NULL +NULL + +# Empty array: returns empty array (no normalization needed, no division by zero) +query ? +select array_normalize(arrow_cast(make_array(), 'List(Float64)')); +---- +[] + +# No arguments error +query error array_normalize function requires 1 argument, got 0 +select array_normalize(); + +# Return type matches input variant (List of Float64) +query ?T +select array_normalize([3.0, 4.0]), arrow_typeof(array_normalize([3.0, 4.0])); +---- +[0.6, 0.8] List(Float64) + +# list_normalize alias produces the same result +query ? +select list_normalize([3.0, 4.0]); +---- +[0.6, 0.8] + +# list_normalize alias multi-row with NULL row +query ? +select list_normalize(column1) from (values + (make_array(3.0, 4.0)), + (NULL) +) as t(column1); +---- +[0.6, 0.8] +NULL diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 260f69f737d1b..b1350a36af8ea 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3259,6 +3259,7 @@ _Alias of [current_date](#current_date)._ - [array_max](#array_max) - [array_min](#array_min) - [array_ndims](#array_ndims) +- [array_normalize](#array_normalize) - [array_pop_back](#array_pop_back) - [array_pop_front](#array_pop_front) - [array_position](#array_position) @@ -3312,6 +3313,7 @@ _Alias of [current_date](#current_date)._ - [list_length](#list_length) - [list_max](#list_max) - [list_ndims](#list_ndims) +- [list_normalize](#list_normalize) - [list_pop_back](#list_pop_back) - [list_pop_front](#list_pop_front) - [list_position](#list_position) @@ -3884,6 +3886,33 @@ array_ndims(array, element) - list_ndims +### `array_normalize` + +Returns the L2-normalized vector for the input numeric array, computed as `array[i] / sqrt(sum(array[i]^2))` per element. Returns NULL if the input is NULL, contains NULL elements, or has zero magnitude (all elements are zero). Returns an empty array for an empty input array. + +```sql +array_normalize(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_normalize([3.0, 4.0]); ++-----------------------------+ +| array_normalize(List([3.0,4.0])) | ++-----------------------------+ +| [0.6, 0.8] | ++-----------------------------+ +``` + +#### Aliases + +- list_normalize + ### `array_pop_back` Returns the array without the last element. @@ -4769,6 +4798,10 @@ _Alias of [array_max](#array_max)._ _Alias of [array_ndims](#array_ndims)._ +### `list_normalize` + +_Alias of [array_normalize](#array_normalize)._ + ### `list_pop_back` _Alias of [array_pop_back](#array_pop_back)._