Skip to content

Add GroupBy.aggregate (and tpch-1 query to examples) #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/API_specification/dataframe_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"Duration",
"String",
"is_dtype",
"Aggregation",
]


Expand Down
75 changes: 73 additions & 2 deletions spec/API_specification/dataframe_api/groupby_object.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol

if TYPE_CHECKING:
from .dataframe_object import DataFrame


__all__ = ['GroupBy']
__all__ = [
"Aggregation",
"GroupBy",
]


class GroupBy:
Expand Down Expand Up @@ -51,3 +54,71 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr

def size(self) -> DataFrame:
...

def aggregate(self, *aggregation: Aggregation) -> DataFrame:
"""
Aggregate columns according to given aggregation function.

Examples
--------
>>> df: DataFrame
>>> namespace = df.__dataframe_namespace__()
>>> df.group_by('year').aggregate(
... namespace.Aggregation.sum('l_quantity').rename('sum_qty'),
... namespace.Aggregation.mean('l_quantity').rename('avg_qty'),
... namespace.Aggregation.mean('l_extended_price').rename('avg_price'),
... namespace.Aggregation.mean('l_discount').rename('avg_disc'),
... namespace.Aggregation.size().rename('count_order'),
... )
"""
...

class Aggregation(Protocol):
def rename(self, name: str) -> Aggregation:
"""Assign given name to output of aggregation. """
...

@classmethod
def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation:
...

@classmethod
def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation:
...

@classmethod
def size(cls) -> Aggregation:
...

7 changes: 5 additions & 2 deletions spec/API_specification/dataframe_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from dataframe_api.column_object import Column
from dataframe_api.dataframe_object import DataFrame
from dataframe_api.groupby_object import GroupBy
from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT

if TYPE_CHECKING:
from .dtypes import (
Expand Down Expand Up @@ -147,7 +147,9 @@ def is_null(value: object, /) -> bool:
@staticmethod
def is_dtype(dtype: Any, kind: str | tuple[str, ...]) -> bool:
...


class Aggregation(AggregationT):
...

class SupportsDataFrameAPI(Protocol):
def __dataframe_consortium_standard__(
Expand All @@ -163,6 +165,7 @@ def __column_consortium_standard__(


__all__ = [
"Aggregation",
"Column",
"DataFrame",
"DType",
Expand Down
39 changes: 39 additions & 0 deletions spec/API_specification/examples/tpch/q1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
from dataframe_api.typing import SupportsDataFrameAPI


def query(lineitem_raw: SupportsDataFrameAPI) -> Any:
lineitem = lineitem_raw.__dataframe_consortium_standard__()
namespace = lineitem.__dataframe_namespace__()

mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) # type: ignore
lineitem = lineitem.assign(
[
(
lineitem.get_column_by_name("l_extended_price")
* (1 - lineitem.get_column_by_name("l_discount"))
).rename("l_disc_price"),
(
lineitem.get_column_by_name("l_extended_price")
* (1 - lineitem.get_column_by_name("l_discount"))
* (1 + lineitem.get_column_by_name("l_tax"))
).rename("l_charge"),
]
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this syntax though...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shwina - are you ok with this syntax?

Personally I think it's worse than what's in any existing dataframe library, and I can't imagine any user ever wanting to write code like this

but maybe it's just me

Copy link
Contributor

@shwina shwina Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I'm wrong, but I thought the goal of the standard right now is to provide an API focused on third-party library developers (not end users). This is why we have been comfortable sacrificing syntactic crispness or an expressive API in favor of being the "lowest common denominator" that all libraries can implement.

I think this necessarily means the API isn't quite as nice to work with for the end-user.

For example, changing get_column_by_name to just [ ] in the code above would be a massive boost in readability, but we explicitly decided against it because (IIRC) we wanted library authors to have the freedom to decide what [ ] should mean for their library

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That being said, I agree with you 100% that this looks a mess. It's a question whether library developers are going to be OK with dealing with a messy API to get cross-library compatibility in return...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you 100% that this looks a mess

Well I'm glad we could find some common ground 😄

Let's discuss more next week - I'm genuinely interested in finding a solution that works for everybody

My current prediction is that, unless the standard drastically improves, that libraries will just support pandas and Polars and ignore the standard completely

The end result for cudf will be that you'll be no better off than you are now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I was saying...(emphasis mine)

I'm pretty upset about having to use df.get_column_by_name("a") instead of a simpler df["a"] or col("a"). This will obfuscate our code and impair readability, and therefore we may consider keeping our duplicate logic

#287

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. We should shorten the name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any suggestions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being addressed in #290

result = (
lineitem.filter(mask)
.group_by(["l_returnflag", "l_linestatus"])
.aggregate(
namespace.Aggregation.sum("l_quantity").rename("sum_qty"),
namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"),
namespace.Aggregation.sum("l_disc_price").rename("sum_disc_price"),
namespace.Aggregation.sum("change").rename("sum_charge"),
namespace.Aggregation.mean("l_quantity").rename("avg_qty"),
namespace.Aggregation.mean("l_discount").rename("avg_disc"),
namespace.Aggregation.size().rename("count_order"),
)
.sort(["l_returnflag", "l_linestatus"])
)
return result.dataframe
3 changes: 1 addition & 2 deletions spec/API_specification/examples/tpch/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def query(
* (1 - result.get_column_by_name("l_discount"))
).rename("revenue")
result = result.assign(new_column)
result = result.select(["revenue", "n_name"])
result = result.group_by("n_name").sum()
result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue"))

return result.dataframe
1 change: 1 addition & 0 deletions spec/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
('py:class', 'Scalar'),
('py:class', 'Bool'),
('py:class', 'optional'),
('py:class', 'Aggregation'),
('py:class', 'NullType'),
('py:class', 'Namespace'),
('py:class', 'SupportsDataFrameAPI'),
Expand Down