Skip to content

Commit 7043f64

Browse files
committed
ergonomics
1 parent 9ebd79d commit 7043f64

12 files changed

+168
-153
lines changed

spec/API_specification/dataframe_api/__init__.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
from typing import Mapping, Sequence, Any, Literal
77

88
from .eagercolumn_object import EagerColumn
9+
from .eagerframe_object import EagerFrame
910
from .expression_object import Expression
1011
from .dataframe_object import DataFrame
12+
from .groupby_object import GroupBy
1113

1214
__all__ = [
1315
"__dataframe_api_version__",
1416
"DataFrame",
17+
"EagerFrame",
1518
"EagerColumn",
19+
"Expression",
20+
"GroupBy",
1621
"column_from_sequence",
1722
"column_from_1d_array",
1823
"col",
@@ -35,6 +40,8 @@
3540
"Float32",
3641
"Bool",
3742
"is_dtype",
43+
"any_rowwise",
44+
"all_rowwise",
3845
]
3946

4047

@@ -202,7 +209,7 @@ def dataframe_from_2d_array(array: Any, *, names: Sequence[str], dtypes: Mapping
202209
"""
203210
...
204211

205-
def any_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Expression:
212+
def any_rowwise(*keys: str, skip_nulls: bool = True) -> Expression:
206213
"""
207214
Reduction returns an Expression.
208215
@@ -211,8 +218,8 @@ def any_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Ex
211218
212219
Parameters
213220
----------
214-
keys : list[str]
215-
Column names to consider. If `None`, all columns are considered.
221+
keys : str
222+
Column names to consider.
216223
217224
Raises
218225
------
@@ -221,7 +228,7 @@ def any_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Ex
221228
"""
222229
...
223230

224-
def all_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Expression:
231+
def all_rowwise(*keys: str, skip_nulls: bool = True) -> Expression:
225232
"""
226233
Reduction returns an Expression.
227234
@@ -230,8 +237,8 @@ def all_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Ex
230237
231238
Parameters
232239
----------
233-
keys : list[str]
234-
Column names to consider. If `None`, all columns are considered.
240+
keys : str
241+
Column names to consider.
235242
236243
Raises
237244
------
@@ -241,8 +248,7 @@ def all_rowwise(keys: list[str] | None = None, *, skip_nulls: bool = True) -> Ex
241248
...
242249

243250
def sorted_indices(
244-
keys: str | list[str] | None = None,
245-
*,
251+
*keys: str,
246252
ascending: Sequence[bool] | bool = True,
247253
nulls_position: Literal['first', 'last'] = 'last',
248254
) -> Expression:
@@ -253,9 +259,8 @@ def sorted_indices(
253259
254260
Parameters
255261
----------
256-
keys : str | list[str], optional
262+
keys : str
257263
Names of columns to sort by.
258-
If `None`, sort by all columns.
259264
ascending : Sequence[bool] or bool
260265
If `True`, sort by all keys in ascending order.
261266
If `False`, sort by all keys in descending order.
@@ -280,15 +285,14 @@ def sorted_indices(
280285
...
281286

282287

283-
def unique_indices(keys: str | list[str] | None = None, *, skip_nulls: bool = True) -> Expression:
288+
def unique_indices(*keys: str, skip_nulls: bool = True) -> Expression:
284289
"""
285290
Return indices corresponding to unique values across selected columns.
286291
287292
Parameters
288293
----------
289-
keys : str | list[str], optional
294+
keys : str
290295
Column names to consider when finding unique values.
291-
If `None`, all columns are considered.
292296
293297
Returns
294298
-------

spec/API_specification/dataframe_api/_types.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,17 @@
1717
TypeVar,
1818
Union,
1919
Protocol,
20+
TYPE_CHECKING,
21+
TypeAlias
2022
)
2123
from enum import Enum
2224

25+
if TYPE_CHECKING:
26+
from .expression_object import Expression
27+
from .eagercolumn_object import EagerColumn
28+
29+
IntoExpression: TypeAlias = Expression | EagerColumn
30+
2331
# Type alias: Mypy needs Any, but for readability we need to make clear this
2432
# is a Python scalar (i.e., an instance of `bool`, `int`, `float`, `str`, etc.)
2533
Scalar = Any

spec/API_specification/dataframe_api/dataframe_object.py

+18-58
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
from .eagerframe_object import EagerFrame
99
from .eagercolumn_object import EagerColumn
1010
from .groupby_object import GroupBy
11-
from ._types import NullType, Scalar
11+
from ._types import NullType, Scalar, IntoExpression
1212

1313

1414
__all__ = ["DataFrame"]
1515

16-
IntoExpression = EagerColumn[Any] | Expression
1716

1817

1918
class DataFrame:
@@ -68,13 +67,13 @@ def dataframe(self) -> object:
6867
"""
6968
...
7069

71-
def groupby(self, keys: str | list[str], /) -> GroupBy:
70+
def groupby(self, *keys: str) -> GroupBy:
7271
"""
7372
Group the DataFrame by the given columns.
7473
7574
Parameters
7675
----------
77-
keys : str | list[str]
76+
keys : str
7877
7978
Returns
8079
-------
@@ -93,7 +92,7 @@ def groupby(self, keys: str | list[str], /) -> GroupBy:
9392
"""
9493
...
9594

96-
def select(self, names: str | Expression | Sequence[str | Expression], /) -> DataFrame:
95+
def select(self, *names: str | Expression) -> DataFrame:
9796
"""
9897
Select multiple columns, either by name or by expressions.
9998
@@ -107,16 +106,9 @@ def select(self, names: str | Expression | Sequence[str | Expression], /) -> Dat
107106
108107
Examples
109108
--------
110-
Select columns 'a' and 'b':
111-
112-
>>> df: DataFrame
113-
>>> df.select(['a', 'b'])
114-
115-
You can also pass expressions:
116-
117109
>>> df: DataFrame
118-
>>> namespace = df.__dataframe_namespace__()
119-
>>> df.select(['a', namespace.col('b')+1])
110+
>>> col = df.__dataframe_namespace__().col
111+
>>> df = df.select('a', col('b'), (col('c')+col('d')+1).rename('e'))
120112
121113
Raises
122114
------
@@ -145,13 +137,13 @@ def slice_rows(
145137
"""
146138
...
147139

148-
def filter(self, mask: Expression) -> DataFrame:
140+
def filter(self, mask: IntoExpression) -> DataFrame:
149141
"""
150142
Select a subset of rows corresponding to a mask.
151143
152144
Parameters
153145
----------
154-
mask : Expression
146+
mask : Expression or EagerColumn
155147
156148
Returns
157149
-------
@@ -170,7 +162,7 @@ def filter(self, mask: Expression) -> DataFrame:
170162
"""
171163
...
172164

173-
def insert_columns(self, columns: IntoExpression | Sequence[IntoExpression]) -> DataFrame:
165+
def insert_columns(self, *columns: Expression | EagerColumn[Any]) -> DataFrame:
174166
"""
175167
Insert column into DataFrame at rightmost location.
176168
@@ -184,7 +176,7 @@ def insert_columns(self, columns: IntoExpression | Sequence[IntoExpression]) ->
184176
namespace = df.__dataframe_namespace__()
185177
col = namespace.col
186178
new_column = namespace.col('a') + 1
187-
df = df.insert_column(new_column.rename('a_plus_1'))
179+
df = df.insert_columns(new_column.rename('a_plus_1'))
188180
189181
If you need to insert the column at a different location, combine with
190182
:meth:`select`, e.g.:
@@ -196,7 +188,7 @@ def insert_columns(self, columns: IntoExpression | Sequence[IntoExpression]) ->
196188
col = namespace.col
197189
new_column = namespace.col('a') + 1
198190
new_columns_names = ['a_plus_1'] + df.get_column_names()
199-
df = df.insert_column(new_column.rename('a_plus_1'))
191+
df = df.insert_columns(new_column.rename('a_plus_1'))
200192
df = df.select(new_column_names)
201193
202194
Parameters
@@ -206,7 +198,7 @@ def insert_columns(self, columns: IntoExpression | Sequence[IntoExpression]) ->
206198
"""
207199
...
208200

209-
def update_columns(self, columns: Expression | EagerColumn | Sequence[Expression | EagerColumn], /) -> DataFrame:
201+
def update_columns(self, *columns: Expression | EagerColumn[Any]) -> DataFrame:
210202
"""
211203
Update values in existing column(s) from Dataframe.
212204
@@ -224,7 +216,7 @@ def update_columns(self, columns: Expression | EagerColumn | Sequence[Expression
224216
225217
Parameters
226218
----------
227-
columns : Expression | Sequence[Expression]
219+
columns : Expression, EagerColumn, or sequence of either
228220
Column(s) to update. If updating multiple columns, they must all have
229221
different names.
230222
@@ -268,7 +260,8 @@ def rename_columns(self, mapping: Mapping[str, str]) -> DataFrame:
268260
"""
269261
...
270262

271-
def get_column_names(self) -> list[str]:
263+
@property
264+
def column_names(self) -> list[str]:
272265
"""
273266
Get column names.
274267
@@ -280,8 +273,7 @@ def get_column_names(self) -> list[str]:
280273

281274
def sort(
282275
self,
283-
keys: str | Expression | list[str | Expression] | None = None,
284-
*,
276+
*keys: str | Expression,
285277
ascending: Sequence[bool] | bool = True,
286278
nulls_position: Literal['first', 'last'] = 'last',
287279
) -> DataFrame:
@@ -293,9 +285,9 @@ def sort(
293285
294286
Parameters
295287
----------
296-
keys : str | list[str], optional
288+
keys : str | Expression
297289
Names of columns to sort by.
298-
If `None`, sort by all columns.
290+
If not passed, will sort by all columns.
299291
ascending : Sequence[bool] or bool
300292
If `True`, sort by all keys in ascending order.
301293
If `False`, sort by all keys in descending order.
@@ -759,38 +751,6 @@ def fill_nan(self, value: float | NullType, /) -> DataFrame:
759751
"""
760752
...
761753

762-
def fill_null(
763-
self, value: Scalar, /, *, column_names : list[str] | None = None
764-
) -> DataFrame:
765-
"""
766-
Fill null values with the given fill value.
767-
768-
This method can only be used if all columns that are to be filled are
769-
of the same dtype (e.g., all of ``Float64`` or all of string dtype).
770-
If that is not the case, it is not possible to use a single Python
771-
scalar type that matches the dtype of all columns to which
772-
``fill_null`` is being applied, and hence an exception will be raised.
773-
774-
Parameters
775-
----------
776-
value : Scalar
777-
Value used to replace any ``null`` values in the dataframe with.
778-
Must be of the Python scalar type matching the dtype(s) of the dataframe.
779-
column_names : list[str] | None
780-
A list of column names for which to replace nulls with the given
781-
scalar value. If ``None``, nulls will be replaced in all columns.
782-
783-
Raises
784-
------
785-
TypeError
786-
If the columns of the dataframe are not all of the same kind.
787-
KeyError
788-
If ``column_names`` contains a column name that is not present in
789-
the dataframe.
790-
791-
"""
792-
...
793-
794754
def join(
795755
self,
796756
other: DataFrame,

0 commit comments

Comments
 (0)