Skip to content

Commit e1dd032

Browse files
authored
REF: combine Block _can_hold_element methods (#40709)
1 parent 65860fa commit e1dd032

File tree

5 files changed

+73
-72
lines changed

5 files changed

+73
-72
lines changed

pandas/core/dtypes/cast.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@
113113
from pandas.core.arrays import (
114114
DatetimeArray,
115115
ExtensionArray,
116+
IntervalArray,
117+
PeriodArray,
118+
TimedeltaArray,
116119
)
117120

118121
_int8_max = np.iinfo(np.int8).max
@@ -2169,32 +2172,68 @@ def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
21692172
raise ValueError(f"Cannot assign {type(value).__name__} to bool series")
21702173

21712174

2172-
def can_hold_element(dtype: np.dtype, element: Any) -> bool:
2175+
def can_hold_element(arr: ArrayLike, element: Any) -> bool:
21732176
"""
21742177
Can we do an inplace setitem with this element in an array with this dtype?
21752178
21762179
Parameters
21772180
----------
2178-
dtype : np.dtype
2181+
arr : np.ndarray or ExtensionArray
21792182
element : Any
21802183
21812184
Returns
21822185
-------
21832186
bool
21842187
"""
2188+
dtype = arr.dtype
2189+
if not isinstance(dtype, np.dtype) or dtype.kind in ["m", "M"]:
2190+
if isinstance(dtype, (PeriodDtype, IntervalDtype, DatetimeTZDtype, np.dtype)):
2191+
# np.dtype here catches datetime64ns and timedelta64ns; we assume
2192+
# in this case that we have DatetimeArray/TimedeltaArray
2193+
arr = cast(
2194+
"PeriodArray | DatetimeArray | TimedeltaArray | IntervalArray", arr
2195+
)
2196+
try:
2197+
arr._validate_setitem_value(element)
2198+
return True
2199+
except (ValueError, TypeError):
2200+
return False
2201+
2202+
# This is technically incorrect, but maintains the behavior of
2203+
# ExtensionBlock._can_hold_element
2204+
return True
2205+
21852206
tipo = maybe_infer_dtype_type(element)
21862207

21872208
if dtype.kind in ["i", "u"]:
21882209
if tipo is not None:
2189-
return tipo.kind in ["i", "u"] and dtype.itemsize >= tipo.itemsize
2210+
if tipo.kind not in ["i", "u"]:
2211+
# Anything other than integer we cannot hold
2212+
return False
2213+
elif dtype.itemsize < tipo.itemsize:
2214+
return False
2215+
elif not isinstance(tipo, np.dtype):
2216+
# i.e. nullable IntegerDtype; we can put this into an ndarray
2217+
# losslessly iff it has no NAs
2218+
return not element._mask.any()
2219+
return True
21902220

21912221
# We have not inferred an integer from the dtype
21922222
# check if we have a builtin int or a float equal to an int
21932223
return is_integer(element) or (is_float(element) and element.is_integer())
21942224

21952225
elif dtype.kind == "f":
21962226
if tipo is not None:
2197-
return tipo.kind in ["f", "i", "u"]
2227+
# TODO: itemsize check?
2228+
if tipo.kind not in ["f", "i", "u"]:
2229+
# Anything other than float/integer we cannot hold
2230+
return False
2231+
elif not isinstance(tipo, np.dtype):
2232+
# i.e. nullable IntegerDtype or FloatingDtype;
2233+
# we can put this into an ndarray losslessly iff it has no NAs
2234+
return not element._mask.any()
2235+
return True
2236+
21982237
return lib.is_integer(element) or lib.is_float(element)
21992238

22002239
elif dtype.kind == "c":
@@ -2212,4 +2251,11 @@ def can_hold_element(dtype: np.dtype, element: Any) -> bool:
22122251
elif dtype == object:
22132252
return True
22142253

2254+
elif dtype.kind == "S":
2255+
# TODO: test tests.frame.methods.test_replace tests get here,
2256+
# need more targeted tests. xref phofl has a PR about this
2257+
if tipo is not None:
2258+
return tipo.kind == "S" and tipo.itemsize <= dtype.itemsize
2259+
return isinstance(element, bytes) and len(element) <= dtype.itemsize
2260+
22152261
raise NotImplementedError(dtype)

pandas/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4482,7 +4482,7 @@ def _validate_fill_value(self, value):
44824482
TypeError
44834483
If the value cannot be inserted into an array of this dtype.
44844484
"""
4485-
if not can_hold_element(self.dtype, value):
4485+
if not can_hold_element(self._values, value):
44864486
raise TypeError
44874487
return value
44884488

pandas/core/internals/blocks.py

+10-57
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import (
66
TYPE_CHECKING,
77
Any,
8-
Callable,
98
List,
109
Optional,
1110
Tuple,
@@ -18,8 +17,6 @@
1817
import numpy as np
1918

2019
from pandas._libs import (
21-
Interval,
22-
Period,
2320
Timestamp,
2421
algos as libalgos,
2522
internals as libinternals,
@@ -102,6 +99,7 @@
10299
PeriodArray,
103100
TimedeltaArray,
104101
)
102+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
105103
from pandas.core.base import PandasObject
106104
import pandas.core.common as com
107105
import pandas.core.computation.expressions as expressions
@@ -122,7 +120,6 @@
122120
Float64Index,
123121
Index,
124122
)
125-
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
126123

127124
# comparison is faster than is_object_dtype
128125
_dtype_obj = np.dtype("object")
@@ -625,9 +622,11 @@ def convert(
625622
"""
626623
return [self.copy()] if copy else [self]
627624

625+
@final
628626
def _can_hold_element(self, element: Any) -> bool:
629627
""" require the same dtype as ourselves """
630-
raise NotImplementedError("Implemented on subclasses")
628+
element = extract_array(element, extract_numpy=True)
629+
return can_hold_element(self.values, element)
631630

632631
@final
633632
def should_store(self, value: ArrayLike) -> bool:
@@ -1545,7 +1544,7 @@ def setitem(self, indexer, value):
15451544
be a compatible shape.
15461545
"""
15471546
if not self._can_hold_element(value):
1548-
# This is only relevant for DatetimeTZBlock, ObjectValuesExtensionBlock,
1547+
# This is only relevant for DatetimeTZBlock, PeriodDtype, IntervalDtype,
15491548
# which has a non-trivial `_can_hold_element`.
15501549
# https://github.com/pandas-dev/pandas/issues/24020
15511550
# Need a dedicated setitem until GH#24020 (type promotion in setitem
@@ -1597,10 +1596,6 @@ def take_nd(
15971596

15981597
return self.make_block_same_class(new_values, new_mgr_locs)
15991598

1600-
def _can_hold_element(self, element: Any) -> bool:
1601-
# TODO: We may need to think about pushing this onto the array.
1602-
return True
1603-
16041599
def _slice(self, slicer):
16051600
"""
16061601
Return a slice of my values.
@@ -1746,54 +1741,22 @@ def _unstack(self, unstacker, fill_value, new_placement):
17461741
return blocks, mask
17471742

17481743

1749-
class HybridMixin:
1750-
"""
1751-
Mixin for Blocks backed (maybe indirectly) by ExtensionArrays.
1752-
"""
1753-
1754-
array_values: Callable
1755-
1756-
def _can_hold_element(self, element: Any) -> bool:
1757-
values = self.array_values
1758-
1759-
try:
1760-
# error: "Callable[..., Any]" has no attribute "_validate_setitem_value"
1761-
values._validate_setitem_value(element) # type: ignore[attr-defined]
1762-
return True
1763-
except (ValueError, TypeError):
1764-
return False
1765-
1766-
1767-
class ObjectValuesExtensionBlock(HybridMixin, ExtensionBlock):
1768-
"""
1769-
Block providing backwards-compatibility for `.values`.
1770-
1771-
Used by PeriodArray and IntervalArray to ensure that
1772-
Series[T].values is an ndarray of objects.
1773-
"""
1774-
1775-
pass
1776-
1777-
17781744
class NumericBlock(Block):
17791745
__slots__ = ()
17801746
is_numeric = True
17811747

1782-
def _can_hold_element(self, element: Any) -> bool:
1783-
element = extract_array(element, extract_numpy=True)
1784-
if isinstance(element, (IntegerArray, FloatingArray)):
1785-
if element._mask.any():
1786-
return False
1787-
return can_hold_element(self.dtype, element)
17881748

1789-
1790-
class NDArrayBackedExtensionBlock(HybridMixin, Block):
1749+
class NDArrayBackedExtensionBlock(Block):
17911750
"""
17921751
Block backed by an NDArrayBackedExtensionArray
17931752
"""
17941753

17951754
values: NDArrayBackedExtensionArray
17961755

1756+
@property
1757+
def array_values(self) -> NDArrayBackedExtensionArray:
1758+
return self.values
1759+
17971760
@property
17981761
def is_view(self) -> bool:
17991762
""" return a boolean if I am possibly a view """
@@ -1901,10 +1864,6 @@ class DatetimeLikeBlockMixin(NDArrayBackedExtensionBlock):
19011864

19021865
is_numeric = False
19031866

1904-
@cache_readonly
1905-
def array_values(self):
1906-
return self.values
1907-
19081867

19091868
class DatetimeBlock(DatetimeLikeBlockMixin):
19101869
__slots__ = ()
@@ -1920,7 +1879,6 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeLikeBlockMixin):
19201879
is_numeric = False
19211880

19221881
internal_values = Block.internal_values
1923-
_can_hold_element = DatetimeBlock._can_hold_element
19241882
diff = DatetimeBlock.diff
19251883
where = DatetimeBlock.where
19261884
putmask = DatetimeLikeBlockMixin.putmask
@@ -1983,9 +1941,6 @@ def convert(
19831941
res_values = ensure_block_shape(res_values, self.ndim)
19841942
return [self.make_block(res_values)]
19851943

1986-
def _can_hold_element(self, element: Any) -> bool:
1987-
return True
1988-
19891944

19901945
class CategoricalBlock(ExtensionBlock):
19911946
# this Block type is kept for backwards-compatibility
@@ -2052,8 +2007,6 @@ def get_block_type(values, dtype: Optional[Dtype] = None):
20522007
cls = CategoricalBlock
20532008
elif vtype is Timestamp:
20542009
cls = DatetimeTZBlock
2055-
elif vtype is Interval or vtype is Period:
2056-
cls = ObjectValuesExtensionBlock
20572010
elif isinstance(dtype, ExtensionDtype):
20582011
# Note: need to be sure PandasArray is unwrapped before we get here
20592012
cls = ExtensionBlock

pandas/core/internals/managers.py

-9
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
CategoricalBlock,
7474
DatetimeTZBlock,
7575
ExtensionBlock,
76-
ObjectValuesExtensionBlock,
7776
ensure_block_shape,
7877
extend_blocks,
7978
get_block_type,
@@ -1841,14 +1840,6 @@ def _form_blocks(
18411840

18421841
blocks.extend(external_blocks)
18431842

1844-
if len(items_dict["ObjectValuesExtensionBlock"]):
1845-
external_blocks = [
1846-
new_block(array, klass=ObjectValuesExtensionBlock, placement=i, ndim=2)
1847-
for i, array in items_dict["ObjectValuesExtensionBlock"]
1848-
]
1849-
1850-
blocks.extend(external_blocks)
1851-
18521843
if len(extra_locs):
18531844
shape = (len(extra_locs),) + tuple(len(x) for x in axes[1:])
18541845

pandas/tests/extension/test_numpy.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pandas.util._test_decorators as td
2020

21+
from pandas.core.dtypes.cast import can_hold_element
2122
from pandas.core.dtypes.dtypes import (
2223
ExtensionDtype,
2324
PandasDtype,
@@ -27,7 +28,10 @@
2728
import pandas as pd
2829
import pandas._testing as tm
2930
from pandas.core.arrays.numpy_ import PandasArray
30-
from pandas.core.internals import managers
31+
from pandas.core.internals import (
32+
blocks,
33+
managers,
34+
)
3135
from pandas.tests.extension import base
3236

3337
# TODO(ArrayManager) PandasArray
@@ -45,6 +49,12 @@ def _extract_array_patched(obj):
4549
return obj
4650

4751

52+
def _can_hold_element_patched(obj, element) -> bool:
53+
if isinstance(element, PandasArray):
54+
element = element.to_numpy()
55+
return can_hold_element(obj, element)
56+
57+
4858
@pytest.fixture(params=["float", "object"])
4959
def dtype(request):
5060
return PandasDtype(np.dtype(request.param))
@@ -70,6 +80,7 @@ def allow_in_pandas(monkeypatch):
7080
with monkeypatch.context() as m:
7181
m.setattr(PandasArray, "_typ", "extension")
7282
m.setattr(managers, "_extract_array", _extract_array_patched)
83+
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
7384
yield
7485

7586

0 commit comments

Comments
 (0)