Skip to content

Commit e48df1c

Browse files
authored
BUG: DTA/TDA constructors with mismatched values/dtype resolutions (#55658)
* BUG: DTA/TDA constructors with mismatched values/dtype resolutions * mypy fixup
1 parent 5622b1b commit e48df1c

File tree

11 files changed

+107
-33
lines changed

11 files changed

+107
-33
lines changed

pandas/core/arrays/_mixins.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,31 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
128128
dtype = pandas_dtype(dtype)
129129
arr = self._ndarray
130130

131-
if isinstance(dtype, (PeriodDtype, DatetimeTZDtype)):
131+
if isinstance(dtype, PeriodDtype):
132132
cls = dtype.construct_array_type()
133133
return cls(arr.view("i8"), dtype=dtype)
134+
elif isinstance(dtype, DatetimeTZDtype):
135+
# error: Incompatible types in assignment (expression has type
136+
# "type[DatetimeArray]", variable has type "type[PeriodArray]")
137+
cls = dtype.construct_array_type() # type: ignore[assignment]
138+
dt64_values = arr.view(f"M8[{dtype.unit}]")
139+
return cls(dt64_values, dtype=dtype)
134140
elif dtype == "M8[ns]":
135141
from pandas.core.arrays import DatetimeArray
136142

137-
return DatetimeArray(arr.view("i8"), dtype=dtype)
143+
# error: Argument 1 to "view" of "ndarray" has incompatible type
144+
# "ExtensionDtype | dtype[Any]"; expected "dtype[Any] | type[Any]
145+
# | _SupportsDType[dtype[Any]]"
146+
dt64_values = arr.view(dtype) # type: ignore[arg-type]
147+
return DatetimeArray(dt64_values, dtype=dtype)
138148
elif dtype == "m8[ns]":
139149
from pandas.core.arrays import TimedeltaArray
140150

141-
return TimedeltaArray(arr.view("i8"), dtype=dtype)
151+
# error: Argument 1 to "view" of "ndarray" has incompatible type
152+
# "ExtensionDtype | dtype[Any]"; expected "dtype[Any] | type[Any]
153+
# | _SupportsDType[dtype[Any]]"
154+
td64_values = arr.view(dtype) # type: ignore[arg-type]
155+
return TimedeltaArray(td64_values, dtype=dtype)
142156

143157
# error: Argument "dtype" to "view" of "_ArrayOrScalarCommon" has incompatible
144158
# type "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None,

pandas/core/arrays/datetimelike.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,9 @@ class TimelikeOps(DatetimeLikeArrayMixin):
19181918
def __init__(
19191919
self, values, dtype=None, freq=lib.no_default, copy: bool = False
19201920
) -> None:
1921+
if dtype is not None:
1922+
dtype = pandas_dtype(dtype)
1923+
19211924
values = extract_array(values, extract_numpy=True)
19221925
if isinstance(values, IntegerArray):
19231926
values = values.to_numpy("int64", na_value=iNaT)
@@ -1936,13 +1939,11 @@ def __init__(
19361939
freq = to_offset(freq)
19371940
freq, _ = validate_inferred_freq(freq, values.freq, False)
19381941

1939-
if dtype is not None:
1940-
dtype = pandas_dtype(dtype)
1941-
if dtype != values.dtype:
1942-
# TODO: we only have tests for this for DTA, not TDA (2022-07-01)
1943-
raise TypeError(
1944-
f"dtype={dtype} does not match data dtype {values.dtype}"
1945-
)
1942+
if dtype is not None and dtype != values.dtype:
1943+
# TODO: we only have tests for this for DTA, not TDA (2022-07-01)
1944+
raise TypeError(
1945+
f"dtype={dtype} does not match data dtype {values.dtype}"
1946+
)
19461947

19471948
dtype = values.dtype
19481949
values = values._ndarray
@@ -1952,6 +1953,8 @@ def __init__(
19521953
dtype = values.dtype
19531954
else:
19541955
dtype = self._default_dtype
1956+
if isinstance(values, np.ndarray) and values.dtype == "i8":
1957+
values = values.view(dtype)
19551958

19561959
if not isinstance(values, np.ndarray):
19571960
raise ValueError(
@@ -1966,7 +1969,15 @@ def __init__(
19661969
# for compat with datetime/timedelta/period shared methods,
19671970
# we can sometimes get here with int64 values. These represent
19681971
# nanosecond UTC (or tz-naive) unix timestamps
1969-
values = values.view(self._default_dtype)
1972+
if dtype is None:
1973+
dtype = self._default_dtype
1974+
values = values.view(self._default_dtype)
1975+
elif lib.is_np_dtype(dtype, "mM"):
1976+
values = values.view(dtype)
1977+
elif isinstance(dtype, DatetimeTZDtype):
1978+
kind = self._default_dtype.kind
1979+
new_dtype = f"{kind}8[{dtype.unit}]"
1980+
values = values.view(new_dtype)
19701981

19711982
dtype = self._validate_dtype(values, dtype)
19721983

pandas/core/arrays/datetimes.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,15 @@ def _from_scalars(cls, scalars, *, dtype: DtypeObj) -> Self:
278278
@classmethod
279279
def _validate_dtype(cls, values, dtype):
280280
# used in TimeLikeOps.__init__
281-
_validate_dt64_dtype(values.dtype)
282281
dtype = _validate_dt64_dtype(dtype)
282+
_validate_dt64_dtype(values.dtype)
283+
if isinstance(dtype, np.dtype):
284+
if values.dtype != dtype:
285+
raise ValueError("Values resolution does not match dtype.")
286+
else:
287+
vunit = np.datetime_data(values.dtype)[0]
288+
if vunit != dtype.unit:
289+
raise ValueError("Values resolution does not match dtype.")
283290
return dtype
284291

285292
# error: Signature of "_simple_new" incompatible with supertype "NDArrayBacked"

pandas/core/arrays/timedeltas.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,10 @@ def dtype(self) -> np.dtype[np.timedelta64]: # type: ignore[override]
205205
@classmethod
206206
def _validate_dtype(cls, values, dtype):
207207
# used in TimeLikeOps.__init__
208-
_validate_td64_dtype(values.dtype)
209208
dtype = _validate_td64_dtype(dtype)
209+
_validate_td64_dtype(values.dtype)
210+
if dtype != values.dtype:
211+
raise ValueError("Values resolution does not match dtype.")
210212
return dtype
211213

212214
# error: Signature of "_simple_new" incompatible with supertype "NDArrayBacked"
@@ -1202,11 +1204,9 @@ def _validate_td64_dtype(dtype) -> DtypeObj:
12021204
)
12031205
raise ValueError(msg)
12041206

1205-
if (
1206-
not isinstance(dtype, np.dtype)
1207-
or dtype.kind != "m"
1208-
or not is_supported_unit(get_unit_from_dtype(dtype))
1209-
):
1210-
raise ValueError(f"dtype {dtype} cannot be converted to timedelta64[ns]")
1207+
if not lib.is_np_dtype(dtype, "m"):
1208+
raise ValueError(f"dtype '{dtype}' is invalid, should be np.timedelta64 dtype")
1209+
elif not is_supported_unit(get_unit_from_dtype(dtype)):
1210+
raise ValueError("Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'")
12111211

12121212
return dtype

pandas/core/internals/managers.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2307,7 +2307,8 @@ def make_na_array(dtype: DtypeObj, shape: Shape, fill_value) -> ArrayLike:
23072307
# NB: exclude e.g. pyarrow[dt64tz] dtypes
23082308
ts = Timestamp(fill_value).as_unit(dtype.unit)
23092309
i8values = np.full(shape, ts._value)
2310-
return DatetimeArray(i8values, dtype=dtype)
2310+
dt64values = i8values.view(f"M8[{dtype.unit}]")
2311+
return DatetimeArray(dt64values, dtype=dtype)
23112312

23122313
elif is_1d_only_ea_dtype(dtype):
23132314
dtype = cast(ExtensionDtype, dtype)

pandas/core/tools/datetimes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,9 @@ def _convert_listlike_datetimes(
494494
if tz_parsed is not None:
495495
# We can take a shortcut since the datetime64 numpy array
496496
# is in UTC
497-
dta = DatetimeArray(result, dtype=tz_to_dtype(tz_parsed))
497+
dtype = cast(DatetimeTZDtype, tz_to_dtype(tz_parsed))
498+
dt64_values = result.view(f"M8[{dtype.unit}]")
499+
dta = DatetimeArray(dt64_values, dtype=dtype)
498500
return DatetimeIndex._simple_new(dta, name=name)
499501

500502
return _box_as_indexlike(result, utc=utc, name=name)

pandas/tests/arrays/datetimes/test_constructors.py

+18
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,24 @@ def test_incorrect_dtype_raises(self):
112112
with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
113113
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="category")
114114

115+
with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
116+
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="m8[s]")
117+
118+
with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
119+
DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="M8[D]")
120+
121+
def test_mismatched_values_dtype_units(self):
122+
arr = np.array([1, 2, 3], dtype="M8[s]")
123+
dtype = np.dtype("M8[ns]")
124+
msg = "Values resolution does not match dtype."
125+
126+
with pytest.raises(ValueError, match=msg):
127+
DatetimeArray(arr, dtype=dtype)
128+
129+
dtype2 = DatetimeTZDtype(tz="UTC", unit="ns")
130+
with pytest.raises(ValueError, match=msg):
131+
DatetimeArray(arr, dtype=dtype2)
132+
115133
def test_freq_infer_raises(self):
116134
with pytest.raises(ValueError, match="Frequency inference"):
117135
DatetimeArray(np.array([1, 2, 3], dtype="i8"), freq="infer")

pandas/tests/arrays/test_datetimes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_repeat_preserves_tz(self):
468468
repeated = arr.repeat([1, 1])
469469

470470
# preserves tz and values, but not freq
471-
expected = DatetimeArray(arr.asi8, freq=None, dtype=arr.dtype)
471+
expected = DatetimeArray._from_sequence(arr.asi8, dtype=arr.dtype)
472472
tm.assert_equal(repeated, expected)
473473

474474
def test_value_counts_preserves_tz(self):

pandas/tests/arrays/timedeltas/test_constructors.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,40 @@ def test_non_array_raises(self):
3333
TimedeltaArray([1, 2, 3])
3434

3535
def test_other_type_raises(self):
36-
with pytest.raises(ValueError, match="dtype bool cannot be converted"):
36+
msg = "dtype 'bool' is invalid, should be np.timedelta64 dtype"
37+
with pytest.raises(ValueError, match=msg):
3738
TimedeltaArray(np.array([1, 2, 3], dtype="bool"))
3839

3940
def test_incorrect_dtype_raises(self):
40-
# TODO: why TypeError for 'category' but ValueError for i8?
41-
with pytest.raises(
42-
ValueError, match=r"category cannot be converted to timedelta64\[ns\]"
43-
):
41+
msg = "dtype 'category' is invalid, should be np.timedelta64 dtype"
42+
with pytest.raises(ValueError, match=msg):
4443
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype="category")
4544

46-
with pytest.raises(
47-
ValueError, match=r"dtype int64 cannot be converted to timedelta64\[ns\]"
48-
):
45+
msg = "dtype 'int64' is invalid, should be np.timedelta64 dtype"
46+
with pytest.raises(ValueError, match=msg):
4947
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("int64"))
5048

49+
msg = r"dtype 'datetime64\[ns\]' is invalid, should be np.timedelta64 dtype"
50+
with pytest.raises(ValueError, match=msg):
51+
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("M8[ns]"))
52+
53+
msg = (
54+
r"dtype 'datetime64\[us, UTC\]' is invalid, should be np.timedelta64 dtype"
55+
)
56+
with pytest.raises(ValueError, match=msg):
57+
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype="M8[us, UTC]")
58+
59+
msg = "Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'"
60+
with pytest.raises(ValueError, match=msg):
61+
TimedeltaArray(np.array([1, 2, 3], dtype="i8"), dtype=np.dtype("m8[Y]"))
62+
63+
def test_mismatched_values_dtype_units(self):
64+
arr = np.array([1, 2, 3], dtype="m8[s]")
65+
dtype = np.dtype("m8[ns]")
66+
msg = r"Values resolution does not match dtype"
67+
with pytest.raises(ValueError, match=msg):
68+
TimedeltaArray(arr, dtype=dtype)
69+
5170
def test_copy(self):
5271
data = np.array([1, 2, 3], dtype="m8[ns]")
5372
arr = TimedeltaArray(data, copy=False)
@@ -58,6 +77,6 @@ def test_copy(self):
5877
assert arr._ndarray.base is not data
5978

6079
def test_from_sequence_dtype(self):
61-
msg = "dtype .*object.* cannot be converted to timedelta64"
80+
msg = "dtype 'object' is invalid, should be np.timedelta64 dtype"
6281
with pytest.raises(ValueError, match=msg):
6382
TimedeltaArray._from_sequence([], dtype=object)

pandas/tests/base/test_conversion.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,9 @@ def test_array_multiindex_raises():
314314
),
315315
# Timedelta
316316
(
317-
TimedeltaArray(np.array([0, 3600000000000], dtype="i8"), freq="h"),
317+
TimedeltaArray(
318+
np.array([0, 3600000000000], dtype="i8").view("m8[ns]"), freq="h"
319+
),
318320
np.array([0, 3600000000000], dtype="m8[ns]"),
319321
),
320322
# GH#26406 tz is preserved in Categorical[dt64tz]

pandas/tests/indexes/timedeltas/test_constructors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def test_constructor_no_precision_raises(self):
240240
pd.Index(["2000"], dtype="timedelta64")
241241

242242
def test_constructor_wrong_precision_raises(self):
243-
msg = r"dtype timedelta64\[D\] cannot be converted to timedelta64\[ns\]"
243+
msg = "Supported timedelta64 resolutions are 's', 'ms', 'us', 'ns'"
244244
with pytest.raises(ValueError, match=msg):
245245
TimedeltaIndex(["2000"], dtype="timedelta64[D]")
246246

0 commit comments

Comments
 (0)