Skip to content

Commit 6792f5c

Browse files
authored
REF: simplify DTA.__init__ (#47574)
1 parent 9f38929 commit 6792f5c

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

pandas/core/arrays/datetimes.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _scalar_type(self) -> type[Timestamp]:
256256
_freq = None
257257

258258
def __init__(
259-
self, values, dtype=DT64NS_DTYPE, freq=lib.no_default, copy: bool = False
259+
self, values, dtype=None, freq=lib.no_default, copy: bool = False
260260
) -> None:
261261
values = extract_array(values, extract_numpy=True)
262262
if isinstance(values, IntegerArray):
@@ -276,22 +276,19 @@ def __init__(
276276
freq = to_offset(freq)
277277
freq, _ = dtl.validate_inferred_freq(freq, values.freq, False)
278278

279-
# validation
280-
dtz = getattr(dtype, "tz", None)
281-
if dtz and values.tz is None:
282-
dtype = DatetimeTZDtype(tz=dtype.tz)
283-
elif dtz and values.tz:
284-
if not timezones.tz_compare(dtz, values.tz):
285-
msg = (
286-
"Timezone of the array and 'dtype' do not match. "
287-
f"'{dtz}' != '{values.tz}'"
279+
if dtype is not None:
280+
dtype = pandas_dtype(dtype)
281+
if not is_dtype_equal(dtype, values.dtype):
282+
raise TypeError(
283+
f"dtype={dtype} does not match data dtype {values.dtype}"
288284
)
289-
raise TypeError(msg)
290-
elif values.tz:
291-
dtype = values.dtype
292285

286+
dtype = values.dtype
293287
values = values._ndarray
294288

289+
elif dtype is None:
290+
dtype = DT64NS_DTYPE
291+
295292
if not isinstance(values, np.ndarray):
296293
raise ValueError(
297294
f"Unexpected type '{type(values).__name__}'. 'values' must be a "

pandas/tests/arrays/datetimes/test_constructors.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,16 @@ def test_mismatched_timezone_raises(self):
7777
dtype=DatetimeTZDtype(tz="US/Central"),
7878
)
7979
dtype = DatetimeTZDtype(tz="US/Eastern")
80-
with pytest.raises(TypeError, match="Timezone of the array"):
80+
msg = r"dtype=datetime64\[ns.*\] does not match data dtype datetime64\[ns.*\]"
81+
with pytest.raises(TypeError, match=msg):
8182
DatetimeArray(arr, dtype=dtype)
8283

84+
# also with mismatched tzawareness
85+
with pytest.raises(TypeError, match=msg):
86+
DatetimeArray(arr, dtype=np.dtype("M8[ns]"))
87+
with pytest.raises(TypeError, match=msg):
88+
DatetimeArray(arr.tz_localize(None), dtype=arr.dtype)
89+
8390
def test_non_array_raises(self):
8491
with pytest.raises(ValueError, match="list"):
8592
DatetimeArray([1, 2, 3])

0 commit comments

Comments
 (0)