Skip to content

API (string): return str dtype for .dt methods, DatetimeIndex methods #59526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config.config import get_option

from pandas._libs import (
Expand Down Expand Up @@ -1759,6 +1760,10 @@ def strftime(self, date_format: str) -> npt.NDArray[np.object_]:
dtype='object')
"""
result = self._format_native_types(date_format=date_format, na_rep=np.nan)
if using_string_dtype():
from pandas import StringDtype

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result.astype(object, copy=False)


Expand Down
16 changes: 16 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config.config import get_option

from pandas._libs import (
Expand Down Expand Up @@ -1332,6 +1333,13 @@ def month_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "month_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

def day_name(self, locale=None) -> npt.NDArray[np.object_]:
Expand Down Expand Up @@ -1393,6 +1401,14 @@ def day_name(self, locale=None) -> npt.NDArray[np.object_]:
values, "day_name", locale=locale, reso=self._creso
)
result = self._maybe_mask_results(result, fill_value=None)
if using_string_dtype():
# TODO: no tests that check for dtype of result as of 2024-08-15
from pandas import (
StringDtype,
array as pd_array,
)

return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
return result

@property
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _engine_type(self) -> type[libindex.DatetimeEngine]:
@doc(DatetimeArray.strftime)
def strftime(self, date_format) -> Index:
arr = self._data.strftime(date_format)
return Index(arr, name=self.name, dtype=object)
return Index(arr, name=self.name, dtype=arr.dtype)

@doc(DatetimeArray.tz_convert)
def tz_convert(self, tz) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fget(self):
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

def fset(self, value) -> None:
Expand All @@ -101,7 +101,7 @@ def method(self, *args, **kwargs): # type: ignore[misc]
return type(self)._simple_new(result, name=self.name)
elif isinstance(result, ABCDataFrame):
return result.set_index(self)
return Index(result, name=self.name)
return Index(result, name=self.name, dtype=result.dtype)
return result

# error: "property" has no attribute "__name__"
Expand Down
24 changes: 16 additions & 8 deletions pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,20 +891,24 @@ def test_concat_same_type_different_freq(self, unit):

tm.assert_datetime_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y %b")
expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = DatetimeIndex(["2019-01-01", NaT])._data

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


class TestTimedeltaArray(SharedTests):
Expand Down Expand Up @@ -1161,20 +1165,24 @@ def test_array_interface(self, arr1d):
expected = np.asarray(arr).astype("S20")
tm.assert_numpy_array_equal(result, expected)

def test_strftime(self, arr1d):
def test_strftime(self, arr1d, using_infer_string):
arr = arr1d

result = arr.strftime("%Y")
expected = np.array([per.strftime("%Y") for per in arr], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)

def test_strftime_nat(self):
def test_strftime_nat(self, using_infer_string):
# GH 29578
arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]"))

result = arr.strftime("%Y-%m-%d")
expected = np.array(["2019-01-01", np.nan], dtype=object)
tm.assert_numpy_array_equal(result, expected)
if using_infer_string:
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
tm.assert_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/io/excel/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def test_excel_multindex_roundtrip(
)
tm.assert_frame_equal(df, act, check_names=check_names)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_read_excel_parse_dates(self, tmp_excel):
# see gh-11544, gh-12051
df = DataFrame(
Expand Down
8 changes: 3 additions & 5 deletions pandas/tests/series/accessors/test_dt_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Period,
PeriodIndex,
Series,
StringDtype,
TimedeltaIndex,
date_range,
period_range,
Expand Down Expand Up @@ -513,7 +514,6 @@ def test_dt_accessor_datetime_name_accessors(self, time_locale):
ser = pd.concat([ser, Series([pd.NaT])])
assert np.isnan(ser.dt.month_name(locale=time_locale).iloc[-1])

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime(self):
# GH 10086
ser = Series(date_range("20130101", periods=5))
Expand Down Expand Up @@ -584,10 +584,9 @@ def test_strftime_period_days(self, using_infer_string):
dtype="=U10",
)
if using_infer_string:
expected = expected.astype("str")
expected = expected.astype(StringDtype(na_value=np.nan))
tm.assert_index_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_strftime_dt64_microsecond_resolution(self):
ser = Series([datetime(2013, 1, 1, 2, 32, 59), datetime(2013, 1, 2, 14, 32, 1)])
result = ser.dt.strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -620,7 +619,6 @@ def test_strftime_period_minutes(self):
)
tm.assert_series_equal(result, expected)

@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize(
"data",
[
Expand All @@ -643,7 +641,7 @@ def test_strftime_all_nat(self, data):
ser = Series(data)
with tm.assert_produces_warning(None):
result = ser.dt.strftime("%Y-%m-%d")
expected = Series([np.nan], dtype=object)
expected = Series([np.nan], dtype="str")
tm.assert_series_equal(result, expected)

def test_valid_dt_with_missing_values(self):
Expand Down