Skip to content

Commit 40b788e

Browse files
timsaucerclaude
andcommitted
feat: expose dot_product alias for inner_product
Match upstream DataFusion SQL alias surface (inner_product UDF registers `dot_product` in its alias list). Also expand `inner_product` docstring with NULL/length-mismatch behavior to match peer distance fns added in this PR. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 1755a78 commit 40b788e

2 files changed

Lines changed: 49 additions & 2 deletions

File tree

python/datafusion/functions.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
"degrees",
175175
"dense_rank",
176176
"digest",
177+
"dot_product",
177178
"element_at",
178179
"empty",
179180
"encode",
@@ -3319,7 +3320,16 @@ def cosine_distance(array1: Expr, array2: Expr) -> Expr:
33193320
def inner_product(array1: Expr, array2: Expr) -> Expr:
33203321
"""Returns the inner (dot) product of two numeric arrays.
33213322
3322-
The SQL name ``dot_product`` is an alias for this function in raw SQL.
3323+
Treats each input as a vector and returns the sum of the element-wise
3324+
products: ``sum(array1[i] * array2[i])``. For ``[1, 2, 3]`` and
3325+
``[4, 5, 6]`` the result is ``1*4 + 2*5 + 3*6 = 32``.
3326+
3327+
Also available as :py:func:`dot_product` (and as ``dot_product`` in
3328+
raw SQL).
3329+
3330+
Both arrays must have the same length; otherwise execution fails. NULL
3331+
is returned when either input array is NULL or when any element of
3332+
either array is NULL.
33233333
33243334
Examples:
33253335
>>> ctx = dfn.SessionContext()
@@ -3333,10 +3343,32 @@ def inner_product(array1: Expr, array2: Expr) -> Expr:
33333343
... )
33343344
>>> result.collect_column("result")[0].as_py()
33353345
32.0
3346+
3347+
NULL elements propagate to NULL output:
3348+
3349+
>>> df_null = ctx.from_pydict(
3350+
... {"a": [[1.0, None, 3.0]], "b": [[4.0, 5.0, 6.0]]}
3351+
... )
3352+
>>> result = df_null.select(
3353+
... dfn.functions.inner_product(
3354+
... dfn.col("a"), dfn.col("b")
3355+
... ).alias("result")
3356+
... )
3357+
>>> result.collect_column("result")[0].as_py() is None
3358+
True
33363359
"""
33373360
return Expr(f.inner_product(array1.expr, array2.expr))
33383361

33393362

3363+
def dot_product(array1: Expr, array2: Expr) -> Expr:
3364+
"""Returns the inner (dot) product of two numeric arrays.
3365+
3366+
See Also:
3367+
This is an alias for :py:func:`inner_product`.
3368+
"""
3369+
return inner_product(array1, array2)
3370+
3371+
33403372
def list_cat(*args: Expr) -> Expr:
33413373
"""Concatenates the input arrays.
33423374

python/tests/test_functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,22 @@ def test_array_function_aliases(alias_fn, primary_fn, data):
735735
)
736736

737737

738-
@pytest.mark.parametrize("fn", [f.cosine_distance, f.inner_product])
738+
def test_dot_product_alias_matches_inner_product():
739+
"""dot_product should be an exact alias for inner_product."""
740+
ctx = SessionContext()
741+
df = ctx.from_pydict({"a": [[1.0, 2.0, 3.0]], "b": [[4.0, 5.0, 6.0]]})
742+
alias_result = df.select(
743+
f.dot_product(column("a"), column("b")).alias("r")
744+
).collect()
745+
primary_result = df.select(
746+
f.inner_product(column("a"), column("b")).alias("r")
747+
).collect()
748+
assert (
749+
alias_result[0].column(0).to_pylist() == primary_result[0].column(0).to_pylist()
750+
)
751+
752+
753+
@pytest.mark.parametrize("fn", [f.cosine_distance, f.inner_product, f.dot_product])
739754
def test_array_distance_length_mismatch_raises(fn):
740755
"""Length-mismatched inputs to vector distance fns should raise at execute."""
741756
ctx = SessionContext()

0 commit comments

Comments
 (0)