Skip to content

ENH: EADtype._find_compatible_dtype #53106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion pandas/core/arrays/arrow/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
)
from decimal import Decimal
import re
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Any,
)

import numpy as np

from pandas._libs import missing as libmissing
from pandas._libs.tslibs import (
Timedelta,
Timestamp,
Expand All @@ -23,6 +27,7 @@
StorageExtensionDtype,
register_extension_dtype,
)
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.dtypes import CategoricalDtypeType

if not pa_version_under7p0:
Expand Down Expand Up @@ -321,3 +326,27 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray):
array_class = self.construct_array_type()
arr = array.cast(self.pyarrow_dtype, safe=True)
return array_class(arr)

def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]:
if isinstance(item, pa.Scalar):
if not item.is_valid:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't NA always be able to be inserted into ArrowExtensionArray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not sure. pyarrow nulls are typed, so we could plausibly want to disallow e.g. <pyarrow.TimestampScalar: None> in a pyarrow integer dtype

# TODO: ask joris for help making these checks more robust
if item.type == self.pyarrow_dtype:
return self, item.as_py()
if item.type.to_pandas_dtype() == np.int64 and self.kind == "i":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed specifically?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just to get the tests copied from #52833 passing.

# FIXME: kludge
return self, item.as_py()

item = item.as_py()

elif item is None or item is libmissing.NA:
# TODO: np.nan? use is_valid_na_for_dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since pyarrow supports nan vs NA, possibly we want to allow nan if pa.types.is_floating(self.pyarrow_dtype)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what to do here depends on making a decision about when/how to distinguish between np.nan and pd.NA (which i hope to finally nail down at the sprint). doing this The Right Way would involve something like implementing EA._is_valid_na_for_dtype

return self, item

dtype, item = maybe_promote(self.numpy_dtype, item)

if dtype == self.numpy_dtype:
return self, item

# TODO: implement from_numpy_dtype analogous to MaskedDtype.from_numpy_dtype
return np.dtype(object), item
3 changes: 3 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def _can_hold_na(self) -> bool:
"""
return True

def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]:
return np.dtype(object), item


class StorageExtensionDtype(ExtensionDtype):
"""ExtensionDtype that may be backed by more than one implementation."""
Expand Down
17 changes: 7 additions & 10 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
ensure_int16,
ensure_int32,
ensure_int64,
ensure_object,
ensure_str,
is_bool,
is_complex,
Expand Down Expand Up @@ -539,13 +538,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj:
}


def maybe_promote(dtype: np.dtype, fill_value=np.nan):
def maybe_promote(dtype: DtypeObj, fill_value=np.nan):
"""
Find the minimal dtype that can hold both the given dtype and fill_value.

Parameters
----------
dtype : np.dtype
dtype : np.dtype or ExtensionDtype
fill_value : scalar, default np.nan

Returns
Expand Down Expand Up @@ -593,9 +592,13 @@ def _maybe_promote_cached(dtype, fill_value, fill_value_type):
return _maybe_promote(dtype, fill_value)


def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
def _maybe_promote(dtype: DtypeObj, fill_value=np.nan):
# The actual implementation of the function, use `maybe_promote` above for
# a cached version.

if not isinstance(dtype, np.dtype):
return dtype._maybe_promote(fill_value)

if not is_scalar(fill_value):
# with object dtype there is nothing to promote, and the user can
# pass pretty much any weird fill_value they like
Expand All @@ -611,12 +614,6 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
fv = na_value_for_dtype(dtype)
return dtype, fv

elif isinstance(dtype, CategoricalDtype):
if fill_value in dtype.categories or isna(fill_value):
return dtype, fill_value
else:
return object, ensure_object(fill_value)

elif isna(fill_value):
dtype = _dtype_obj
if fill_value is None:
Expand Down
21 changes: 21 additions & 0 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:

return find_common_type(non_cat_dtypes)

def _maybe_promote(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if item in self.categories or is_valid_na_for_dtype(
item, self.categories.dtype
):
return self, item
return np.dtype(object), item


@register_extension_dtype
class DatetimeTZDtype(PandasExtensionDtype):
Expand Down Expand Up @@ -1500,3 +1509,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
return type(self).from_numpy_dtype(new_dtype)
except (KeyError, NotImplementedError):
return None

def _maybe_promote(self, item) -> tuple[DtypeObj, Any]:
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.missing import is_valid_na_for_dtype

if is_valid_na_for_dtype(item, self):
return self, item

dtype, item = maybe_promote(self.numpy_dtype, item)
if dtype.kind in "iufb":
return type(self).from_numpy_dtype(dtype), item
return dtype, item
5 changes: 2 additions & 3 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,7 +2091,7 @@ def _setitem_with_indexer_missing(self, indexer, value):
return self._setitem_with_indexer(new_indexer, value, "loc")

# this preserves dtype of the value and of the object
if not is_scalar(value):
if is_list_like(value):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for ArrowDtype with pa.list_ type, we would want to treat value like a scalar e.g

ser = pd.Series([[1, 1]], dtype=pd.ArrowDtype(pa.list_(pa.int64())))
ser[0] = [1, 2]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah, getting rid of this is_list_like check causes us to incorrectly raise on numpy non-object cases when using a list value (for which we don't have any tests). Can fix that in this PR or separately, as it is a bit more invasive.

new_dtype = None

elif is_valid_na_for_dtype(value, self.obj.dtype):
Expand All @@ -2107,8 +2107,7 @@ def _setitem_with_indexer_missing(self, indexer, value):
# We should not cast, if we have object dtype because we can
# set timedeltas into object series
curr_dtype = self.obj.dtype
curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype)
new_dtype = maybe_promote(curr_dtype, value)[0]
new_dtype, value = maybe_promote(curr_dtype, value)
else:
new_dtype = None

Expand Down
23 changes: 23 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2855,6 +2855,29 @@ def test_describe_timedelta_data(pa_type):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"value, target_value, dtype",
[
(pa.scalar(4, type="int32"), 4, "int32[pyarrow]"),
(pa.scalar(4, type="int64"), 4, "int32[pyarrow]"),
# (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens here?

Also what happens with a int64 scalar and int32 dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id want to follow the same logic we do for numpy dtypes, but was punting here in expectation of doing it in a follow-up (likely involving joris expressing an opinion)

(4, 4, "int32[pyarrow]"),
(pd.NA, None, "int32[pyarrow]"),
(None, None, "int32[pyarrow]"),
(pa.scalar(None, type="int32"), None, "int32[pyarrow]"),
(pa.scalar(None, type="int64"), None, "int32[pyarrow]"),
],
)
def test_series_setitem_with_enlargement(value, target_value, dtype):
# GH#52235
# similar to series/inedexing/test_setitem.py::test_setitem_keep_precision
# and test_setitem_enlarge_with_na, but for arrow dtypes
ser = pd.Series([1, 2, 3], dtype=dtype)
ser[3] = value
expected = pd.Series([1, 2, 3, target_value], dtype=dtype)
tm.assert_series_equal(ser, expected)


@pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES)
def test_describe_datetime_data(pa_type):
# GH53001
Expand Down