Skip to content

Commit 40e2dd2

Browse files
jbrockmendelKevin D Smith
authored and
Kevin D Smith
committed
EA: Tighten signature on DatetimeArray._from_sequence (pandas-dev#36718)
1 parent f8e53e6 commit 40e2dd2

File tree

8 files changed

+41
-15
lines changed

8 files changed

+41
-15
lines changed

pandas/core/arrays/datetimes.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,11 @@ def _simple_new(
299299
return result
300300

301301
@classmethod
302-
def _from_sequence(
302+
def _from_sequence(cls, scalars, dtype=None, copy: bool = False):
303+
return cls._from_sequence_not_strict(scalars, dtype=dtype, copy=copy)
304+
305+
@classmethod
306+
def _from_sequence_not_strict(
303307
cls,
304308
data,
305309
dtype=None,

pandas/core/indexes/datetimes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __new__(
295295

296296
name = maybe_extract_name(name, data, cls)
297297

298-
dtarr = DatetimeArray._from_sequence(
298+
dtarr = DatetimeArray._from_sequence_not_strict(
299299
data,
300300
dtype=dtype,
301301
copy=copy,

pandas/core/nanops.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,9 @@ def na_accum_func(values: ArrayLike, accum_func, skipna: bool) -> ArrayLike:
16161616
result = result.view(orig_dtype)
16171617
else:
16181618
# DatetimeArray
1619-
result = type(values)._from_sequence(result, dtype=orig_dtype)
1619+
result = type(values)._simple_new( # type: ignore[attr-defined]
1620+
result, dtype=orig_dtype
1621+
)
16201622

16211623
elif skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)):
16221624
vals = values.copy()

pandas/tests/arrays/test_array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def test_array_copy():
210210
datetime.datetime(2000, 1, 1, tzinfo=cet),
211211
datetime.datetime(2001, 1, 1, tzinfo=cet),
212212
],
213-
DatetimeArray._from_sequence(["2000", "2001"], tz=cet),
213+
DatetimeArray._from_sequence(
214+
["2000", "2001"], dtype=pd.DatetimeTZDtype(tz=cet)
215+
),
214216
),
215217
# timedelta
216218
(

pandas/tests/arrays/test_datetimes.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_mixing_naive_tzaware_raises(self, meth):
7171
def test_from_pandas_array(self):
7272
arr = pd.array(np.arange(5, dtype=np.int64)) * 3600 * 10 ** 9
7373

74-
result = DatetimeArray._from_sequence(arr, freq="infer")
74+
result = DatetimeArray._from_sequence(arr)._with_freq("infer")
7575

7676
expected = pd.date_range("1970-01-01", periods=5, freq="H")._data
7777
tm.assert_datetime_array_equal(result, expected)
@@ -162,7 +162,9 @@ def test_cmp_dt64_arraylike_tznaive(self, all_compare_operators):
162162

163163
class TestDatetimeArray:
164164
def test_astype_to_same(self):
165-
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
165+
arr = DatetimeArray._from_sequence(
166+
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
167+
)
166168
result = arr.astype(DatetimeTZDtype(tz="US/Central"), copy=False)
167169
assert result is arr
168170

@@ -193,7 +195,9 @@ def test_astype_int(self, dtype):
193195
tm.assert_numpy_array_equal(result, expected)
194196

195197
def test_tz_setter_raises(self):
196-
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
198+
arr = DatetimeArray._from_sequence(
199+
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
200+
)
197201
with pytest.raises(AttributeError, match="tz_localize"):
198202
arr.tz = "UTC"
199203

@@ -282,7 +286,8 @@ def test_fillna_preserves_tz(self, method):
282286

283287
fill_val = dti[1] if method == "pad" else dti[3]
284288
expected = DatetimeArray._from_sequence(
285-
[dti[0], dti[1], fill_val, dti[3], dti[4]], freq=None, tz="US/Central"
289+
[dti[0], dti[1], fill_val, dti[3], dti[4]],
290+
dtype=DatetimeTZDtype(tz="US/Central"),
286291
)
287292

288293
result = arr.fillna(method=method)
@@ -434,19 +439,24 @@ def test_shift_value_tzawareness_mismatch(self):
434439

435440
class TestSequenceToDT64NS:
436441
def test_tz_dtype_mismatch_raises(self):
437-
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
442+
arr = DatetimeArray._from_sequence(
443+
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
444+
)
438445
with pytest.raises(TypeError, match="data is already tz-aware"):
439446
sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="UTC"))
440447

441448
def test_tz_dtype_matches(self):
442-
arr = DatetimeArray._from_sequence(["2000"], tz="US/Central")
449+
arr = DatetimeArray._from_sequence(
450+
["2000"], dtype=DatetimeTZDtype(tz="US/Central")
451+
)
443452
result, _, _ = sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="US/Central"))
444453
tm.assert_numpy_array_equal(arr._data, result)
445454

446455

447456
class TestReductions:
448457
@pytest.mark.parametrize("tz", [None, "US/Central"])
449458
def test_min_max(self, tz):
459+
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
450460
arr = DatetimeArray._from_sequence(
451461
[
452462
"2000-01-03",
@@ -456,7 +466,7 @@ def test_min_max(self, tz):
456466
"2000-01-05",
457467
"2000-01-04",
458468
],
459-
tz=tz,
469+
dtype=dtype,
460470
)
461471

462472
result = arr.min()
@@ -476,7 +486,8 @@ def test_min_max(self, tz):
476486
@pytest.mark.parametrize("tz", [None, "US/Central"])
477487
@pytest.mark.parametrize("skipna", [True, False])
478488
def test_min_max_empty(self, skipna, tz):
479-
arr = DatetimeArray._from_sequence([], tz=tz)
489+
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
490+
arr = DatetimeArray._from_sequence([], dtype=dtype)
480491
result = arr.min(skipna=skipna)
481492
assert result is pd.NaT
482493

pandas/tests/extension/test_datetime.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ def test_concat_mixed_dtypes(self, data):
181181
@pytest.mark.parametrize("obj", ["series", "frame"])
182182
def test_unstack(self, obj):
183183
# GH-13287: can't use base test, since building the expected fails.
184+
dtype = DatetimeTZDtype(tz="US/Central")
184185
data = DatetimeArray._from_sequence(
185-
["2000", "2001", "2002", "2003"], tz="US/Central"
186+
["2000", "2001", "2002", "2003"],
187+
dtype=dtype,
186188
)
187189
index = pd.MultiIndex.from_product(([["A", "B"], ["a", "b"]]), names=["a", "b"])
188190

pandas/tests/indexes/datetimes/test_constructors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717

1818
class TestDatetimeIndex:
19-
@pytest.mark.parametrize("dt_cls", [DatetimeIndex, DatetimeArray._from_sequence])
19+
@pytest.mark.parametrize(
20+
"dt_cls", [DatetimeIndex, DatetimeArray._from_sequence_not_strict]
21+
)
2022
def test_freq_validation_with_nat(self, dt_cls):
2123
# GH#11587 make sure we get a useful error message when generate_range
2224
# raises

pandas/tests/scalar/test_nat.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from pandas import (
1414
DatetimeIndex,
15+
DatetimeTZDtype,
1516
Index,
1617
NaT,
1718
Period,
@@ -440,7 +441,9 @@ def test_nat_rfloordiv_timedelta(val, expected):
440441
DatetimeIndex(["2011-01-01", "2011-01-02"], name="x"),
441442
DatetimeIndex(["2011-01-01", "2011-01-02"], tz="US/Eastern", name="x"),
442443
DatetimeArray._from_sequence(["2011-01-01", "2011-01-02"]),
443-
DatetimeArray._from_sequence(["2011-01-01", "2011-01-02"], tz="US/Pacific"),
444+
DatetimeArray._from_sequence(
445+
["2011-01-01", "2011-01-02"], dtype=DatetimeTZDtype(tz="US/Pacific")
446+
),
444447
TimedeltaIndex(["1 day", "2 day"], name="x"),
445448
],
446449
)

0 commit comments

Comments
 (0)