Skip to content

Commit 1d4f685

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
REF: share some constructor code (pandas-dev#47555)
* REF: share some constructor code * mypy fixup
1 parent e43dfde commit 1d4f685

File tree

3 files changed

+43
-55
lines changed

3 files changed

+43
-55
lines changed

pandas/core/arrays/datetimelike.py

+36
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@
9494
DatetimeTZDtype,
9595
ExtensionDtype,
9696
)
97+
from pandas.core.dtypes.generic import (
98+
ABCCategorical,
99+
ABCMultiIndex,
100+
)
97101
from pandas.core.dtypes.missing import (
98102
is_valid_na_for_dtype,
99103
isna,
@@ -114,6 +118,8 @@
114118
NDArrayBackedExtensionArray,
115119
ravel_compat,
116120
)
121+
from pandas.core.arrays.base import ExtensionArray
122+
from pandas.core.arrays.integer import IntegerArray
117123
import pandas.core.common as com
118124
from pandas.core.construction import (
119125
array as pd_array,
@@ -2024,6 +2030,36 @@ def factorize( # type:ignore[override]
20242030
# Shared Constructor Helpers
20252031

20262032

2033+
def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str):
2034+
if not hasattr(data, "dtype"):
2035+
# e.g. list, tuple
2036+
if np.ndim(data) == 0:
2037+
# i.e. generator
2038+
data = list(data)
2039+
data = np.asarray(data)
2040+
copy = False
2041+
elif isinstance(data, ABCMultiIndex):
2042+
raise TypeError(f"Cannot create a {cls_name} from a MultiIndex.")
2043+
else:
2044+
data = extract_array(data, extract_numpy=True)
2045+
2046+
if isinstance(data, IntegerArray):
2047+
data = data.to_numpy("int64", na_value=iNaT)
2048+
copy = False
2049+
elif not isinstance(data, (np.ndarray, ExtensionArray)):
2050+
# GH#24539 e.g. xarray, dask object
2051+
data = np.asarray(data)
2052+
2053+
elif isinstance(data, ABCCategorical):
2054+
# GH#18664 preserve tz in going DTI->Categorical->DTI
2055+
# TODO: cases where we need to do another pass through maybe_convert_dtype,
2056+
# e.g. the categories are timedelta64s
2057+
data = data.categories.take(data.codes, fill_value=NaT)._values
2058+
copy = False
2059+
2060+
return data, copy
2061+
2062+
20272063
def validate_periods(periods):
20282064
"""
20292065
If a `periods` argument is passed to the Datetime/Timedelta Array/Index

pandas/core/arrays/datetimes.py

+4-30
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
DT64NS_DTYPE,
5454
INT64_DTYPE,
5555
is_bool_dtype,
56-
is_categorical_dtype,
5756
is_datetime64_any_dtype,
5857
is_datetime64_dtype,
5958
is_datetime64_ns_dtype,
@@ -69,13 +68,9 @@
6968
pandas_dtype,
7069
)
7170
from pandas.core.dtypes.dtypes import DatetimeTZDtype
72-
from pandas.core.dtypes.generic import ABCMultiIndex
7371
from pandas.core.dtypes.missing import isna
7472

75-
from pandas.core.arrays import (
76-
ExtensionArray,
77-
datetimelike as dtl,
78-
)
73+
from pandas.core.arrays import datetimelike as dtl
7974
from pandas.core.arrays._ranges import generate_regular_range
8075
from pandas.core.arrays.integer import IntegerArray
8176
import pandas.core.common as com
@@ -2064,23 +2059,9 @@ def _sequence_to_dt64ns(
20642059
# if dtype has an embedded tz, capture it
20652060
tz = validate_tz_from_dtype(dtype, tz)
20662061

2067-
if not hasattr(data, "dtype"):
2068-
# e.g. list, tuple
2069-
if np.ndim(data) == 0:
2070-
# i.e. generator
2071-
data = list(data)
2072-
data = np.asarray(data)
2073-
copy = False
2074-
elif isinstance(data, ABCMultiIndex):
2075-
raise TypeError("Cannot create a DatetimeArray from a MultiIndex.")
2076-
else:
2077-
data = extract_array(data, extract_numpy=True)
2078-
2079-
if isinstance(data, IntegerArray):
2080-
data = data.to_numpy("int64", na_value=iNaT)
2081-
elif not isinstance(data, (np.ndarray, ExtensionArray)):
2082-
# GH#24539 e.g. xarray, dask object
2083-
data = np.asarray(data)
2062+
data, copy = dtl.ensure_arraylike_for_datetimelike(
2063+
data, copy, cls_name="DatetimeArray"
2064+
)
20842065

20852066
if isinstance(data, DatetimeArray):
20862067
inferred_freq = data.freq
@@ -2320,13 +2301,6 @@ def maybe_convert_dtype(data, copy: bool, tz: tzinfo | None = None):
23202301
"Passing PeriodDtype data is invalid. Use `data.to_timestamp()` instead"
23212302
)
23222303

2323-
elif is_categorical_dtype(data.dtype):
2324-
# GH#18664 preserve tz in going DTI->Categorical->DTI
2325-
# TODO: cases where we need to do another pass through this func,
2326-
# e.g. the categories are timedelta64s
2327-
data = data.categories.take(data.codes, fill_value=NaT)._values
2328-
copy = False
2329-
23302304
elif is_extension_array_dtype(data.dtype) and not is_datetime64tz_dtype(data.dtype):
23312305
# TODO: We have no tests for these
23322306
data = np.array(data, dtype=np.object_)

pandas/core/arrays/timedeltas.py

+3-25
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,10 @@
5151
is_timedelta64_dtype,
5252
pandas_dtype,
5353
)
54-
from pandas.core.dtypes.generic import (
55-
ABCCategorical,
56-
ABCMultiIndex,
57-
)
5854
from pandas.core.dtypes.missing import isna
5955

6056
from pandas.core import nanops
6157
from pandas.core.arrays import (
62-
ExtensionArray,
6358
IntegerArray,
6459
datetimelike as dtl,
6560
)
@@ -936,26 +931,9 @@ def sequence_to_td64ns(
936931
if unit is not None:
937932
unit = parse_timedelta_unit(unit)
938933

939-
# Unwrap whatever we have into a np.ndarray
940-
if not hasattr(data, "dtype"):
941-
# e.g. list, tuple
942-
if np.ndim(data) == 0:
943-
# i.e. generator
944-
data = list(data)
945-
data = np.array(data, copy=False)
946-
elif isinstance(data, ABCMultiIndex):
947-
raise TypeError("Cannot create a TimedeltaArray from a MultiIndex.")
948-
else:
949-
data = extract_array(data, extract_numpy=True)
950-
951-
if isinstance(data, IntegerArray):
952-
data = data.to_numpy("int64", na_value=iNaT)
953-
elif not isinstance(data, (np.ndarray, ExtensionArray)):
954-
# GH#24539 e.g. xarray, dask object
955-
data = np.asarray(data)
956-
elif isinstance(data, ABCCategorical):
957-
data = data.categories.take(data.codes, fill_value=NaT)._values
958-
copy = False
934+
data, copy = dtl.ensure_arraylike_for_datetimelike(
935+
data, copy, cls_name="TimedeltaArray"
936+
)
959937

960938
if isinstance(data, TimedeltaArray):
961939
inferred_freq = data.freq

0 commit comments

Comments
 (0)