Skip to content

Commit 6b6d8fd

Browse files
authored
REF: Back DatetimeTZBlock with sometimes-2D DTA (#41082)
1 parent a15aff6 commit 6b6d8fd

File tree

15 files changed

+183
-103
lines changed

15 files changed

+183
-103
lines changed

pandas/core/dtypes/common.py

+27
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,33 @@ def is_extension_type(arr) -> bool:
14131413
return False
14141414

14151415

1416+
def is_1d_only_ea_obj(obj: Any) -> bool:
1417+
"""
1418+
ExtensionArray that does not support 2D, or more specifically that does
1419+
not use HybridBlock.
1420+
"""
1421+
from pandas.core.arrays import (
1422+
DatetimeArray,
1423+
ExtensionArray,
1424+
TimedeltaArray,
1425+
)
1426+
1427+
return isinstance(obj, ExtensionArray) and not isinstance(
1428+
obj, (DatetimeArray, TimedeltaArray)
1429+
)
1430+
1431+
1432+
def is_1d_only_ea_dtype(dtype: Optional[DtypeObj]) -> bool:
1433+
"""
1434+
Analogue to is_extension_array_dtype but excluding DatetimeTZDtype.
1435+
"""
1436+
# Note: if other EA dtypes are ever held in HybridBlock, exclude those
1437+
# here too.
1438+
# NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype
1439+
# to exclude ArrowTimestampUSDtype
1440+
return isinstance(dtype, ExtensionDtype) and not isinstance(dtype, DatetimeTZDtype)
1441+
1442+
14161443
def is_extension_array_dtype(arr_or_dtype) -> bool:
14171444
"""
14181445
Check if an object is a pandas extension array type.

pandas/core/dtypes/concat.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,15 @@ def is_nonempty(x) -> bool:
113113
to_concat = non_empties
114114

115115
kinds = {obj.dtype.kind for obj in to_concat}
116+
contains_datetime = any(kind in ["m", "M"] for kind in kinds)
116117

117118
all_empty = not len(non_empties)
118119
single_dtype = len({x.dtype for x in to_concat}) == 1
119120
any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)
120121

122+
if contains_datetime:
123+
return _concat_datetime(to_concat, axis=axis)
124+
121125
if any_ea:
122126
# we ignore axis here, as internally concatting with EAs is always
123127
# for axis=0
@@ -131,9 +135,6 @@ def is_nonempty(x) -> bool:
131135
else:
132136
return np.concatenate(to_concat)
133137

134-
elif any(kind in ["m", "M"] for kind in kinds):
135-
return _concat_datetime(to_concat, axis=axis)
136-
137138
elif all_empty:
138139
# we have all empties, but may need to coerce the result dtype to
139140
# object if we have non-numeric type operands (numpy would otherwise
@@ -349,14 +350,5 @@ def _concat_datetime(to_concat, axis=0):
349350
# in Timestamp/Timedelta
350351
return _concatenate_2d([x.astype(object) for x in to_concat], axis=axis)
351352

352-
if axis == 1:
353-
# TODO(EA2D): kludge not necessary with 2D EAs
354-
to_concat = [x.reshape(1, -1) if x.ndim == 1 else x for x in to_concat]
355-
356353
result = type(to_concat[0])._concat_same_type(to_concat, axis=axis)
357-
358-
if result.ndim == 2 and isinstance(result.dtype, ExtensionDtype):
359-
# TODO(EA2D): kludge not necessary with 2D EAs
360-
assert result.shape[0] == 1
361-
result = result[0]
362354
return result

pandas/core/frame.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
from pandas.core.dtypes.common import (
9999
ensure_platform_int,
100100
infer_dtype_from_object,
101+
is_1d_only_ea_dtype,
101102
is_bool_dtype,
102103
is_dataclass,
103104
is_datetime64_any_dtype,
@@ -845,7 +846,9 @@ def _can_fast_transpose(self) -> bool:
845846
if len(blocks) != 1:
846847
return False
847848

848-
return not self._mgr.any_extension_types
849+
dtype = blocks[0].dtype
850+
# TODO(EA2D) special case would be unnecessary with 2D EAs
851+
return not is_1d_only_ea_dtype(dtype)
849852

850853
# ----------------------------------------------------------------------
851854
# Rendering Methods

pandas/core/internals/api.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
2) Use only functions exposed here (or in core.internals)
77
88
"""
9-
from typing import Optional
9+
from __future__ import annotations
1010

1111
import numpy as np
1212

@@ -23,14 +23,15 @@
2323
Block,
2424
DatetimeTZBlock,
2525
check_ndim,
26+
ensure_block_shape,
2627
extract_pandas_array,
2728
get_block_type,
2829
maybe_coerce_values,
2930
)
3031

3132

3233
def make_block(
33-
values, placement, klass=None, ndim=None, dtype: Optional[Dtype] = None
34+
values, placement, klass=None, ndim=None, dtype: Dtype | None = None
3435
) -> Block:
3536
"""
3637
This is a pseudo-public analogue to blocks.new_block.
@@ -48,24 +49,29 @@ def make_block(
4849

4950
values, dtype = extract_pandas_array(values, dtype, ndim)
5051

52+
needs_reshape = False
5153
if klass is None:
5254
dtype = dtype or values.dtype
5355
klass = get_block_type(values, dtype)
5456

5557
elif klass is DatetimeTZBlock and not is_datetime64tz_dtype(values.dtype):
5658
# pyarrow calls get here
5759
values = DatetimeArray._simple_new(values, dtype=dtype)
60+
needs_reshape = True
5861

5962
if not isinstance(placement, BlockPlacement):
6063
placement = BlockPlacement(placement)
6164

6265
ndim = maybe_infer_ndim(values, placement, ndim)
66+
if needs_reshape:
67+
values = ensure_block_shape(values, ndim)
68+
6369
check_ndim(values, placement, ndim)
6470
values = maybe_coerce_values(values)
6571
return klass(values, ndim=ndim, placement=placement)
6672

6773

68-
def maybe_infer_ndim(values, placement: BlockPlacement, ndim: Optional[int]) -> int:
74+
def maybe_infer_ndim(values, placement: BlockPlacement, ndim: int | None) -> int:
6975
"""
7076
If `ndim` is not provided, infer it from placment and values.
7177
"""

pandas/core/internals/blocks.py

+25-33
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
soft_convert_objects,
4343
)
4444
from pandas.core.dtypes.common import (
45+
is_1d_only_ea_dtype,
46+
is_1d_only_ea_obj,
4547
is_categorical_dtype,
4648
is_dtype_equal,
4749
is_extension_array_dtype,
@@ -224,7 +226,6 @@ def get_values(self, dtype: DtypeObj | None = None) -> np.ndarray:
224226
# expected "ndarray")
225227
return self.values # type: ignore[return-value]
226228

227-
@final
228229
def get_block_values_for_json(self) -> np.ndarray:
229230
"""
230231
This is used in the JSON C code.
@@ -415,7 +416,11 @@ def _split_op_result(self, result) -> list[Block]:
415416
# if we get a 2D ExtensionArray, we need to split it into 1D pieces
416417
nbs = []
417418
for i, loc in enumerate(self._mgr_locs):
418-
vals = result[i]
419+
if not is_1d_only_ea_obj(result):
420+
vals = result[i : i + 1]
421+
else:
422+
vals = result[i]
423+
419424
block = self.make_block(values=vals, placement=loc)
420425
nbs.append(block)
421426
return nbs
@@ -1670,7 +1675,7 @@ class NumericBlock(NumpyBlock):
16701675
is_numeric = True
16711676

16721677

1673-
class NDArrayBackedExtensionBlock(EABackedBlock):
1678+
class NDArrayBackedExtensionBlock(libinternals.Block, EABackedBlock):
16741679
"""
16751680
Block backed by an NDArrayBackedExtensionArray
16761681
"""
@@ -1683,11 +1688,6 @@ def is_view(self) -> bool:
16831688
# check the ndarray values of the DatetimeIndex values
16841689
return self.values._ndarray.base is not None
16851690

1686-
def iget(self, key):
1687-
# GH#31649 we need to wrap scalars in Timestamp/Timedelta
1688-
# TODO(EA2D): this can be removed if we ever have 2D EA
1689-
return self.values.reshape(self.shape)[key]
1690-
16911691
def setitem(self, indexer, value):
16921692
if not self._can_hold_element(value):
16931693
# TODO: general case needs casting logic.
@@ -1707,24 +1707,21 @@ def putmask(self, mask, new) -> list[Block]:
17071707
if not self._can_hold_element(new):
17081708
return self.astype(object).putmask(mask, new)
17091709

1710-
# TODO(EA2D): reshape unnecessary with 2D EAs
1711-
arr = self.values.reshape(self.shape)
1710+
arr = self.values
17121711
arr.T.putmask(mask, new)
17131712
return [self]
17141713

17151714
def where(self, other, cond, errors="raise") -> list[Block]:
17161715
# TODO(EA2D): reshape unnecessary with 2D EAs
1717-
arr = self.values.reshape(self.shape)
1716+
arr = self.values
17181717

17191718
cond = extract_bool_array(cond)
17201719

17211720
try:
17221721
res_values = arr.T.where(cond, other).T
17231722
except (ValueError, TypeError):
1724-
return super().where(other, cond, errors=errors)
1723+
return Block.where(self, other, cond, errors=errors)
17251724

1726-
# TODO(EA2D): reshape not needed with 2D EAs
1727-
res_values = res_values.reshape(self.values.shape)
17281725
nb = self.make_block_same_class(res_values)
17291726
return [nb]
17301727

@@ -1748,15 +1745,13 @@ def diff(self, n: int, axis: int = 0) -> list[Block]:
17481745
The arguments here are mimicking shift so they are called correctly
17491746
by apply.
17501747
"""
1751-
# TODO(EA2D): reshape not necessary with 2D EAs
1752-
values = self.values.reshape(self.shape)
1748+
values = self.values
17531749

17541750
new_values = values - values.shift(n, axis=axis)
17551751
return [self.make_block(new_values)]
17561752

17571753
def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Block]:
1758-
# TODO(EA2D) this is unnecessary if these blocks are backed by 2D EAs
1759-
values = self.values.reshape(self.shape)
1754+
values = self.values
17601755
new_values = values.shift(periods, fill_value=fill_value, axis=axis)
17611756
return [self.make_block_same_class(new_values)]
17621757

@@ -1776,31 +1771,27 @@ def fillna(
17761771
return [self.make_block_same_class(values=new_values)]
17771772

17781773

1779-
class DatetimeLikeBlock(libinternals.Block, NDArrayBackedExtensionBlock):
1774+
class DatetimeLikeBlock(NDArrayBackedExtensionBlock):
17801775
"""Block for datetime64[ns], timedelta64[ns]."""
17811776

17821777
__slots__ = ()
17831778
is_numeric = False
17841779
values: DatetimeArray | TimedeltaArray
17851780

1781+
def get_block_values_for_json(self):
1782+
# Not necessary to override, but helps perf
1783+
return self.values._ndarray
17861784

1787-
class DatetimeTZBlock(ExtensionBlock, NDArrayBackedExtensionBlock):
1785+
1786+
class DatetimeTZBlock(DatetimeLikeBlock):
17881787
""" implement a datetime64 block with a tz attribute """
17891788

17901789
values: DatetimeArray
17911790

17921791
__slots__ = ()
17931792
is_extension = True
1794-
is_numeric = False
1795-
1796-
diff = NDArrayBackedExtensionBlock.diff
1797-
where = NDArrayBackedExtensionBlock.where
1798-
putmask = NDArrayBackedExtensionBlock.putmask
1799-
fillna = NDArrayBackedExtensionBlock.fillna
1800-
1801-
get_values = NDArrayBackedExtensionBlock.get_values
1802-
1803-
is_view = NDArrayBackedExtensionBlock.is_view
1793+
_validate_ndim = True
1794+
_can_consolidate = False
18041795

18051796

18061797
class ObjectBlock(NumpyBlock):
@@ -1967,7 +1958,7 @@ def check_ndim(values, placement: BlockPlacement, ndim: int):
19671958
f"values.ndim > ndim [{values.ndim} > {ndim}]"
19681959
)
19691960

1970-
elif isinstance(values.dtype, np.dtype):
1961+
elif not is_1d_only_ea_dtype(values.dtype):
19711962
# TODO(EA2D): special case not needed with 2D EAs
19721963
if values.ndim != ndim:
19731964
raise ValueError(
@@ -1981,7 +1972,7 @@ def check_ndim(values, placement: BlockPlacement, ndim: int):
19811972
)
19821973
elif ndim == 2 and len(placement) != 1:
19831974
# TODO(EA2D): special case unnecessary with 2D EAs
1984-
raise AssertionError("block.size != values.size")
1975+
raise ValueError("need to split")
19851976

19861977

19871978
def extract_pandas_array(
@@ -2026,8 +2017,9 @@ def ensure_block_shape(values: ArrayLike, ndim: int = 1) -> ArrayLike:
20262017
"""
20272018
Reshape if possible to have values.ndim == ndim.
20282019
"""
2020+
20292021
if values.ndim < ndim:
2030-
if not is_extension_array_dtype(values.dtype):
2022+
if not is_1d_only_ea_dtype(values.dtype):
20312023
# TODO(EA2D): https://github.com/pandas-dev/pandas/issues/23023
20322024
# block.shape is incorrect for "2D" ExtensionArrays
20332025
# We can't, and don't need to, reshape.

pandas/core/internals/concat.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import (
66
TYPE_CHECKING,
77
Sequence,
8+
cast,
89
)
910

1011
import numpy as np
@@ -23,6 +24,8 @@
2324
find_common_type,
2425
)
2526
from pandas.core.dtypes.common import (
27+
is_1d_only_ea_dtype,
28+
is_1d_only_ea_obj,
2629
is_datetime64tz_dtype,
2730
is_dtype_equal,
2831
is_extension_array_dtype,
@@ -210,8 +213,8 @@ def concatenate_managers(
210213
values = np.concatenate(vals, axis=blk.ndim - 1)
211214
else:
212215
# TODO(EA2D): special-casing not needed with 2D EAs
213-
values = concat_compat(vals)
214-
values = ensure_block_shape(values, ndim=2)
216+
values = concat_compat(vals, axis=1)
217+
values = ensure_block_shape(values, blk.ndim)
215218

216219
values = ensure_wrapped_if_datetimelike(values)
217220

@@ -412,13 +415,16 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
412415
fill_value = None
413416

414417
if is_datetime64tz_dtype(empty_dtype):
415-
# TODO(EA2D): special case unneeded with 2D EAs
416-
i8values = np.full(self.shape[1], fill_value.value)
418+
i8values = np.full(self.shape, fill_value.value)
417419
return DatetimeArray(i8values, dtype=empty_dtype)
420+
418421
elif is_extension_array_dtype(blk_dtype):
419422
pass
420-
elif isinstance(empty_dtype, ExtensionDtype):
423+
424+
elif is_1d_only_ea_dtype(empty_dtype):
425+
empty_dtype = cast(ExtensionDtype, empty_dtype)
421426
cls = empty_dtype.construct_array_type()
427+
422428
missing_arr = cls._from_sequence([], dtype=empty_dtype)
423429
ncols, nrows = self.shape
424430
assert ncols == 1, ncols
@@ -429,6 +435,7 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
429435
else:
430436
# NB: we should never get here with empty_dtype integer or bool;
431437
# if we did, the missing_arr.fill would cast to gibberish
438+
empty_dtype = cast(np.dtype, empty_dtype)
432439

433440
missing_arr = np.empty(self.shape, dtype=empty_dtype)
434441
missing_arr.fill(fill_value)
@@ -493,15 +500,17 @@ def _concatenate_join_units(
493500
concat_values = concat_values.copy()
494501
else:
495502
concat_values = concat_values.copy()
496-
elif any(isinstance(t, ExtensionArray) and t.ndim == 1 for t in to_concat):
503+
504+
elif any(is_1d_only_ea_obj(t) for t in to_concat):
505+
# TODO(EA2D): special case not needed if all EAs used HybridBlocks
506+
# NB: we are still assuming here that Hybrid blocks have shape (1, N)
497507
# concatting with at least one EA means we are concatting a single column
498508
# the non-EA values are 2D arrays with shape (1, n)
509+
499510
# error: Invalid index type "Tuple[int, slice]" for
500511
# "Union[ExtensionArray, ndarray]"; expected type "Union[int, slice, ndarray]"
501512
to_concat = [
502-
t
503-
if (isinstance(t, ExtensionArray) and t.ndim == 1)
504-
else t[0, :] # type: ignore[index]
513+
t if is_1d_only_ea_obj(t) else t[0, :] # type: ignore[index]
505514
for t in to_concat
506515
]
507516
concat_values = concat_compat(to_concat, axis=0, ea_compat_axis=True)

0 commit comments

Comments
 (0)