Skip to content

Commit 1addc2a

Browse files
authored
Add GroupBy.aggregate (and tpch-1 query to examples) (#286)
* add Aggregation API * fixup * add q1 * note what happens if rename isnt called * typing * fixup;
1 parent fcfb54f commit 1addc2a

File tree

6 files changed

+121
-5
lines changed

6 files changed

+121
-5
lines changed

spec/API_specification/dataframe_api/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .typing import DType, Scalar
1616

1717
__all__ = [
18+
"Aggregation",
1819
"Bool",
1920
"Column",
2021
"DataFrame",

spec/API_specification/dataframe_api/groupby_object.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from .dataframe_object import DataFrame
77

88

9-
__all__ = ['GroupBy']
9+
__all__ = [
10+
"Aggregation",
11+
"GroupBy",
12+
]
1013

1114

1215
class GroupBy(Protocol):
@@ -51,3 +54,75 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr
5154

5255
def size(self) -> DataFrame:
5356
...
57+
58+
def aggregate(self, *aggregation: Aggregation) -> DataFrame:
59+
"""
60+
Aggregate columns according to given aggregation function.
61+
62+
Examples
63+
--------
64+
>>> df: DataFrame
65+
>>> namespace = df.__dataframe_namespace__()
66+
>>> df.group_by('year').aggregate(
67+
... namespace.Aggregation.sum('l_quantity').rename('sum_qty'),
68+
... namespace.Aggregation.mean('l_quantity').rename('avg_qty'),
69+
... namespace.Aggregation.mean('l_extended_price').rename('avg_price'),
70+
... namespace.Aggregation.mean('l_discount').rename('avg_disc'),
71+
... namespace.Aggregation.size().rename('count_order'),
72+
... )
73+
"""
74+
...
75+
76+
class Aggregation(Protocol):
77+
def rename(self, name: str) -> Aggregation:
78+
"""
79+
Assign given name to output of aggregation.
80+
81+
If not called, the column's name will be used as the output name.
82+
"""
83+
...
84+
85+
@classmethod
86+
def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
87+
...
88+
89+
@classmethod
90+
def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
91+
...
92+
93+
@classmethod
94+
def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
95+
...
96+
97+
@classmethod
98+
def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
99+
...
100+
101+
@classmethod
102+
def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
103+
...
104+
105+
@classmethod
106+
def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
107+
...
108+
109+
@classmethod
110+
def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
111+
...
112+
113+
@classmethod
114+
def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation:
115+
...
116+
117+
@classmethod
118+
def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
119+
...
120+
121+
@classmethod
122+
def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
123+
...
124+
125+
@classmethod
126+
def size(cls) -> Aggregation:
127+
...
128+

spec/API_specification/dataframe_api/typing.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from dataframe_api.column_object import Column
1717
from dataframe_api.dataframe_object import DataFrame
18-
from dataframe_api.groupby_object import GroupBy
18+
from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT
1919

2020
if TYPE_CHECKING:
2121
from .dtypes import (
@@ -112,6 +112,8 @@ def __init__(
112112
class String():
113113
...
114114

115+
Aggregation: AggregationT
116+
115117
def concat(self, dataframes: Sequence[DataFrame]) -> DataFrame:
116118
...
117119

@@ -146,7 +148,7 @@ def is_null(self, value: object, /) -> bool:
146148

147149
def is_dtype(self, dtype: Any, kind: str | tuple[str, ...]) -> bool:
148150
...
149-
151+
150152
def date(self, year: int, month: int, day: int) -> Scalar:
151153
...
152154

@@ -164,6 +166,7 @@ def __column_consortium_standard__(
164166

165167

166168
__all__ = [
169+
"Aggregation",
167170
"Column",
168171
"DataFrame",
169172
"DType",
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from typing import Any, TYPE_CHECKING
2+
3+
if TYPE_CHECKING:
4+
from dataframe_api.typing import SupportsDataFrameAPI
5+
6+
7+
def query(lineitem_raw: SupportsDataFrameAPI) -> Any:
8+
lineitem = lineitem_raw.__dataframe_consortium_standard__()
9+
namespace = lineitem.__dataframe_namespace__()
10+
11+
mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2)
12+
lineitem = lineitem.assign(
13+
(
14+
lineitem.get_column_by_name("l_extended_price")
15+
* (1 - lineitem.get_column_by_name("l_discount"))
16+
).rename("l_disc_price"),
17+
(
18+
lineitem.get_column_by_name("l_extended_price")
19+
* (1 - lineitem.get_column_by_name("l_discount"))
20+
* (1 + lineitem.get_column_by_name("l_tax"))
21+
).rename("l_charge"),
22+
)
23+
result = (
24+
lineitem.filter(mask)
25+
.group_by("l_returnflag", "l_linestatus")
26+
.aggregate(
27+
namespace.Aggregation.sum("l_quantity").rename("sum_qty"),
28+
namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"),
29+
namespace.Aggregation.sum("l_disc_price").rename("sum_disc_price"),
30+
namespace.Aggregation.sum("change").rename("sum_charge"),
31+
namespace.Aggregation.mean("l_quantity").rename("avg_qty"),
32+
namespace.Aggregation.mean("l_discount").rename("avg_disc"),
33+
namespace.Aggregation.size().rename("count_order"),
34+
)
35+
.sort("l_returnflag", "l_linestatus")
36+
)
37+
return result.dataframe

spec/API_specification/examples/tpch/q5.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def query(
6868
* (1 - result.get_column_by_name("l_discount"))
6969
).rename("revenue")
7070
result = result.assign(new_column)
71-
result = result.select("revenue", "n_name")
72-
result = result.group_by("n_name").sum()
71+
result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue"))
7372

7473
return result.dataframe

spec/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
('py:class', 'Scalar'),
8585
('py:class', 'Bool'),
8686
('py:class', 'optional'),
87+
('py:class', 'Aggregation'),
8788
('py:class', 'NullType'),
8889
('py:class', 'Namespace'),
8990
('py:class', 'SupportsDataFrameAPI'),

0 commit comments

Comments
 (0)