Skip to content

temporal functions #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions dataframe_api_compat/pandas_standard/pandas_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,9 @@ def __init__(
def __column_namespace__(self) -> Any:
return dataframe_api_compat.pandas_standard

@property
def root_names(self):
return sorted(set(self._root_names))

@property
def output_name(self):
return self._output_name

Expand All @@ -128,12 +126,12 @@ def _record_call(
) -> PandasColumn:
calls = [*self._calls, (func, self, rhs)]
if isinstance(rhs, PandasColumn):
root_names = self.root_names + rhs.root_names
root_names = self.root_names() + rhs.root_names()
else:
root_names = self.root_names
root_names = self.root_names()
return PandasColumn(
root_names=root_names,
output_name=output_name or self.output_name,
output_name=output_name or self.output_name(),
extra_calls=calls,
)

Expand Down Expand Up @@ -320,12 +318,12 @@ def func(ser, _rhs):
if ascending:
return (
ser.sort_values()
.index.to_series(name=self.output_name)
.index.to_series(name=self.output_name())
.reset_index(drop=True)
)
return (
ser.sort_values()
.index.to_series(name=self.output_name)[::-1]
.index.to_series(name=self.output_name())[::-1]
.reset_index(drop=True)
)

Expand Down Expand Up @@ -376,7 +374,7 @@ def func(ser, value):
ser = num / other
else:
ser = ser.fillna(value)
return ser.rename(self.output_name)
return ser.rename(self.output_name())

return self._record_call(
lambda ser, _rhs: func(ser, value),
Expand Down Expand Up @@ -413,6 +411,66 @@ def rename(self, name: str) -> PandasColumn:
)
return expr

@property
def dt(self) -> ColumnDatetimeAccessor:
"""
Return accessor with functions which work on temporal dtypes.
"""
return ColumnDatetimeAccessor(self)


class ColumnDatetimeAccessor:
def __init__(self, column: PandasColumn | PandasPermissiveColumn) -> None:
if isinstance(column, PandasPermissiveColumn):
self.eager = True
self.column = column._to_expression()
self._api_version = column._api_version
else:
self.eager = False
self.column = column

def _return(self, expr: PandasColumn):
if not self.eager:
return expr
return (
PandasDataFrame(pd.DataFrame(), api_version=self._api_version)
.select(expr)
.collect()
.get_column_by_name(self.column.output_name())
)

def year(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.year, None)
return self._return(expr)

def month(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.month, None)
return expr

def day(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.day, None)
return expr

def hour(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.hour, None)
return expr

def minute(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.minute, None)
return expr

def second(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.second, None)
return expr

def microsecond(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.microsecond, None)
return expr

def iso_weekday(self) -> Column:
expr = self.column._record_call(lambda ser, _rhs: ser.dt.weekday + 1, None)
return expr


class PandasGroupBy(GroupBy):
def __init__(self, df: pd.DataFrame, keys: Sequence[str], api_version: str) -> None:
Expand Down Expand Up @@ -518,6 +576,9 @@ def __init__(self, column: pd.Series[Any], api_version: str) -> None:
"Try updating dataframe-api-compat?"
)

def __repr__(self) -> str: # pragma: no cover
return self.column.__repr__()

def _to_expression(self) -> PandasColumn:
return PandasColumn(
root_names=[],
Expand Down Expand Up @@ -737,6 +798,13 @@ def to_array_object(self, dtype: str) -> Any:
)
return self.column.to_numpy(dtype=dtype)

@property
def dt(self) -> ColumnDatetimeAccessor:
"""
Return accessor with functions which work on temporal dtypes.
"""
return ColumnDatetimeAccessor(self)


class PandasDataFrame(DataFrame):
# Not technically part of the standard
Expand Down Expand Up @@ -879,7 +947,7 @@ def _resolve_expression(
return expression
if not expression._calls:
return expression._base_call(self.dataframe)
output_name = expression.output_name
output_name = expression.output_name()
for func, lhs, rhs in expression._calls:
lhs = self._resolve_expression(lhs)
rhs = self._resolve_expression(rhs)
Expand Down
Loading