Skip to content

Commit bf84995

Browse files
jbrockmendelJulianWgs
authored andcommitted
BUG: constructing DTA/TDA from xarray/dask/pandasarray (pandas-dev#40210)
1 parent d6f0b3f commit bf84995

File tree

3 files changed

+164
-47
lines changed

3 files changed

+164
-47
lines changed

pandas/core/arrays/datetimes.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,18 @@
5959
pandas_dtype,
6060
)
6161
from pandas.core.dtypes.dtypes import DatetimeTZDtype
62-
from pandas.core.dtypes.generic import (
63-
ABCIndex,
64-
ABCPandasArray,
65-
ABCSeries,
66-
)
62+
from pandas.core.dtypes.generic import ABCMultiIndex
6763
from pandas.core.dtypes.missing import isna
6864

6965
from pandas.core.algorithms import checked_add_with_arr
70-
from pandas.core.arrays import datetimelike as dtl
66+
from pandas.core.arrays import (
67+
ExtensionArray,
68+
datetimelike as dtl,
69+
)
7170
from pandas.core.arrays._ranges import generate_regular_range
71+
from pandas.core.arrays.integer import IntegerArray
7272
import pandas.core.common as com
73+
from pandas.core.construction import extract_array
7374

7475
from pandas.tseries.frequencies import get_period_alias
7576
from pandas.tseries.offsets import (
@@ -239,8 +240,9 @@ class DatetimeArray(dtl.TimelikeOps, dtl.DatelikeOps):
239240
_freq = None
240241

241242
def __init__(self, values, dtype=DT64NS_DTYPE, freq=None, copy=False):
242-
if isinstance(values, (ABCSeries, ABCIndex)):
243-
values = values._values
243+
values = extract_array(values, extract_numpy=True)
244+
if isinstance(values, IntegerArray):
245+
values = values.to_numpy("int64", na_value=iNaT)
244246

245247
inferred_freq = getattr(values, "_freq", None)
246248

@@ -266,7 +268,7 @@ def __init__(self, values, dtype=DT64NS_DTYPE, freq=None, copy=False):
266268
if not isinstance(values, np.ndarray):
267269
raise ValueError(
268270
f"Unexpected type '{type(values).__name__}'. 'values' must be "
269-
"a DatetimeArray ndarray, or Series or Index containing one of those."
271+
"a DatetimeArray, ndarray, or Series or Index containing one of those."
270272
)
271273
if values.ndim not in [1, 2]:
272274
raise ValueError("Only 1-dimensional input arrays are supported.")
@@ -1978,30 +1980,29 @@ def sequence_to_dt64ns(
19781980
dtype = _validate_dt64_dtype(dtype)
19791981
tz = timezones.maybe_get_tz(tz)
19801982

1983+
# if dtype has an embedded tz, capture it
1984+
tz = validate_tz_from_dtype(dtype, tz)
1985+
19811986
if not hasattr(data, "dtype"):
19821987
# e.g. list, tuple
19831988
if np.ndim(data) == 0:
19841989
# i.e. generator
19851990
data = list(data)
19861991
data = np.asarray(data)
19871992
copy = False
1988-
elif isinstance(data, ABCSeries):
1989-
data = data._values
1990-
if isinstance(data, ABCPandasArray):
1991-
data = data.to_numpy()
1992-
1993-
if hasattr(data, "freq"):
1994-
# i.e. DatetimeArray/Index
1995-
inferred_freq = data.freq
1993+
elif isinstance(data, ABCMultiIndex):
1994+
raise TypeError("Cannot create a DatetimeArray from a MultiIndex.")
1995+
else:
1996+
data = extract_array(data, extract_numpy=True)
19961997

1997-
# if dtype has an embedded tz, capture it
1998-
tz = validate_tz_from_dtype(dtype, tz)
1998+
if isinstance(data, IntegerArray):
1999+
data = data.to_numpy("int64", na_value=iNaT)
2000+
elif not isinstance(data, (np.ndarray, ExtensionArray)):
2001+
# GH#24539 e.g. xarray, dask object
2002+
data = np.asarray(data)
19992003

2000-
if isinstance(data, ABCIndex):
2001-
if data.nlevels > 1:
2002-
# Without this check, data._data below is None
2003-
raise TypeError("Cannot create a DatetimeArray from a MultiIndex.")
2004-
data = data._data
2004+
if isinstance(data, DatetimeArray):
2005+
inferred_freq = data.freq
20052006

20062007
# By this point we are assured to have either a numpy array or Index
20072008
data, copy = maybe_convert_dtype(data, copy)
@@ -2045,13 +2046,14 @@ def sequence_to_dt64ns(
20452046
if is_datetime64tz_dtype(data_dtype):
20462047
# DatetimeArray -> ndarray
20472048
tz = _maybe_infer_tz(tz, data.tz)
2048-
result = data._data
2049+
result = data._ndarray
20492050

20502051
elif is_datetime64_dtype(data_dtype):
20512052
# tz-naive DatetimeArray or ndarray[datetime64]
2052-
data = getattr(data, "_data", data)
2053+
data = getattr(data, "_ndarray", data)
20532054
if data.dtype != DT64NS_DTYPE:
20542055
data = conversion.ensure_datetime64ns(data)
2056+
copy = False
20552057

20562058
if tz is not None:
20572059
# Convert tz-naive to UTC
@@ -2088,7 +2090,7 @@ def sequence_to_dt64ns(
20882090

20892091

20902092
def objects_to_datetime64ns(
2091-
data,
2093+
data: np.ndarray,
20922094
dayfirst,
20932095
yearfirst,
20942096
utc=False,

pandas/core/arrays/timedeltas.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,13 @@
5353
pandas_dtype,
5454
)
5555
from pandas.core.dtypes.dtypes import DatetimeTZDtype
56-
from pandas.core.dtypes.generic import (
57-
ABCSeries,
58-
ABCTimedeltaIndex,
59-
)
56+
from pandas.core.dtypes.generic import ABCMultiIndex
6057
from pandas.core.dtypes.missing import isna
6158

6259
from pandas.core import nanops
6360
from pandas.core.algorithms import checked_add_with_arr
6461
from pandas.core.arrays import (
62+
ExtensionArray,
6563
IntegerArray,
6664
datetimelike as dtl,
6765
)
@@ -172,7 +170,9 @@ def dtype(self) -> np.dtype: # type: ignore[override]
172170
_freq = None
173171

174172
def __init__(self, values, dtype=TD64NS_DTYPE, freq=lib.no_default, copy=False):
175-
values = extract_array(values)
173+
values = extract_array(values, extract_numpy=True)
174+
if isinstance(values, IntegerArray):
175+
values = values.to_numpy("int64", na_value=tslibs.iNaT)
176176

177177
inferred_freq = getattr(values, "_freq", None)
178178
explicit_none = freq is None
@@ -192,7 +192,7 @@ def __init__(self, values, dtype=TD64NS_DTYPE, freq=lib.no_default, copy=False):
192192
if not isinstance(values, np.ndarray):
193193
msg = (
194194
f"Unexpected type '{type(values).__name__}'. 'values' must be a "
195-
"TimedeltaArray ndarray, or Series or Index containing one of those."
195+
"TimedeltaArray, ndarray, or Series or Index containing one of those."
196196
)
197197
raise ValueError(msg)
198198
if values.ndim not in [1, 2]:
@@ -960,20 +960,23 @@ def sequence_to_td64ns(
960960
# i.e. generator
961961
data = list(data)
962962
data = np.array(data, copy=False)
963-
elif isinstance(data, ABCSeries):
964-
data = data._values
965-
elif isinstance(data, ABCTimedeltaIndex):
966-
inferred_freq = data.freq
967-
data = data._data._ndarray
968-
elif isinstance(data, TimedeltaArray):
969-
inferred_freq = data.freq
970-
data = data._ndarray
971-
elif isinstance(data, IntegerArray):
972-
data = data.to_numpy("int64", na_value=tslibs.iNaT)
963+
elif isinstance(data, ABCMultiIndex):
964+
raise TypeError("Cannot create a DatetimeArray from a MultiIndex.")
965+
else:
966+
data = extract_array(data, extract_numpy=True)
967+
968+
if isinstance(data, IntegerArray):
969+
data = data.to_numpy("int64", na_value=iNaT)
970+
elif not isinstance(data, (np.ndarray, ExtensionArray)):
971+
# GH#24539 e.g. xarray, dask object
972+
data = np.asarray(data)
973973
elif is_categorical_dtype(data.dtype):
974974
data = data.categories.take(data.codes, fill_value=NaT)._values
975975
copy = False
976976

977+
if isinstance(data, TimedeltaArray):
978+
inferred_freq = data.freq
979+
977980
# Convert whatever we have into timedelta64[ns] dtype
978981
if is_object_dtype(data.dtype) or is_string_dtype(data.dtype):
979982
# no need to make a copy, need to convert if string-dtyped

pandas/tests/arrays/test_datetimelike.py

+116-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Timestamp,
1414
)
1515
from pandas.compat import np_version_under1p18
16+
import pandas.util._test_decorators as td
1617

1718
import pandas as pd
1819
from pandas import (
@@ -28,6 +29,8 @@
2829
PeriodArray,
2930
TimedeltaArray,
3031
)
32+
from pandas.core.arrays.datetimes import sequence_to_dt64ns
33+
from pandas.core.arrays.timedeltas import sequence_to_td64ns
3134

3235

3336
# TODO: more freq variants
@@ -224,7 +227,7 @@ def test_unbox_scalar(self):
224227
result = arr._unbox_scalar(NaT)
225228
assert isinstance(result, expected)
226229

227-
msg = f"'value' should be a {self.dtype.__name__}."
230+
msg = f"'value' should be a {self.scalar_type.__name__}."
228231
with pytest.raises(ValueError, match=msg):
229232
arr._unbox_scalar("foo")
230233

@@ -614,11 +617,21 @@ def test_median(self, arr1d):
614617
result = arr2.median(axis=1, skipna=False)
615618
tm.assert_equal(result, arr)
616619

620+
def test_from_integer_array(self):
621+
arr = np.array([1, 2, 3], dtype=np.int64)
622+
expected = self.array_cls(arr, dtype=self.example_dtype)
623+
624+
data = pd.array(arr, dtype="Int64")
625+
result = self.array_cls(data, dtype=self.example_dtype)
626+
627+
tm.assert_extension_array_equal(result, expected)
628+
617629

618630
class TestDatetimeArray(SharedTests):
619631
index_cls = DatetimeIndex
620632
array_cls = DatetimeArray
621-
dtype = Timestamp
633+
scalar_type = Timestamp
634+
example_dtype = "M8[ns]"
622635

623636
@pytest.fixture
624637
def arr1d(self, tz_naive_fixture, freqstr):
@@ -918,7 +931,8 @@ def test_strftime_nat(self):
918931
class TestTimedeltaArray(SharedTests):
919932
index_cls = TimedeltaIndex
920933
array_cls = TimedeltaArray
921-
dtype = pd.Timedelta
934+
scalar_type = pd.Timedelta
935+
example_dtype = "m8[ns]"
922936

923937
def test_from_tdi(self):
924938
tdi = TimedeltaIndex(["1 Day", "3 Hours"])
@@ -1037,7 +1051,8 @@ def test_take_fill_valid(self, timedelta_index):
10371051
class TestPeriodArray(SharedTests):
10381052
index_cls = PeriodIndex
10391053
array_cls = PeriodArray
1040-
dtype = Period
1054+
scalar_type = Period
1055+
example_dtype = PeriodIndex([], freq="W").dtype
10411056

10421057
@pytest.fixture
10431058
def arr1d(self, period_index):
@@ -1305,3 +1320,100 @@ def test_period_index_construction_from_strings(klass):
13051320
result = PeriodIndex(data, freq="Q")
13061321
expected = PeriodIndex([Period(s) for s in strings])
13071322
tm.assert_index_equal(result, expected)
1323+
1324+
1325+
@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"])
1326+
def test_from_pandas_array(dtype):
1327+
# GH#24615
1328+
data = np.array([1, 2, 3], dtype=dtype)
1329+
arr = PandasArray(data)
1330+
1331+
cls = {"M8[ns]": DatetimeArray, "m8[ns]": TimedeltaArray}[dtype]
1332+
1333+
result = cls(arr)
1334+
expected = cls(data)
1335+
tm.assert_extension_array_equal(result, expected)
1336+
1337+
result = cls._from_sequence(arr)
1338+
expected = cls._from_sequence(data)
1339+
tm.assert_extension_array_equal(result, expected)
1340+
1341+
func = {"M8[ns]": sequence_to_dt64ns, "m8[ns]": sequence_to_td64ns}[dtype]
1342+
result = func(arr)[0]
1343+
expected = func(data)[0]
1344+
tm.assert_equal(result, expected)
1345+
1346+
func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype]
1347+
result = func(arr).array
1348+
expected = func(data).array
1349+
tm.assert_equal(result, expected)
1350+
1351+
# Let's check the Indexes while we're here
1352+
idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype]
1353+
result = idx_cls(arr)
1354+
expected = idx_cls(data)
1355+
tm.assert_index_equal(result, expected)
1356+
1357+
1358+
@pytest.fixture(
1359+
params=[
1360+
"memoryview",
1361+
"array",
1362+
pytest.param("dask", marks=td.skip_if_no("dask.array")),
1363+
pytest.param("xarray", marks=td.skip_if_no("xarray")),
1364+
]
1365+
)
1366+
def array_likes(request):
1367+
# GH#24539 recognize e.g xarray, dask, ...
1368+
arr = np.array([1, 2, 3], dtype=np.int64)
1369+
1370+
name = request.param
1371+
if name == "memoryview":
1372+
data = memoryview(arr)
1373+
elif name == "array":
1374+
# stdlib array
1375+
from array import array
1376+
1377+
data = array("i", arr)
1378+
elif name == "dask":
1379+
import dask.array
1380+
1381+
data = dask.array.array(arr)
1382+
elif name == "xarray":
1383+
import xarray as xr
1384+
1385+
data = xr.DataArray(arr)
1386+
1387+
return arr, data
1388+
1389+
1390+
@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"])
1391+
def test_from_obscure_array(dtype, array_likes):
1392+
# GH#24539 recognize e.g xarray, dask, ...
1393+
# Note: we dont do this for PeriodArray bc _from_sequence won't accept
1394+
# an array of integers
1395+
# TODO: could check with arraylike of Period objects
1396+
arr, data = array_likes
1397+
1398+
cls = {"M8[ns]": DatetimeArray, "m8[ns]": TimedeltaArray}[dtype]
1399+
1400+
expected = cls(arr)
1401+
result = cls._from_sequence(data)
1402+
tm.assert_extension_array_equal(result, expected)
1403+
1404+
func = {"M8[ns]": sequence_to_dt64ns, "m8[ns]": sequence_to_td64ns}[dtype]
1405+
result = func(arr)[0]
1406+
expected = func(data)[0]
1407+
tm.assert_equal(result, expected)
1408+
1409+
# FIXME: dask and memoryview both break on these
1410+
# func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype]
1411+
# result = func(arr).array
1412+
# expected = func(data).array
1413+
# tm.assert_equal(result, expected)
1414+
1415+
# Let's check the Indexes while we're here
1416+
idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype]
1417+
result = idx_cls(arr)
1418+
expected = idx_cls(data)
1419+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)