Skip to content

Commit a3bb751

Browse files
authored
REF: back DatetimeBlock, TimedeltaBlock by DTA/TDA (#40456)
1 parent 83b16b5 commit a3bb751

File tree

16 files changed

+181
-85
lines changed

16 files changed

+181
-85
lines changed

pandas/core/array_algos/quantile.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,14 @@ def _quantile_ea_compat(
133133

134134
if not is_sparse(orig.dtype):
135135
# shape[0] should be 1 as long as EAs are 1D
136-
assert result.shape == (1, len(qs)), result.shape
137-
result = type(orig)._from_factorized(result[0], orig)
136+
137+
if orig.ndim == 2:
138+
# i.e. DatetimeArray
139+
result = type(orig)._from_factorized(result, orig)
140+
141+
else:
142+
assert result.shape == (1, len(qs)), result.shape
143+
result = type(orig)._from_factorized(result[0], orig)
138144

139145
# error: Incompatible return value type (got "ndarray", expected "ExtensionArray")
140146
return result # type: ignore[return-value]

pandas/core/array_algos/take.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
from typing import (
55
TYPE_CHECKING,
6+
cast,
67
overload,
78
)
89

@@ -21,6 +22,7 @@
2122
from pandas.core.construction import ensure_wrapped_if_datetimelike
2223

2324
if TYPE_CHECKING:
25+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
2426
from pandas.core.arrays.base import ExtensionArray
2527

2628

@@ -89,7 +91,12 @@ def take_nd(
8991

9092
if not isinstance(arr, np.ndarray):
9193
# i.e. ExtensionArray,
92-
# includes for EA to catch DatetimeArray, TimedeltaArray
94+
if arr.ndim == 2:
95+
# e.g. DatetimeArray, TimedeltArray
96+
arr = cast("NDArrayBackedExtensionArray", arr)
97+
return arr.take(
98+
indexer, fill_value=fill_value, allow_fill=allow_fill, axis=axis
99+
)
93100
return arr.take(indexer, fill_value=fill_value, allow_fill=allow_fill)
94101

95102
arr = np.asarray(arr)

pandas/core/arrays/_mixins.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
cache_readonly,
2525
doc,
2626
)
27-
from pandas.util._validators import validate_fillna_kwargs
27+
from pandas.util._validators import (
28+
validate_bool_kwarg,
29+
validate_fillna_kwargs,
30+
)
2831

2932
from pandas.core.dtypes.common import is_dtype_equal
3033
from pandas.core.dtypes.missing import array_equivalent
@@ -39,6 +42,7 @@
3942
from pandas.core.arrays.base import ExtensionArray
4043
from pandas.core.construction import extract_array
4144
from pandas.core.indexers import check_array_indexer
45+
from pandas.core.sorting import nargminmax
4246

4347
NDArrayBackedExtensionArrayT = TypeVar(
4448
"NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
@@ -189,6 +193,22 @@ def equals(self, other) -> bool:
189193
def _values_for_argsort(self):
190194
return self._ndarray
191195

196+
# Signature of "argmin" incompatible with supertype "ExtensionArray"
197+
def argmin(self, axis: int = 0, skipna: bool = True): # type:ignore[override]
198+
# override base class by adding axis keyword
199+
validate_bool_kwarg(skipna, "skipna")
200+
if not skipna and self.isna().any():
201+
raise NotImplementedError
202+
return nargminmax(self, "argmin", axis=axis)
203+
204+
# Signature of "argmax" incompatible with supertype "ExtensionArray"
205+
def argmax(self, axis: int = 0, skipna: bool = True): # type:ignore[override]
206+
# override base class by adding axis keyword
207+
validate_bool_kwarg(skipna, "skipna")
208+
if not skipna and self.isna().any():
209+
raise NotImplementedError
210+
return nargminmax(self, "argmax", axis=axis)
211+
192212
def copy(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
193213
new_data = self._ndarray.copy()
194214
return self._from_backing_data(new_data)

pandas/core/frame.py

+3
Original file line numberDiff line numberDiff line change
@@ -9544,6 +9544,9 @@ def func(values: np.ndarray):
95449544

95459545
def blk_func(values, axis=1):
95469546
if isinstance(values, ExtensionArray):
9547+
if values.ndim == 2:
9548+
# i.e. DatetimeArray, TimedeltaArray
9549+
return values._reduce(name, axis=1, skipna=skipna, **kwds)
95479550
return values._reduce(name, skipna=skipna, **kwds)
95489551
else:
95499552
return op(values, axis=axis, skipna=skipna, **kwds)

pandas/core/groupby/ops.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
maybe_fill,
7373
)
7474

75+
from pandas.core.arrays import ExtensionArray
7576
from pandas.core.base import SelectionMixin
7677
import pandas.core.common as com
7778
from pandas.core.frame import DataFrame
@@ -267,7 +268,9 @@ def apply(self, f: F, data: FrameOrSeries, axis: int = 0):
267268
group_keys = self._get_group_keys()
268269
result_values = None
269270

270-
if data.ndim == 2 and np.any(data.dtypes.apply(is_extension_array_dtype)):
271+
if data.ndim == 2 and any(
272+
isinstance(x, ExtensionArray) for x in data._iter_column_arrays()
273+
):
271274
# calling splitter.fast_apply will raise TypeError via apply_frame_axis0
272275
# if we pass EA instead of ndarray
273276
# TODO: can we have a workaround for EAs backed by ndarray?

pandas/core/internals/array_manager.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,12 @@ def apply_with_block(self: T, f, align_keys=None, swap_axis=True, **kwargs) -> T
493493
if isinstance(applied, list):
494494
applied = applied[0]
495495
arr = applied.values
496-
if self.ndim == 2:
497-
if isinstance(arr, np.ndarray):
498-
arr = arr[0, :]
496+
if self.ndim == 2 and arr.ndim == 2:
497+
assert len(arr) == 1
498+
# error: Invalid index type "Tuple[int, slice]" for
499+
# "Union[ndarray, ExtensionArray]"; expected type
500+
# "Union[int, slice, ndarray]"
501+
arr = arr[0, :] # type: ignore[index]
499502
result_arrays.append(arr)
500503

501504
return type(self)(result_arrays, self._axes)

pandas/core/internals/blocks.py

+35-52
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
writers,
2828
)
2929
from pandas._libs.internals import BlockPlacement
30-
from pandas._libs.tslibs import conversion
3130
from pandas._typing import (
3231
ArrayLike,
3332
Dtype,
@@ -47,7 +46,6 @@
4746
maybe_downcast_numeric,
4847
maybe_downcast_to_dtype,
4948
maybe_upcast,
50-
sanitize_to_nanoseconds,
5149
soft_convert_objects,
5250
)
5351
from pandas.core.dtypes.common import (
@@ -938,7 +936,11 @@ def setitem(self, indexer, value):
938936
return self.coerce_to_target_dtype(value).setitem(indexer, value)
939937

940938
if self.dtype.kind in ["m", "M"]:
941-
arr = self.array_values.T
939+
arr = self.values
940+
if self.ndim > 1:
941+
# Dont transpose with ndim=1 bc we would fail to invalidate
942+
# arr.freq
943+
arr = arr.T
942944
arr[indexer] = value
943945
return self
944946

@@ -1172,6 +1174,7 @@ def _interpolate_with_fill(
11721174
limit_area=limit_area,
11731175
)
11741176

1177+
values = maybe_coerce_values(values)
11751178
blocks = [self.make_block_same_class(values)]
11761179
return self._maybe_downcast(blocks, downcast)
11771180

@@ -1227,6 +1230,7 @@ def func(yvalues: np.ndarray) -> np.ndarray:
12271230

12281231
# interp each column independently
12291232
interp_values = np.apply_along_axis(func, axis, data)
1233+
interp_values = maybe_coerce_values(interp_values)
12301234

12311235
blocks = [self.make_block_same_class(interp_values)]
12321236
return self._maybe_downcast(blocks, downcast)
@@ -1788,27 +1792,32 @@ class NDArrayBackedExtensionBlock(HybridMixin, Block):
17881792
Block backed by an NDArrayBackedExtensionArray
17891793
"""
17901794

1795+
values: NDArrayBackedExtensionArray
1796+
1797+
@property
1798+
def is_view(self) -> bool:
1799+
""" return a boolean if I am possibly a view """
1800+
# check the ndarray values of the DatetimeIndex values
1801+
return self.values._ndarray.base is not None
1802+
17911803
def internal_values(self):
17921804
# Override to return DatetimeArray and TimedeltaArray
1793-
return self.array_values
1805+
return self.values
17941806

17951807
def get_values(self, dtype: Optional[DtypeObj] = None) -> np.ndarray:
17961808
"""
17971809
return object dtype as boxed values, such as Timestamps/Timedelta
17981810
"""
1799-
values = self.array_values
1811+
values = self.values
18001812
if is_object_dtype(dtype):
1801-
# DTA/TDA constructor and astype can handle 2D
1802-
# error: "Callable[..., Any]" has no attribute "astype"
1803-
values = values.astype(object) # type: ignore[attr-defined]
1813+
values = values.astype(object)
18041814
# TODO(EA2D): reshape not needed with 2D EAs
18051815
return np.asarray(values).reshape(self.shape)
18061816

18071817
def iget(self, key):
18081818
# GH#31649 we need to wrap scalars in Timestamp/Timedelta
18091819
# TODO(EA2D): this can be removed if we ever have 2D EA
1810-
# error: "Callable[..., Any]" has no attribute "reshape"
1811-
return self.array_values.reshape(self.shape)[key] # type: ignore[attr-defined]
1820+
return self.values.reshape(self.shape)[key]
18121821

18131822
def putmask(self, mask, new) -> List[Block]:
18141823
mask = extract_bool_array(mask)
@@ -1817,16 +1826,13 @@ def putmask(self, mask, new) -> List[Block]:
18171826
return self.astype(object).putmask(mask, new)
18181827

18191828
# TODO(EA2D): reshape unnecessary with 2D EAs
1820-
# error: "Callable[..., Any]" has no attribute "reshape"
1821-
arr = self.array_values.reshape(self.shape) # type: ignore[attr-defined]
1822-
arr = cast("NDArrayBackedExtensionArray", arr)
1829+
arr = self.values.reshape(self.shape)
18231830
arr.T.putmask(mask, new)
18241831
return [self]
18251832

18261833
def where(self, other, cond, errors="raise") -> List[Block]:
18271834
# TODO(EA2D): reshape unnecessary with 2D EAs
1828-
# error: "Callable[..., Any]" has no attribute "reshape"
1829-
arr = self.array_values.reshape(self.shape) # type: ignore[attr-defined]
1835+
arr = self.values.reshape(self.shape)
18301836

18311837
cond = extract_bool_array(cond)
18321838

@@ -1837,7 +1843,6 @@ def where(self, other, cond, errors="raise") -> List[Block]:
18371843

18381844
# TODO(EA2D): reshape not needed with 2D EAs
18391845
res_values = res_values.reshape(self.values.shape)
1840-
res_values = maybe_coerce_values(res_values)
18411846
nb = self.make_block_same_class(res_values)
18421847
return [nb]
18431848

@@ -1862,19 +1867,15 @@ def diff(self, n: int, axis: int = 0) -> List[Block]:
18621867
by apply.
18631868
"""
18641869
# TODO(EA2D): reshape not necessary with 2D EAs
1865-
# error: "Callable[..., Any]" has no attribute "reshape"
1866-
values = self.array_values.reshape(self.shape) # type: ignore[attr-defined]
1870+
values = self.values.reshape(self.shape)
18671871

18681872
new_values = values - values.shift(n, axis=axis)
1869-
new_values = maybe_coerce_values(new_values)
18701873
return [self.make_block(new_values)]
18711874

18721875
def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> List[Block]:
1873-
# TODO(EA2D) this is unnecessary if these blocks are backed by 2D EA
1874-
# error: "Callable[..., Any]" has no attribute "reshape"
1875-
values = self.array_values.reshape(self.shape) # type: ignore[attr-defined]
1876+
# TODO(EA2D) this is unnecessary if these blocks are backed by 2D EAs
1877+
values = self.values.reshape(self.shape)
18761878
new_values = values.shift(periods, fill_value=fill_value, axis=axis)
1877-
new_values = maybe_coerce_values(new_values)
18781879
return [self.make_block_same_class(new_values)]
18791880

18801881
def fillna(
@@ -1887,38 +1888,27 @@ def fillna(
18871888
# TODO: don't special-case td64
18881889
return self.astype(object).fillna(value, limit, inplace, downcast)
18891890

1890-
values = self.array_values
1891-
# error: "Callable[..., Any]" has no attribute "copy"
1892-
values = values if inplace else values.copy() # type: ignore[attr-defined]
1893-
# error: "Callable[..., Any]" has no attribute "fillna"
1894-
new_values = values.fillna( # type: ignore[attr-defined]
1895-
value=value, limit=limit
1896-
)
1897-
new_values = maybe_coerce_values(new_values)
1891+
values = self.values
1892+
values = values if inplace else values.copy()
1893+
new_values = values.fillna(value=value, limit=limit)
18981894
return [self.make_block_same_class(values=new_values)]
18991895

19001896

19011897
class DatetimeLikeBlockMixin(NDArrayBackedExtensionBlock):
19021898
"""Mixin class for DatetimeBlock, DatetimeTZBlock, and TimedeltaBlock."""
19031899

1900+
values: Union[DatetimeArray, TimedeltaArray]
1901+
19041902
is_numeric = False
19051903

19061904
@cache_readonly
19071905
def array_values(self):
1908-
return ensure_wrapped_if_datetimelike(self.values)
1906+
return self.values
19091907

19101908

19111909
class DatetimeBlock(DatetimeLikeBlockMixin):
19121910
__slots__ = ()
19131911

1914-
def set_inplace(self, locs, values):
1915-
"""
1916-
See Block.set.__doc__
1917-
"""
1918-
values = conversion.ensure_datetime64ns(values, copy=False)
1919-
1920-
self.values[locs] = values
1921-
19221912

19231913
class DatetimeTZBlock(ExtensionBlock, DatetimeLikeBlockMixin):
19241914
""" implement a datetime64 block with a tz attribute """
@@ -1936,13 +1926,10 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeLikeBlockMixin):
19361926
putmask = DatetimeLikeBlockMixin.putmask
19371927
fillna = DatetimeLikeBlockMixin.fillna
19381928

1939-
array_values = ExtensionBlock.array_values
1940-
1941-
@property
1942-
def is_view(self) -> bool:
1943-
""" return a boolean if I am possibly a view """
1944-
# check the ndarray values of the DatetimeIndex values
1945-
return self.values._data.base is not None
1929+
# error: Incompatible types in assignment (expression has type
1930+
# "Callable[[NDArrayBackedExtensionBlock], bool]", base class "ExtensionBlock"
1931+
# defined the type as "bool") [assignment]
1932+
is_view = NDArrayBackedExtensionBlock.is_view # type: ignore[assignment]
19461933

19471934

19481935
class TimeDeltaBlock(DatetimeLikeBlockMixin):
@@ -2029,15 +2016,11 @@ def maybe_coerce_values(values) -> ArrayLike:
20292016
values = extract_array(values, extract_numpy=True)
20302017

20312018
if isinstance(values, np.ndarray):
2032-
values = sanitize_to_nanoseconds(values)
2019+
values = ensure_wrapped_if_datetimelike(values)
20332020

20342021
if issubclass(values.dtype.type, str):
20352022
values = np.array(values, dtype=object)
20362023

2037-
elif isinstance(values.dtype, np.dtype):
2038-
# i.e. not datetime64tz, extract DTA/TDA -> ndarray
2039-
values = values._data
2040-
20412024
return values
20422025

20432026

pandas/core/internals/concat.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,17 @@ def _concatenate_join_units(
423423
concat_values = concat_values.copy()
424424
else:
425425
concat_values = concat_values.copy()
426-
elif any(isinstance(t, ExtensionArray) for t in to_concat):
426+
elif any(isinstance(t, ExtensionArray) and t.ndim == 1 for t in to_concat):
427427
# concatting with at least one EA means we are concatting a single column
428428
# the non-EA values are 2D arrays with shape (1, n)
429-
to_concat = [t if isinstance(t, ExtensionArray) else t[0, :] for t in to_concat]
429+
# error: Invalid index type "Tuple[int, slice]" for
430+
# "Union[ExtensionArray, ndarray]"; expected type "Union[int, slice, ndarray]"
431+
to_concat = [
432+
t
433+
if (isinstance(t, ExtensionArray) and t.ndim == 1)
434+
else t[0, :] # type: ignore[index]
435+
for t in to_concat
436+
]
430437
concat_values = concat_compat(to_concat, axis=0, ea_compat_axis=True)
431438
concat_values = ensure_block_shape(concat_values, 2)
432439

0 commit comments

Comments
 (0)