Skip to content

REF: combine Block _can_hold_element methods #40709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
IntervalArray,
PeriodArray,
TimedeltaArray,
)

_int8_max = np.iinfo(np.int8).max
Expand Down Expand Up @@ -2169,32 +2172,68 @@ def validate_numeric_casting(dtype: np.dtype, value: Scalar) -> None:
raise ValueError(f"Cannot assign {type(value).__name__} to bool series")


def can_hold_element(dtype: np.dtype, element: Any) -> bool:
def can_hold_element(arr: ArrayLike, element: Any) -> bool:
"""
Can we do an inplace setitem with this element in an array with this dtype?

Parameters
----------
dtype : np.dtype
arr : np.ndarray or ExtensionArray
element : Any

Returns
-------
bool
"""
dtype = arr.dtype
if not isinstance(dtype, np.dtype) or dtype.kind in ["m", "M"]:
if isinstance(dtype, (PeriodDtype, IntervalDtype, DatetimeTZDtype, np.dtype)):
# np.dtype here catches datetime64ns and timedelta64ns; we assume
# in this case that we have DatetimeArray/TimedeltaArray
arr = cast(
"PeriodArray | DatetimeArray | TimedeltaArray | IntervalArray", arr
)
try:
arr._validate_setitem_value(element)
return True
except (ValueError, TypeError):
return False

# This is technically incorrect, but maintains the behavior of
# ExtensionBlock._can_hold_element
return True

tipo = maybe_infer_dtype_type(element)

if dtype.kind in ["i", "u"]:
if tipo is not None:
return tipo.kind in ["i", "u"] and dtype.itemsize >= tipo.itemsize
if tipo.kind not in ["i", "u"]:
# Anything other than integer we cannot hold
return False
elif dtype.itemsize < tipo.itemsize:
return False
elif not isinstance(tipo, np.dtype):
# i.e. nullable IntegerDtype; we can put this into an ndarray
# losslessly iff it has no NAs
return not element._mask.any()
return True

# We have not inferred an integer from the dtype
# check if we have a builtin int or a float equal to an int
return is_integer(element) or (is_float(element) and element.is_integer())

elif dtype.kind == "f":
if tipo is not None:
return tipo.kind in ["f", "i", "u"]
# TODO: itemsize check?
if tipo.kind not in ["f", "i", "u"]:
# Anything other than float/integer we cannot hold
return False
elif not isinstance(tipo, np.dtype):
# i.e. nullable IntegerDtype or FloatingDtype;
# we can put this into an ndarray losslessly iff it has no NAs
return not element._mask.any()
return True

return lib.is_integer(element) or lib.is_float(element)

elif dtype.kind == "c":
Expand All @@ -2212,4 +2251,11 @@ def can_hold_element(dtype: np.dtype, element: Any) -> bool:
elif dtype == object:
return True

elif dtype.kind == "S":
# TODO: test tests.frame.methods.test_replace tests get here,
# need more targeted tests. xref phofl has a PR about this
if tipo is not None:
return tipo.kind == "S" and tipo.itemsize <= dtype.itemsize
return isinstance(element, bytes) and len(element) <= dtype.itemsize

raise NotImplementedError(dtype)
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4482,7 +4482,7 @@ def _validate_fill_value(self, value):
TypeError
If the value cannot be inserted into an array of this dtype.
"""
if not can_hold_element(self.dtype, value):
if not can_hold_element(self._values, value):
raise TypeError
return value

Expand Down
67 changes: 10 additions & 57 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Optional,
Tuple,
Expand All @@ -18,8 +17,6 @@
import numpy as np

from pandas._libs import (
Interval,
Period,
Timestamp,
algos as libalgos,
internals as libinternals,
Expand Down Expand Up @@ -102,6 +99,7 @@
PeriodArray,
TimedeltaArray,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.base import PandasObject
import pandas.core.common as com
import pandas.core.computation.expressions as expressions
Expand All @@ -122,7 +120,6 @@
Float64Index,
Index,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray

# comparison is faster than is_object_dtype
_dtype_obj = np.dtype("object")
Expand Down Expand Up @@ -625,9 +622,11 @@ def convert(
"""
return [self.copy()] if copy else [self]

@final
def _can_hold_element(self, element: Any) -> bool:
""" require the same dtype as ourselves """
raise NotImplementedError("Implemented on subclasses")
element = extract_array(element, extract_numpy=True)
return can_hold_element(self.values, element)

@final
def should_store(self, value: ArrayLike) -> bool:
Expand Down Expand Up @@ -1545,7 +1544,7 @@ def setitem(self, indexer, value):
be a compatible shape.
"""
if not self._can_hold_element(value):
# This is only relevant for DatetimeTZBlock, ObjectValuesExtensionBlock,
# This is only relevant for DatetimeTZBlock, PeriodDtype, IntervalDtype,
# which has a non-trivial `_can_hold_element`.
# https://github.com/pandas-dev/pandas/issues/24020
# Need a dedicated setitem until GH#24020 (type promotion in setitem
Expand Down Expand Up @@ -1597,10 +1596,6 @@ def take_nd(

return self.make_block_same_class(new_values, new_mgr_locs)

def _can_hold_element(self, element: Any) -> bool:
# TODO: We may need to think about pushing this onto the array.
return True

def _slice(self, slicer):
"""
Return a slice of my values.
Expand Down Expand Up @@ -1746,54 +1741,22 @@ def _unstack(self, unstacker, fill_value, new_placement):
return blocks, mask


class HybridMixin:
"""
Mixin for Blocks backed (maybe indirectly) by ExtensionArrays.
"""

array_values: Callable

def _can_hold_element(self, element: Any) -> bool:
values = self.array_values

try:
# error: "Callable[..., Any]" has no attribute "_validate_setitem_value"
values._validate_setitem_value(element) # type: ignore[attr-defined]
return True
except (ValueError, TypeError):
return False


class ObjectValuesExtensionBlock(HybridMixin, ExtensionBlock):
"""
Block providing backwards-compatibility for `.values`.

Used by PeriodArray and IntervalArray to ensure that
Series[T].values is an ndarray of objects.
"""

pass


class NumericBlock(Block):
__slots__ = ()
is_numeric = True

def _can_hold_element(self, element: Any) -> bool:
element = extract_array(element, extract_numpy=True)
if isinstance(element, (IntegerArray, FloatingArray)):
if element._mask.any():
return False
return can_hold_element(self.dtype, element)


class NDArrayBackedExtensionBlock(HybridMixin, Block):
class NDArrayBackedExtensionBlock(Block):
"""
Block backed by an NDArrayBackedExtensionArray
"""

values: NDArrayBackedExtensionArray

@property
def array_values(self) -> NDArrayBackedExtensionArray:
return self.values

@property
def is_view(self) -> bool:
""" return a boolean if I am possibly a view """
Expand Down Expand Up @@ -1901,10 +1864,6 @@ class DatetimeLikeBlockMixin(NDArrayBackedExtensionBlock):

is_numeric = False

@cache_readonly
def array_values(self):
return self.values


class DatetimeBlock(DatetimeLikeBlockMixin):
__slots__ = ()
Expand All @@ -1920,7 +1879,6 @@ class DatetimeTZBlock(ExtensionBlock, DatetimeLikeBlockMixin):
is_numeric = False

internal_values = Block.internal_values
_can_hold_element = DatetimeBlock._can_hold_element
diff = DatetimeBlock.diff
where = DatetimeBlock.where
putmask = DatetimeLikeBlockMixin.putmask
Expand Down Expand Up @@ -1983,9 +1941,6 @@ def convert(
res_values = ensure_block_shape(res_values, self.ndim)
return [self.make_block(res_values)]

def _can_hold_element(self, element: Any) -> bool:
return True


class CategoricalBlock(ExtensionBlock):
# this Block type is kept for backwards-compatibility
Expand Down Expand Up @@ -2052,8 +2007,6 @@ def get_block_type(values, dtype: Optional[Dtype] = None):
cls = CategoricalBlock
elif vtype is Timestamp:
cls = DatetimeTZBlock
elif vtype is Interval or vtype is Period:
cls = ObjectValuesExtensionBlock
elif isinstance(dtype, ExtensionDtype):
# Note: need to be sure PandasArray is unwrapped before we get here
cls = ExtensionBlock
Expand Down
9 changes: 0 additions & 9 deletions pandas/core/internals/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
CategoricalBlock,
DatetimeTZBlock,
ExtensionBlock,
ObjectValuesExtensionBlock,
ensure_block_shape,
extend_blocks,
get_block_type,
Expand Down Expand Up @@ -1841,14 +1840,6 @@ def _form_blocks(

blocks.extend(external_blocks)

if len(items_dict["ObjectValuesExtensionBlock"]):
external_blocks = [
new_block(array, klass=ObjectValuesExtensionBlock, placement=i, ndim=2)
for i, array in items_dict["ObjectValuesExtensionBlock"]
]

blocks.extend(external_blocks)

if len(extra_locs):
shape = (len(extra_locs),) + tuple(len(x) for x in axes[1:])

Expand Down
13 changes: 12 additions & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas.util._test_decorators as td

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.dtypes import (
ExtensionDtype,
PandasDtype,
Expand All @@ -27,7 +28,10 @@
import pandas as pd
import pandas._testing as tm
from pandas.core.arrays.numpy_ import PandasArray
from pandas.core.internals import managers
from pandas.core.internals import (
blocks,
managers,
)
from pandas.tests.extension import base

# TODO(ArrayManager) PandasArray
Expand All @@ -45,6 +49,12 @@ def _extract_array_patched(obj):
return obj


def _can_hold_element_patched(obj, element) -> bool:
if isinstance(element, PandasArray):
element = element.to_numpy()
return can_hold_element(obj, element)


@pytest.fixture(params=["float", "object"])
def dtype(request):
return PandasDtype(np.dtype(request.param))
Expand All @@ -70,6 +80,7 @@ def allow_in_pandas(monkeypatch):
with monkeypatch.context() as m:
m.setattr(PandasArray, "_typ", "extension")
m.setattr(managers, "_extract_array", _extract_array_patched)
m.setattr(blocks, "can_hold_element", _can_hold_element_patched)
yield


Expand Down