Skip to content

Commit ab5c5a3

Browse files
authored
REF: implement find_result_type (#45304)
1 parent a1ed4b2 commit ab5c5a3

File tree

5 files changed

+66
-48
lines changed

5 files changed

+66
-48
lines changed

pandas/core/dtypes/cast.py

+52
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,58 @@ def ensure_nanosecond_dtype(dtype: DtypeObj) -> DtypeObj:
14181418
return dtype
14191419

14201420

1421+
# TODO: other value-dependent functions to standardize here include
1422+
# dtypes.concat.cast_to_common_type and Index._find_common_type_compat
1423+
def find_result_type(left: ArrayLike, right: Any) -> DtypeObj:
1424+
"""
1425+
Find the type/dtype for a the result of an operation between these objects.
1426+
1427+
This is similar to find_common_type, but looks at the objects instead
1428+
of just their dtypes. This can be useful in particular when one of the
1429+
objects does not have a `dtype`.
1430+
1431+
Parameters
1432+
----------
1433+
left : np.ndarray or ExtensionArray
1434+
right : Any
1435+
1436+
Returns
1437+
-------
1438+
np.dtype or ExtensionDtype
1439+
1440+
See also
1441+
--------
1442+
find_common_type
1443+
numpy.result_type
1444+
"""
1445+
new_dtype: DtypeObj
1446+
1447+
if left.dtype.kind in ["i", "u", "c"] and (
1448+
lib.is_integer(right) or lib.is_float(right)
1449+
):
1450+
# e.g. with int8 dtype and right=512, we want to end up with
1451+
# np.int16, whereas infer_dtype_from(512) gives np.int64,
1452+
# which will make us upcast too far.
1453+
if lib.is_float(right) and right.is_integer() and left.dtype.kind != "f":
1454+
right = int(right)
1455+
1456+
# Argument 1 to "result_type" has incompatible type "Union[ExtensionArray,
1457+
# ndarray[Any, Any]]"; expected "Union[Union[_SupportsArray[dtype[Any]],
1458+
# _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex,
1459+
# str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]],
1460+
# Union[dtype[Any], None, Type[Any], _SupportsDType[dtype[Any]], str,
1461+
# Union[Tuple[Any, int], Tuple[Any, Union[SupportsIndex,
1462+
# Sequence[SupportsIndex]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]"
1463+
new_dtype = np.result_type(left, right) # type:ignore[arg-type]
1464+
1465+
else:
1466+
dtype, _ = infer_dtype_from(right, pandas_dtype=True)
1467+
1468+
new_dtype = find_common_type([left.dtype, dtype])
1469+
1470+
return new_dtype
1471+
1472+
14211473
@overload
14221474
def find_common_type(types: list[np.dtype]) -> np.dtype:
14231475
...

pandas/core/internals/blocks.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pandas.core.dtypes.astype import astype_array_safe
3939
from pandas.core.dtypes.cast import (
4040
can_hold_element,
41-
find_common_type,
41+
find_result_type,
4242
infer_dtype_from,
4343
maybe_downcast_numeric,
4444
maybe_downcast_to_dtype,
@@ -1031,10 +1031,7 @@ def coerce_to_target_dtype(self, other) -> Block:
10311031
we can also safely try to coerce to the same dtype
10321032
and will receive the same block
10331033
"""
1034-
# if we cannot then coerce to object
1035-
dtype, _ = infer_dtype_from(other, pandas_dtype=True)
1036-
1037-
new_dtype = find_common_type([self.dtype, dtype])
1034+
new_dtype = find_result_type(self.values, other)
10381035

10391036
return self.astype(new_dtype, copy=False)
10401037

pandas/tests/frame/indexing/test_where.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,15 @@ def test_where_try_cast_deprecated(frame_or_series):
753753
obj.where(mask, -1, try_cast=False)
754754

755755

756-
@pytest.mark.xfail(
757-
reason="After fixing a bug in can_hold_element, we don't go through "
758-
"the deprecated path, and also up-cast to int64 instead of int32 "
759-
"(for now)."
760-
)
761-
def test_where_int_downcasting_deprecated(using_array_manager):
756+
def test_where_int_downcasting_deprecated(using_array_manager, request):
762757
# GH#44597
758+
if not using_array_manager:
759+
mark = pytest.mark.xfail(
760+
reason="After fixing a bug in can_hold_element, we don't go through "
761+
"the deprecated path, and also up-cast both columns to int32 "
762+
"instead of just 1."
763+
)
764+
request.node.add_marker(mark)
763765
arr = np.arange(6).astype(np.int16).reshape(3, 2)
764766
df = DataFrame(arr)
765767

pandas/tests/indexing/test_loc.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2734,14 +2734,14 @@ def test_loc_getitem_nullable_index_with_duplicates():
27342734
tm.assert_series_equal(res, expected)
27352735

27362736

2737-
def test_loc_setitem_uint8_upcast():
2737+
@pytest.mark.parametrize("value", [300, np.uint16(300), np.int16(300)])
2738+
def test_loc_setitem_uint8_upcast(value):
27382739
# GH#26049
27392740

27402741
df = DataFrame([1, 2, 3, 4], columns=["col1"], dtype="uint8")
2741-
df.loc[2, "col1"] = 300 # value that can't be held in uint8
2742+
df.loc[2, "col1"] = value # value that can't be held in uint8
27422743

2743-
# TODO: would be better to get uint16?
2744-
expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="int64")
2744+
expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="uint16")
27452745
tm.assert_frame_equal(df, expected)
27462746

27472747

pandas/tests/series/indexing/test_setitem.py

-33
Original file line numberDiff line numberDiff line change
@@ -1077,27 +1077,6 @@ def key(self):
10771077
def expected(self):
10781078
return Series([1, 512, 3], dtype=np.int16)
10791079

1080-
def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request):
1081-
if not isinstance(val, np.int16):
1082-
# with python int we end up with int64
1083-
mark = pytest.mark.xfail
1084-
request.node.add_marker(mark)
1085-
super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace)
1086-
1087-
def test_mask_key(self, obj, key, expected, val, indexer_sli, request):
1088-
if not isinstance(val, np.int16):
1089-
# with python int we end up with int64
1090-
mark = pytest.mark.xfail
1091-
request.node.add_marker(mark)
1092-
super().test_mask_key(obj, key, expected, val, indexer_sli)
1093-
1094-
def test_series_where(self, obj, key, expected, val, is_inplace, request):
1095-
if not isinstance(val, np.int16):
1096-
# with python int we end up with int64
1097-
mark = pytest.mark.xfail
1098-
request.node.add_marker(mark)
1099-
super().test_series_where(obj, key, expected, val, is_inplace)
1100-
11011080

11021081
@pytest.mark.parametrize("val", [2 ** 33 + 1.0, 2 ** 33 + 1.1, 2 ** 62])
11031082
class TestSmallIntegerSetitemUpcast(SetitemCastingEquivalents):
@@ -1118,18 +1097,6 @@ def expected(self, val):
11181097
dtype = "i8"
11191098
return Series([val, 2, 3], dtype=dtype)
11201099

1121-
def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request):
1122-
if val % 1 == 0 and isinstance(val, float):
1123-
mark = pytest.mark.xfail
1124-
request.node.add_marker(mark)
1125-
super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace)
1126-
1127-
def test_mask_key(self, obj, key, expected, val, indexer_sli, request):
1128-
if val % 1 == 0 and isinstance(val, float):
1129-
mark = pytest.mark.xfail
1130-
request.node.add_marker(mark)
1131-
super().test_mask_key(obj, key, expected, val, indexer_sli)
1132-
11331100

11341101
class CoercionTest(SetitemCastingEquivalents):
11351102
# Tests ported from tests.indexing.test_coercion

0 commit comments

Comments
 (0)