Skip to content

Commit d98026e

Browse files
jbrockmendelWillAyd
authored andcommitted
API (string): return str dtype for .dt methods, DatetimeIndex methods (pandas-dev#59526)
* API (string): return str dtype for .dt methods, DatetimeIndex methods * mypy fixup
1 parent d47cca7 commit d98026e

File tree

6 files changed

+45
-16
lines changed

6 files changed

+45
-16
lines changed

pandas/core/arrays/datetimelike.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import numpy as np
2222

23+
from pandas._config import using_string_dtype
24+
2325
from pandas._libs import (
2426
algos,
2527
lib,
@@ -1789,6 +1791,10 @@ def strftime(self, date_format: str) -> npt.NDArray[np.object_]:
17891791
dtype='object')
17901792
"""
17911793
result = self._format_native_types(date_format=date_format, na_rep=np.nan)
1794+
if using_string_dtype():
1795+
from pandas import StringDtype
1796+
1797+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
17921798
return result.astype(object, copy=False)
17931799

17941800

pandas/core/arrays/datetimes.py

+17
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import numpy as np
1616

17+
from pandas._config import using_string_dtype
18+
1719
from pandas._libs import (
1820
lib,
1921
tslib,
@@ -1306,6 +1308,13 @@ def month_name(self, locale=None) -> npt.NDArray[np.object_]:
13061308
values, "month_name", locale=locale, reso=self._creso
13071309
)
13081310
result = self._maybe_mask_results(result, fill_value=None)
1311+
if using_string_dtype():
1312+
from pandas import (
1313+
StringDtype,
1314+
array as pd_array,
1315+
)
1316+
1317+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
13091318
return result
13101319

13111320
def day_name(self, locale=None) -> npt.NDArray[np.object_]:
@@ -1363,6 +1372,14 @@ def day_name(self, locale=None) -> npt.NDArray[np.object_]:
13631372
values, "day_name", locale=locale, reso=self._creso
13641373
)
13651374
result = self._maybe_mask_results(result, fill_value=None)
1375+
if using_string_dtype():
1376+
# TODO: no tests that check for dtype of result as of 2024-08-15
1377+
from pandas import (
1378+
StringDtype,
1379+
array as pd_array,
1380+
)
1381+
1382+
return pd_array(result, dtype=StringDtype(na_value=np.nan)) # type: ignore[return-value]
13661383
return result
13671384

13681385
@property

pandas/core/indexes/datetimes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def _engine_type(self) -> type[libindex.DatetimeEngine]:
276276
@doc(DatetimeArray.strftime)
277277
def strftime(self, date_format) -> Index:
278278
arr = self._data.strftime(date_format)
279-
return Index(arr, name=self.name, dtype=object)
279+
return Index(arr, name=self.name, dtype=arr.dtype)
280280

281281
@doc(DatetimeArray.tz_convert)
282282
def tz_convert(self, tz) -> Self:

pandas/core/indexes/extension.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def fget(self):
7171
return type(self)._simple_new(result, name=self.name)
7272
elif isinstance(result, ABCDataFrame):
7373
return result.set_index(self)
74-
return Index(result, name=self.name)
74+
return Index(result, name=self.name, dtype=result.dtype)
7575
return result
7676

7777
def fset(self, value) -> None:
@@ -98,7 +98,7 @@ def method(self, *args, **kwargs): # type: ignore[misc]
9898
return type(self)._simple_new(result, name=self.name)
9999
elif isinstance(result, ABCDataFrame):
100100
return result.set_index(self)
101-
return Index(result, name=self.name)
101+
return Index(result, name=self.name, dtype=result.dtype)
102102
return result
103103

104104
# error: "property" has no attribute "__name__"

pandas/tests/arrays/test_datetimelike.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -889,20 +889,24 @@ def test_concat_same_type_different_freq(self, unit):
889889

890890
tm.assert_datetime_array_equal(result, expected)
891891

892-
def test_strftime(self, arr1d):
892+
def test_strftime(self, arr1d, using_infer_string):
893893
arr = arr1d
894894

895895
result = arr.strftime("%Y %b")
896896
expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object)
897-
tm.assert_numpy_array_equal(result, expected)
897+
if using_infer_string:
898+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
899+
tm.assert_equal(result, expected)
898900

899-
def test_strftime_nat(self):
901+
def test_strftime_nat(self, using_infer_string):
900902
# GH 29578
901903
arr = DatetimeIndex(["2019-01-01", NaT])._data
902904

903905
result = arr.strftime("%Y-%m-%d")
904906
expected = np.array(["2019-01-01", np.nan], dtype=object)
905-
tm.assert_numpy_array_equal(result, expected)
907+
if using_infer_string:
908+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
909+
tm.assert_equal(result, expected)
906910

907911

908912
class TestTimedeltaArray(SharedTests):
@@ -1159,20 +1163,24 @@ def test_array_interface(self, arr1d):
11591163
expected = np.asarray(arr).astype("S20")
11601164
tm.assert_numpy_array_equal(result, expected)
11611165

1162-
def test_strftime(self, arr1d):
1166+
def test_strftime(self, arr1d, using_infer_string):
11631167
arr = arr1d
11641168

11651169
result = arr.strftime("%Y")
11661170
expected = np.array([per.strftime("%Y") for per in arr], dtype=object)
1167-
tm.assert_numpy_array_equal(result, expected)
1171+
if using_infer_string:
1172+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
1173+
tm.assert_equal(result, expected)
11681174

1169-
def test_strftime_nat(self):
1175+
def test_strftime_nat(self, using_infer_string):
11701176
# GH 29578
11711177
arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]"))
11721178

11731179
result = arr.strftime("%Y-%m-%d")
11741180
expected = np.array(["2019-01-01", np.nan], dtype=object)
1175-
tm.assert_numpy_array_equal(result, expected)
1181+
if using_infer_string:
1182+
expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan))
1183+
tm.assert_equal(result, expected)
11761184

11771185

11781186
@pytest.mark.parametrize(

pandas/tests/series/accessors/test_dt_accessor.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Period,
3030
PeriodIndex,
3131
Series,
32+
StringDtype,
3233
TimedeltaIndex,
3334
date_range,
3435
period_range,
@@ -528,7 +529,6 @@ def test_dt_accessor_datetime_name_accessors(self, time_locale):
528529
ser = pd.concat([ser, Series([pd.NaT])])
529530
assert np.isnan(ser.dt.month_name(locale=time_locale).iloc[-1])
530531

531-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
532532
def test_strftime(self):
533533
# GH 10086
534534
ser = Series(date_range("20130101", periods=5))
@@ -599,10 +599,9 @@ def test_strftime_period_days(self, using_infer_string):
599599
dtype="=U10",
600600
)
601601
if using_infer_string:
602-
expected = expected.astype("str")
602+
expected = expected.astype(StringDtype(na_value=np.nan))
603603
tm.assert_index_equal(result, expected)
604604

605-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
606605
def test_strftime_dt64_microsecond_resolution(self):
607606
ser = Series([datetime(2013, 1, 1, 2, 32, 59), datetime(2013, 1, 2, 14, 32, 1)])
608607
result = ser.dt.strftime("%Y-%m-%d %H:%M:%S")
@@ -635,7 +634,6 @@ def test_strftime_period_minutes(self):
635634
)
636635
tm.assert_series_equal(result, expected)
637636

638-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
639637
@pytest.mark.parametrize(
640638
"data",
641639
[
@@ -658,7 +656,7 @@ def test_strftime_all_nat(self, data):
658656
ser = Series(data)
659657
with tm.assert_produces_warning(None):
660658
result = ser.dt.strftime("%Y-%m-%d")
661-
expected = Series([np.nan], dtype=object)
659+
expected = Series([np.nan], dtype="str")
662660
tm.assert_series_equal(result, expected)
663661

664662
def test_valid_dt_with_missing_values(self):

0 commit comments

Comments
 (0)