From 21be6ffda404733e6f33be2b6cd41ccf21eaa2bf Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:40:57 +0100 Subject: [PATCH 1/6] add Aggregation API --- .../dataframe_api/__init__.py | 1 + .../dataframe_api/groupby_object.py | 75 ++++++++++++++++++- .../API_specification/dataframe_api/typing.py | 7 +- spec/API_specification/examples/tpch/q5.py | 5 +- spec/conf.py | 1 + 5 files changed, 82 insertions(+), 7 deletions(-) diff --git a/spec/API_specification/dataframe_api/__init__.py b/spec/API_specification/dataframe_api/__init__.py index 7f4d17d4..fd11ca70 100644 --- a/spec/API_specification/dataframe_api/__init__.py +++ b/spec/API_specification/dataframe_api/__init__.py @@ -40,6 +40,7 @@ "Duration", "String", "is_dtype", + "Aggregation", ] diff --git a/spec/API_specification/dataframe_api/groupby_object.py b/spec/API_specification/dataframe_api/groupby_object.py index 0ccefebe..f1fb1163 100644 --- a/spec/API_specification/dataframe_api/groupby_object.py +++ b/spec/API_specification/dataframe_api/groupby_object.py @@ -1,12 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol if TYPE_CHECKING: from .dataframe_object import DataFrame -__all__ = ['GroupBy'] +__all__ = [ + "Aggregation", + "GroupBy", +] class GroupBy: @@ -51,3 +54,71 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr def size(self) -> DataFrame: ... + + def aggregate(self, *aggregation: Aggregation) -> DataFrame: + """ + Aggregate columns according to given aggregation function. + + Examples + -------- + >>> df: DataFrame + >>> namespace = df.__dataframe_namespace__() + >>> df.group_by('year').aggregate( + ... namespace.Aggregation.sum('l_quantity').rename('sum_qty'), + ... namespace.Aggregation.mean('l_quantity').rename('avg_qty'), + ... namespace.Aggregation.mean('l_extended_price').rename('avg_price'), + ... namespace.Aggregation.mean('l_discount').rename('avg_disc'), + ... namespace.Aggregation.size().rename('count_order'), + ... ) + """ + ... + +class Aggregation(Protocol): + def rename(self, name: str) -> Aggregation: + """Assign given name to output of aggregation. """ + ... + + @classmethod + def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def size(cls) -> Aggregation: + ... + diff --git a/spec/API_specification/dataframe_api/typing.py b/spec/API_specification/dataframe_api/typing.py index 4b157a9d..85033b52 100644 --- a/spec/API_specification/dataframe_api/typing.py +++ b/spec/API_specification/dataframe_api/typing.py @@ -15,7 +15,7 @@ from dataframe_api.column_object import Column from dataframe_api.dataframe_object import DataFrame -from dataframe_api.groupby_object import GroupBy +from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT if TYPE_CHECKING: from .dtypes import ( @@ -147,7 +147,9 @@ def is_null(value: object, /) -> bool: @staticmethod def is_dtype(dtype: Any, kind: str | tuple[str, ...]) -> bool: ... - + + class Aggregation(AggregationT): + ... class SupportsDataFrameAPI(Protocol): def __dataframe_consortium_standard__( @@ -163,6 +165,7 @@ def __column_consortium_standard__( __all__ = [ + "Aggregation", "Column", "DataFrame", "DType", diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index cdca0806..861741fe 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -65,10 +65,9 @@ def query( new_column = ( result.get_column_by_name("l_extendedprice") - * (1 - result.get_column_by_name("l_discount")) + * (result.get_column_by_name("l_discount") * -1 + 1) ).rename("revenue") result = result.assign(new_column) - result = result.select(["revenue", "n_name"]) - result = result.group_by("n_name").sum() + result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue")) return result.dataframe diff --git a/spec/conf.py b/spec/conf.py index 8302ee43..ef7eedb2 100644 --- a/spec/conf.py +++ b/spec/conf.py @@ -84,6 +84,7 @@ ('py:class', 'Scalar'), ('py:class', 'Bool'), ('py:class', 'optional'), + ('py:class', 'Aggregation'), ('py:class', 'NullType'), ('py:class', 'Namespace'), ('py:class', 'SupportsDataFrameAPI'), From e0681ab18af782eea3952708be177b1695db9fc1 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:47:54 +0100 Subject: [PATCH 2/6] fixup --- spec/API_specification/examples/tpch/q5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index 861741fe..c1d604bd 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -65,7 +65,7 @@ def query( new_column = ( result.get_column_by_name("l_extendedprice") - * (result.get_column_by_name("l_discount") * -1 + 1) + * (1 - result.get_column_by_name("l_discount")) ).rename("revenue") result = result.assign(new_column) result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue")) From 4ac1d5de7491741bd4a3033036aab571aa2531ee Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:48:44 +0100 Subject: [PATCH 3/6] add q1 --- spec/API_specification/examples/tpch/q1.py | 39 ++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 spec/API_specification/examples/tpch/q1.py diff --git a/spec/API_specification/examples/tpch/q1.py b/spec/API_specification/examples/tpch/q1.py new file mode 100644 index 00000000..42ecdf08 --- /dev/null +++ b/spec/API_specification/examples/tpch/q1.py @@ -0,0 +1,39 @@ +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from dataframe_api.typing import SupportsDataFrameAPI + + +def query(lineitem_raw: SupportsDataFrameAPI) -> Any: + lineitem = lineitem_raw.__dataframe_consortium_standard__() + namespace = lineitem.__dataframe_namespace__() + + mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) # type: ignore + lineitem = lineitem.assign( + [ + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + ).rename("l_disc_price"), + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + * (1 + lineitem.get_column_by_name("l_tax")) + ).rename("l_charge"), + ] + ) + result = ( + lineitem.filter(mask) + .group_by(["l_returnflag", "l_linestatus"]) + .aggregate( + namespace.Aggregation.sum("l_quantity").rename("sum_qty"), + namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"), + namespace.Aggregation.sum("l_disc_price").rename("sum_disc_price"), + namespace.Aggregation.sum("change").rename("sum_charge"), + namespace.Aggregation.mean("l_quantity").rename("avg_qty"), + namespace.Aggregation.mean("l_discount").rename("avg_disc"), + namespace.Aggregation.size().rename("count_order"), + ) + .sort(["l_returnflag", "l_linestatus"]) + ) + return result.dataframe From abc309257da26f55765294fffa60beb96efa1291 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:05:20 +0100 Subject: [PATCH 4/6] note what happens if rename isnt called --- spec/API_specification/dataframe_api/groupby_object.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/spec/API_specification/dataframe_api/groupby_object.py b/spec/API_specification/dataframe_api/groupby_object.py index f1fb1163..c387dac0 100644 --- a/spec/API_specification/dataframe_api/groupby_object.py +++ b/spec/API_specification/dataframe_api/groupby_object.py @@ -75,7 +75,11 @@ def aggregate(self, *aggregation: Aggregation) -> DataFrame: class Aggregation(Protocol): def rename(self, name: str) -> Aggregation: - """Assign given name to output of aggregation. """ + """ + Assign given name to output of aggregation. + + If not called, the column's name will be used as the output name. + """ ... @classmethod From e55ebd820ae449a9bed133c3c5e43d6444fd62b0 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 26 Oct 2023 10:42:47 +0100 Subject: [PATCH 5/6] typing --- spec/API_specification/dataframe_api/typing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spec/API_specification/dataframe_api/typing.py b/spec/API_specification/dataframe_api/typing.py index 8a74aae1..5efeb1b1 100644 --- a/spec/API_specification/dataframe_api/typing.py +++ b/spec/API_specification/dataframe_api/typing.py @@ -112,8 +112,7 @@ def __init__( class String(): ... - class Aggregation: - ... + Aggregation: AggregationT def concat(self, dataframes: Sequence[DataFrame]) -> DataFrame: ... From 5112b1244290c517a8b2fac2fd4ba7eb6d2affd1 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 26 Oct 2023 10:45:11 +0100 Subject: [PATCH 6/6] fixup; --- spec/API_specification/examples/tpch/q1.py | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/spec/API_specification/examples/tpch/q1.py b/spec/API_specification/examples/tpch/q1.py index 42ecdf08..b5c11287 100644 --- a/spec/API_specification/examples/tpch/q1.py +++ b/spec/API_specification/examples/tpch/q1.py @@ -8,23 +8,21 @@ def query(lineitem_raw: SupportsDataFrameAPI) -> Any: lineitem = lineitem_raw.__dataframe_consortium_standard__() namespace = lineitem.__dataframe_namespace__() - mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) # type: ignore + mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) lineitem = lineitem.assign( - [ - ( - lineitem.get_column_by_name("l_extended_price") - * (1 - lineitem.get_column_by_name("l_discount")) - ).rename("l_disc_price"), - ( - lineitem.get_column_by_name("l_extended_price") - * (1 - lineitem.get_column_by_name("l_discount")) - * (1 + lineitem.get_column_by_name("l_tax")) - ).rename("l_charge"), - ] + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + ).rename("l_disc_price"), + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + * (1 + lineitem.get_column_by_name("l_tax")) + ).rename("l_charge"), ) result = ( lineitem.filter(mask) - .group_by(["l_returnflag", "l_linestatus"]) + .group_by("l_returnflag", "l_linestatus") .aggregate( namespace.Aggregation.sum("l_quantity").rename("sum_qty"), namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"), @@ -34,6 +32,6 @@ def query(lineitem_raw: SupportsDataFrameAPI) -> Any: namespace.Aggregation.mean("l_discount").rename("avg_disc"), namespace.Aggregation.size().rename("count_order"), ) - .sort(["l_returnflag", "l_linestatus"]) + .sort("l_returnflag", "l_linestatus") ) return result.dataframe