Skip to content

Commit a6ef75f

Browse files
committed
if it accepts expression, it should accept column
1 parent d404a73 commit a6ef75f

File tree

7 files changed

+40
-50
lines changed

7 files changed

+40
-50
lines changed

spec/API_specification/dataframe_api/_types.py

-8
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,9 @@
1717
TypeVar,
1818
Union,
1919
Protocol,
20-
TYPE_CHECKING,
21-
TypeAlias
2220
)
2321
from enum import Enum
2422

25-
if TYPE_CHECKING:
26-
from .expression_object import Expression
27-
from .eagercolumn_object import EagerColumn
28-
29-
IntoExpression: TypeAlias = Expression | EagerColumn
30-
3123
# Type alias: Mypy needs Any, but for readability we need to make clear this
3224
# is a Python scalar (i.e., an instance of `bool`, `int`, `float`, `str`, etc.)
3325
Scalar = Any

spec/API_specification/dataframe_api/dataframe_object.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .eagerframe_object import EagerFrame
99
from .eagercolumn_object import EagerColumn
1010
from .groupby_object import GroupBy
11-
from ._types import NullType, Scalar, IntoExpression
11+
from ._types import NullType, Scalar
1212

1313

1414
__all__ = ["DataFrame"]
@@ -92,7 +92,7 @@ def groupby(self, *keys: str) -> GroupBy:
9292
"""
9393
...
9494

95-
def select(self, *names: str | Expression) -> DataFrame:
95+
def select(self, *names: str | Expression | EagerColumn[Any]) -> DataFrame:
9696
"""
9797
Select multiple columns, either by name or by expressions.
9898
@@ -137,7 +137,7 @@ def slice_rows(
137137
"""
138138
...
139139

140-
def filter(self, mask: IntoExpression) -> DataFrame:
140+
def filter(self, mask: Expression | EagerColumn[bool]) -> DataFrame:
141141
"""
142142
Select a subset of rows corresponding to a mask.
143143
@@ -216,7 +216,7 @@ def update_columns(self, *columns: Expression | EagerColumn[Any]) -> DataFrame:
216216
217217
Parameters
218218
----------
219-
columns : Expression, EagerColumn, or sequence of either
219+
columns : Expression | EagerColumn
220220
Column(s) to update. If updating multiple columns, they must all have
221221
different names.
222222
@@ -273,7 +273,7 @@ def column_names(self) -> list[str]:
273273

274274
def sort(
275275
self,
276-
*keys: str | Expression,
276+
*keys: str | Expression | EagerColumn[Any],
277277
ascending: Sequence[bool] | bool = True,
278278
nulls_position: Literal['first', 'last'] = 'last',
279279
) -> DataFrame:

spec/API_specification/dataframe_api/eagercolumn_object.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class EagerColumn(Generic[DType]):
1717
"""
1818
EagerColumn object
1919
20-
Instantiate via :meth:`EagerFrame.get_column_by_name`.
20+
Instantiate via :meth:`EagerFrame.get_column`.
2121
2222
If you need to use this within the context of a :class`DataFrame` operation
2323
(such as `:meth:`DataFrame.filter`) then you can convert it to an expression
@@ -106,7 +106,7 @@ def slice_rows(
106106
...
107107

108108

109-
def filter(self: EagerColumn[DType], mask: EagerColumn[Bool]) -> EagerColumn[DType]:
109+
def filter(self: EagerColumn[DType], mask: Expression | EagerColumn[Bool]) -> EagerColumn[DType]:
110110
"""
111111
Select a subset of rows corresponding to a mask.
112112

spec/API_specification/dataframe_api/eagerframe_object.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .expression_object import Expression
99
from .dataframe_object import DataFrame
1010
from .groupby_object import GroupBy
11-
from ._types import NullType, Scalar, IntoExpression
11+
from ._types import NullType, Scalar
1212

1313

1414
__all__ = ["EagerFrame"]
@@ -89,7 +89,7 @@ def get_column(self, name: str, /) -> EagerColumn[Any]:
8989
"""
9090
...
9191

92-
def select(self, *columns: str | Expression) -> EagerFrame:
92+
def select(self, *columns: str | Expression | EagerColumn[Any]) -> EagerFrame:
9393
"""
9494
Select multiple columns by name.
9595
@@ -115,7 +115,7 @@ def select(self, *columns: str | Expression) -> EagerFrame:
115115
"""
116116
...
117117

118-
def get_rows(self, indices: Expression) -> EagerFrame:
118+
def get_rows(self, indices: Expression | EagerColumn[Any]) -> EagerFrame:
119119
"""
120120
Select a subset of rows, similar to `ndarray.take`.
121121
@@ -148,7 +148,7 @@ def slice_rows(
148148
"""
149149
...
150150

151-
def filter(self, mask: IntoExpression) -> EagerFrame:
151+
def filter(self, mask: Expression | EagerColumn[bool]) -> EagerFrame:
152152
"""
153153
Select a subset of rows corresponding to a mask.
154154
@@ -173,15 +173,15 @@ def insert_columns(self, *columns: Expression | EagerColumn[Any]) -> EagerFrame:
173173
174174
.. code-block:: python
175175
176-
new_column = df.get_column_by_name('a') + 1
176+
new_column = df.get_column('a') + 1
177177
df = df.insert_columns(new_column.rename('a_plus_1'))
178178
179179
If you need to insert the column at a different location, combine with
180180
:meth:`select`, e.g.:
181181
182182
.. code-block:: python
183183
184-
new_column = df.get_column_by_name('a') + 1
184+
new_column = df.get_column('a') + 1
185185
new_columns_names = ['a_plus_1'] + df.get_column_names()
186186
df = df.insert_columns(new_column.rename('a_plus_1'))
187187
df = df.select(new_column_names)
@@ -203,12 +203,12 @@ def update_columns(self, *columns: Expression | EagerColumn[Any]) -> EagerFrame:
203203
204204
.. code-block:: python
205205
206-
new_column = df.get_column_by_name('a') + 1
206+
new_column = df.get_column('a') + 1
207207
df = df.update_column(new_column.rename('b').to_expression())
208208
209209
Parameters
210210
----------
211-
columns : IntoExpression | Sequence[IntoExpression]
211+
columns : Expression | EagerColumn
212212
Column(s) to update. If updating multiple columns, they must all have
213213
different names.
214214
@@ -265,7 +265,7 @@ def column_names(self) -> list[str]:
265265

266266
def sort(
267267
self,
268-
*keys: str | Expression,
268+
*keys: str | Expression | EagerColumn[Any],
269269
ascending: Sequence[bool] | bool = True,
270270
nulls_position: Literal['first', 'last'] = 'last',
271271
) -> EagerFrame:

spec/API_specification/dataframe_api/expression_object.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44

55

66
if TYPE_CHECKING:
7-
from ._types import DType
8-
from . import Bool
97
from ._types import NullType, Scalar
8+
from .eagercolumn_object import EagerColumn
109

1110

1211
__all__ = ['Expression']
@@ -92,7 +91,7 @@ def len(self) -> Expression:
9291
Return the number of rows.
9392
"""
9493

95-
def get_rows(self: Expression, indices: Expression) -> Expression:
94+
def get_rows(self, indices: Expression | EagerColumn[Any]) -> Expression:
9695
"""
9796
Select a subset of rows, similar to `ndarray.take`.
9897
@@ -104,7 +103,7 @@ def get_rows(self: Expression, indices: Expression) -> Expression:
104103
...
105104

106105
def slice_rows(
107-
self: Expression, start: int | None, stop: int | None, step: int | None
106+
self, start: int | None, stop: int | None, step: int | None
108107
) -> Expression:
109108
"""
110109
Select a subset of rows corresponding to a slice.
@@ -121,7 +120,7 @@ def slice_rows(
121120
"""
122121
...
123122

124-
def filter(self, mask: Expression) -> Expression:
123+
def filter(self, mask: Expression | EagerColumn[bool]) -> Expression:
125124
"""
126125
Select a subset of rows corresponding to a mask.
127126
@@ -225,7 +224,7 @@ def __eq__(self, other: Expression | Scalar) -> Expression: # type: ignore[over
225224
Expression
226225
"""
227226

228-
def __ne__(self: Expression, other: Expression | Scalar) -> Expression: # type: ignore[override]
227+
def __ne__(self, other: Expression | Scalar) -> Expression: # type: ignore[override]
229228
"""
230229
Compare for non-equality.
231230
@@ -243,7 +242,7 @@ def __ne__(self: Expression, other: Expression | Scalar) -> Expression: # type:
243242
Expression
244243
"""
245244

246-
def __ge__(self: Expression, other: Expression | Scalar) -> Expression:
245+
def __ge__(self, other: Expression | Scalar) -> Expression:
247246
"""
248247
Compare for "greater than or equal to" `other`.
249248
@@ -259,7 +258,7 @@ def __ge__(self: Expression, other: Expression | Scalar) -> Expression:
259258
Expression
260259
"""
261260

262-
def __gt__(self: Expression, other: Expression | Scalar) -> Expression:
261+
def __gt__(self, other: Expression | Scalar) -> Expression:
263262
"""
264263
Compare for "greater than" `other`.
265264
@@ -275,7 +274,7 @@ def __gt__(self: Expression, other: Expression | Scalar) -> Expression:
275274
Expression
276275
"""
277276

278-
def __le__(self: Expression, other: Expression | Scalar) -> Expression:
277+
def __le__(self, other: Expression | Scalar) -> Expression:
279278
"""
280279
Compare for "less than or equal to" `other`.
281280
@@ -291,7 +290,7 @@ def __le__(self: Expression, other: Expression | Scalar) -> Expression:
291290
Expression
292291
"""
293292

294-
def __lt__(self: Expression, other: Expression | Scalar) -> Expression:
293+
def __lt__(self, other: Expression | Scalar) -> Expression:
295294
"""
296295
Compare for "less than" `other`.
297296
@@ -307,7 +306,7 @@ def __lt__(self: Expression, other: Expression | Scalar) -> Expression:
307306
Expression
308307
"""
309308

310-
def __and__(self: Expression, other: Expression | bool) -> Expression:
309+
def __and__(self, other: Expression | bool) -> Expression:
311310
"""
312311
Apply logical 'and' to `other` expression (or scalar) and this expression.
313312
@@ -328,7 +327,7 @@ def __and__(self: Expression, other: Expression | bool) -> Expression:
328327
If `self` or `other` is not boolean.
329328
"""
330329

331-
def __or__(self: Expression, other: Expression | bool) -> Expression:
330+
def __or__(self, other: Expression | bool) -> Expression:
332331
"""
333332
Apply logical 'or' to `other` expression (or scalar) and this expression.
334333
@@ -349,7 +348,7 @@ def __or__(self: Expression, other: Expression | bool) -> Expression:
349348
If `self` or `other` is not boolean.
350349
"""
351350

352-
def __add__(self: Expression, other: Expression | Scalar) -> Expression:
351+
def __add__(self, other: Expression | Scalar) -> Expression:
353352
"""
354353
Add `other` expression or scalar to this expression.
355354
@@ -365,7 +364,7 @@ def __add__(self: Expression, other: Expression | Scalar) -> Expression:
365364
Expression
366365
"""
367366

368-
def __sub__(self: Expression, other: Expression | Scalar) -> Expression:
367+
def __sub__(self, other: Expression | Scalar) -> Expression:
369368
"""
370369
Subtract `other` expression or scalar from this expression.
371370
@@ -481,7 +480,7 @@ def __divmod__(self, other: Expression | Scalar) -> tuple[Expression, Expression
481480
tuple[Expression, Expression]
482481
"""
483482

484-
def __invert__(self: Expression) -> Expression:
483+
def __invert__(self) -> Expression:
485484
"""
486485
Invert truthiness of (boolean) elements.
487486
@@ -491,7 +490,7 @@ def __invert__(self: Expression) -> Expression:
491490
If any of the expression's expressions is not boolean.
492491
"""
493492

494-
def any(self: Expression, *, skip_nulls: bool = True) -> Expression:
493+
def any(self, *, skip_nulls: bool = True) -> Expression:
495494
"""
496495
Reduction returns a bool.
497496
@@ -501,7 +500,7 @@ def any(self: Expression, *, skip_nulls: bool = True) -> Expression:
501500
If expression is not boolean.
502501
"""
503502

504-
def all(self: Expression, *, skip_nulls: bool = True) -> Expression:
503+
def all(self, *, skip_nulls: bool = True) -> Expression:
505504
"""
506505
Reduction returns a bool.
507506
@@ -595,26 +594,26 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> Expres
595594
Whether to skip null values.
596595
"""
597596

598-
def cumulative_max(self: Expression) -> Expression:
597+
def cumulative_max(self) -> Expression:
599598
"""
600599
Reduction returns a expression. Any data type that supports comparisons
601600
must be supported. The returned value has the same dtype as the expression.
602601
"""
603602

604-
def cumulative_min(self: Expression) -> Expression:
603+
def cumulative_min(self) -> Expression:
605604
"""
606605
Reduction returns a expression. Any data type that supports comparisons
607606
must be supported. The returned value has the same dtype as the expression.
608607
"""
609608

610-
def cumulative_sum(self: Expression) -> Expression:
609+
def cumulative_sum(self) -> Expression:
611610
"""
612611
Reduction returns a expression. Must be supported for numerical and
613612
datetime data types. The returned value has the same dtype as the
614613
expression.
615614
"""
616615

617-
def cumulative_prod(self: Expression) -> Expression:
616+
def cumulative_prod(self) -> Expression:
618617
"""
619618
Reduction returns a expression. Must be supported for numerical and
620619
datetime data types. The returned value has the same dtype as the
@@ -659,7 +658,7 @@ def is_nan(self) -> Expression:
659658
In particular, does not check for `np.timedelta64('NaT')`.
660659
"""
661660

662-
def is_in(self: Expression, values: Expression) -> Expression:
661+
def is_in(self, values: Expression | EagerColumn[Any]) -> Expression:
663662
"""
664663
Indicate whether the value at each row matches any value in `values`.
665664
@@ -698,7 +697,7 @@ def unique_indices(self, *, skip_nulls: bool = True) -> Expression:
698697
"""
699698
...
700699

701-
def fill_nan(self: Expression, value: float | NullType, /) -> Expression:
700+
def fill_nan(self, value: float | NullType, /) -> Expression:
702701
"""
703702
Fill floating point ``nan`` values with the given fill value.
704703
@@ -712,7 +711,7 @@ def fill_nan(self: Expression, value: float | NullType, /) -> Expression:
712711
"""
713712
...
714713

715-
def fill_null(self: Expression, value: Scalar, /) -> Expression:
714+
def fill_null(self, value: Scalar, /) -> Expression:
716715
"""
717716
Fill null values with the given fill value.
718717

spec/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@
8585
('py:class', 'optional'),
8686
('py:class', 'NullType'),
8787
('py:class', 'GroupBy'),
88-
('py:class', 'IntoExpression'),
8988
]
9089
# NOTE: this alias handling isn't used yet - added in anticipation of future
9190
# need based on dataframe API aliases.

spec/design_topics/python_builtin_types.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class EagerColumn:
2121
def mean(self, skip_nulls: bool = True) -> float | NullType:
2222
...
2323

24-
larger = df2 > df1.get_column_by_name('foo').mean()
24+
larger = df2 > df1.get_column('foo').mean()
2525
```
2626

2727
For a GPU dataframe library, it is desirable for all data to reside on the GPU,

0 commit comments

Comments
 (0)