Skip to content

Commit 75cdf89

Browse files
authored
Merge pull request #16 from data-apis/datetime-feats
temporal functions
2 parents d9aa344 + de750ec commit 75cdf89

File tree

6 files changed

+432
-64
lines changed

6 files changed

+432
-64
lines changed

dataframe_api_compat/pandas_standard/pandas_standard.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,9 @@ def __init__(
112112
def __column_namespace__(self) -> Any:
113113
return dataframe_api_compat.pandas_standard
114114

115-
@property
116115
def root_names(self):
117116
return sorted(set(self._root_names))
118117

119-
@property
120118
def output_name(self):
121119
return self._output_name
122120

@@ -128,12 +126,12 @@ def _record_call(
128126
) -> PandasColumn:
129127
calls = [*self._calls, (func, self, rhs)]
130128
if isinstance(rhs, PandasColumn):
131-
root_names = self.root_names + rhs.root_names
129+
root_names = self.root_names() + rhs.root_names()
132130
else:
133-
root_names = self.root_names
131+
root_names = self.root_names()
134132
return PandasColumn(
135133
root_names=root_names,
136-
output_name=output_name or self.output_name,
134+
output_name=output_name or self.output_name(),
137135
extra_calls=calls,
138136
)
139137

@@ -320,12 +318,12 @@ def func(ser, _rhs):
320318
if ascending:
321319
return (
322320
ser.sort_values()
323-
.index.to_series(name=self.output_name)
321+
.index.to_series(name=self.output_name())
324322
.reset_index(drop=True)
325323
)
326324
return (
327325
ser.sort_values()
328-
.index.to_series(name=self.output_name)[::-1]
326+
.index.to_series(name=self.output_name())[::-1]
329327
.reset_index(drop=True)
330328
)
331329

@@ -376,7 +374,7 @@ def func(ser, value):
376374
ser = num / other
377375
else:
378376
ser = ser.fillna(value)
379-
return ser.rename(self.output_name)
377+
return ser.rename(self.output_name())
380378

381379
return self._record_call(
382380
lambda ser, _rhs: func(ser, value),
@@ -413,6 +411,66 @@ def rename(self, name: str) -> PandasColumn:
413411
)
414412
return expr
415413

414+
@property
415+
def dt(self) -> ColumnDatetimeAccessor:
416+
"""
417+
Return accessor with functions which work on temporal dtypes.
418+
"""
419+
return ColumnDatetimeAccessor(self)
420+
421+
422+
class ColumnDatetimeAccessor:
423+
def __init__(self, column: PandasColumn | PandasPermissiveColumn) -> None:
424+
if isinstance(column, PandasPermissiveColumn):
425+
self.eager = True
426+
self.column = column._to_expression()
427+
self._api_version = column._api_version
428+
else:
429+
self.eager = False
430+
self.column = column
431+
432+
def _return(self, expr: PandasColumn):
433+
if not self.eager:
434+
return expr
435+
return (
436+
PandasDataFrame(pd.DataFrame(), api_version=self._api_version)
437+
.select(expr)
438+
.collect()
439+
.get_column_by_name(self.column.output_name())
440+
)
441+
442+
def year(self) -> Column:
443+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.year, None)
444+
return self._return(expr)
445+
446+
def month(self) -> Column:
447+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.month, None)
448+
return expr
449+
450+
def day(self) -> Column:
451+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.day, None)
452+
return expr
453+
454+
def hour(self) -> Column:
455+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.hour, None)
456+
return expr
457+
458+
def minute(self) -> Column:
459+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.minute, None)
460+
return expr
461+
462+
def second(self) -> Column:
463+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.second, None)
464+
return expr
465+
466+
def microsecond(self) -> Column:
467+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.microsecond, None)
468+
return expr
469+
470+
def iso_weekday(self) -> Column:
471+
expr = self.column._record_call(lambda ser, _rhs: ser.dt.weekday + 1, None)
472+
return expr
473+
416474

417475
class PandasGroupBy(GroupBy):
418476
def __init__(self, df: pd.DataFrame, keys: Sequence[str], api_version: str) -> None:
@@ -518,6 +576,9 @@ def __init__(self, column: pd.Series[Any], api_version: str) -> None:
518576
"Try updating dataframe-api-compat?"
519577
)
520578

579+
def __repr__(self) -> str: # pragma: no cover
580+
return self.column.__repr__()
581+
521582
def _to_expression(self) -> PandasColumn:
522583
return PandasColumn(
523584
root_names=[],
@@ -737,6 +798,13 @@ def to_array_object(self, dtype: str) -> Any:
737798
)
738799
return self.column.to_numpy(dtype=dtype)
739800

801+
@property
802+
def dt(self) -> ColumnDatetimeAccessor:
803+
"""
804+
Return accessor with functions which work on temporal dtypes.
805+
"""
806+
return ColumnDatetimeAccessor(self)
807+
740808

741809
class PandasDataFrame(DataFrame):
742810
# Not technically part of the standard
@@ -879,7 +947,7 @@ def _resolve_expression(
879947
return expression
880948
if not expression._calls:
881949
return expression._base_call(self.dataframe)
882-
output_name = expression.output_name
950+
output_name = expression.output_name()
883951
for func, lhs, rhs in expression._calls:
884952
lhs = self._resolve_expression(lhs)
885953
rhs = self._resolve_expression(rhs)

0 commit comments

Comments
 (0)