Skip to content

Commit ae04178

Browse files
committed
feat(dataframe): update group_by to accept None and normalize to empty list
- Updated `group_by` method to accept `None` and normalize it to an empty list. - Improved docstring for clarity. - Added regression test in `test_dataframe.py` to verify that `None` equals an empty list. - Updated documentation to mention that `group_by=None` is now supported.
1 parent bfa14f4 commit ae04178

3 files changed

Lines changed: 25 additions & 14 deletions

File tree

docs/source/user-guide/common-operations/aggregations.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ to form a single summary value. For performing an aggregation, DataFusion provid
4141
f.approx_median(col_speed).alias("Median Speed"),
4242
f.approx_percentile_cont(col_speed, 0.9).alias("90% Speed")])
4343
44-
When the :code:`group_by` list is empty the aggregation is done over the whole :class:`.DataFrame`.
45-
For grouping the :code:`group_by` list must contain at least one column.
44+
When :code:`group_by` is :code:`None` or an empty list, the aggregation is done over the whole
45+
:class:`.DataFrame`. For grouping the :code:`group_by` list must contain at least one column.
4646

4747
.. ipython:: python
4848

python/datafusion/dataframe.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def with_column_renamed(self, old_name: str, new_name: str) -> DataFrame:
798798

799799
def aggregate(
800800
self,
801-
group_by: Sequence[Expr | str] | Expr | str,
801+
group_by: Sequence[Expr | str] | Expr | str | None,
802802
aggs: Sequence[Expr] | Expr,
803803
) -> DataFrame:
804804
"""Aggregates the rows of the current DataFrame.
@@ -816,23 +816,24 @@ def aggregate(
816816
817817
Args:
818818
group_by: Sequence of expressions or column names to group
819-
by. A :py:class:`~datafusion.expr.GroupingSet`
820-
expression may be included to produce multiple grouping
821-
levels (rollup, cube, or explicit grouping sets).
819+
by, or ``None`` for aggregation over the whole DataFrame.
820+
A :py:class:`~datafusion.expr.GroupingSet` expression may
821+
be included to produce multiple grouping levels (rollup,
822+
cube, or explicit grouping sets).
822823
aggs: Sequence of expressions to aggregate.
823824
824825
Returns:
825826
DataFrame after aggregation.
826827
827828
Examples:
828-
Aggregate without grouping — an empty ``group_by`` produces a
829-
single row:
829+
Aggregate without grouping — ``None`` or an empty ``group_by``
830+
produces a single row:
830831
831832
>>> ctx = dfn.SessionContext()
832833
>>> df = ctx.from_pydict(
833834
... {"team": ["x", "x", "y"], "score": [1, 2, 5]}
834835
... )
835-
>>> df.aggregate([], [F.sum(col("score")).alias("total")]).to_pydict()
836+
>>> df.aggregate(None, [F.sum(col("score")).alias("total")]).to_pydict()
836837
{'total': [8]}
837838
838839
Group by a column and produce one row per group:
@@ -842,11 +843,15 @@ def aggregate(
842843
... ).sort("team").to_pydict()
843844
{'team': ['x', 'y'], 'total': [3, 5]}
844845
"""
845-
group_by_list = (
846-
list(group_by)
847-
if isinstance(group_by, Sequence) and not isinstance(group_by, Expr | str)
848-
else [group_by]
849-
)
846+
if group_by is None:
847+
group_by_list = []
848+
else:
849+
group_by_list = (
850+
list(group_by)
851+
if isinstance(group_by, Sequence)
852+
and not isinstance(group_by, Expr | str)
853+
else [group_by]
854+
)
850855
aggs_list = (
851856
list(aggs)
852857
if isinstance(aggs, Sequence) and not isinstance(aggs, Expr)

python/tests/test_dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,12 @@ def test_aggregate_tuple_group_by(df):
475475
assert result_tuple == result_list
476476

477477

478+
def test_aggregate_none_group_by_equivalent_to_empty_list(df):
479+
result_none = df.aggregate(None, [f.count()]).to_pydict()
480+
result_empty = df.aggregate([], [f.count()]).to_pydict()
481+
assert result_none == result_empty
482+
483+
478484
def test_aggregate_tuple_aggs(df):
479485
result_list = df.aggregate("a", [f.count()]).sort("a").to_pydict()
480486
result_tuple = df.aggregate("a", (f.count(),)).sort("a").to_pydict()

0 commit comments

Comments
 (0)