diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 27b4539488e40..3ea716d93aca0 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1418,6 +1418,58 @@ def ensure_nanosecond_dtype(dtype: DtypeObj) -> DtypeObj: return dtype +# TODO: other value-dependent functions to standardize here include +# dtypes.concat.cast_to_common_type and Index._find_common_type_compat +def find_result_type(left: ArrayLike, right: Any) -> DtypeObj: + """ + Find the type/dtype for a the result of an operation between these objects. + + This is similar to find_common_type, but looks at the objects instead + of just their dtypes. This can be useful in particular when one of the + objects does not have a `dtype`. + + Parameters + ---------- + left : np.ndarray or ExtensionArray + right : Any + + Returns + ------- + np.dtype or ExtensionDtype + + See also + -------- + find_common_type + numpy.result_type + """ + new_dtype: DtypeObj + + if left.dtype.kind in ["i", "u", "c"] and ( + lib.is_integer(right) or lib.is_float(right) + ): + # e.g. with int8 dtype and right=512, we want to end up with + # np.int16, whereas infer_dtype_from(512) gives np.int64, + # which will make us upcast too far. + if lib.is_float(right) and right.is_integer() and left.dtype.kind != "f": + right = int(right) + + # Argument 1 to "result_type" has incompatible type "Union[ExtensionArray, + # ndarray[Any, Any]]"; expected "Union[Union[_SupportsArray[dtype[Any]], + # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, + # str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], + # Union[dtype[Any], None, Type[Any], _SupportsDType[dtype[Any]], str, + # Union[Tuple[Any, int], Tuple[Any, Union[SupportsIndex, + # Sequence[SupportsIndex]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]" + new_dtype = np.result_type(left, right) # type:ignore[arg-type] + + else: + dtype, _ = infer_dtype_from(right, pandas_dtype=True) + + new_dtype = find_common_type([left.dtype, dtype]) + + return new_dtype + + @overload def find_common_type(types: list[np.dtype]) -> np.dtype: ... diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index ec3b9261dd1f5..d8543a15b1ea0 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -38,7 +38,7 @@ from pandas.core.dtypes.astype import astype_array_safe from pandas.core.dtypes.cast import ( can_hold_element, - find_common_type, + find_result_type, infer_dtype_from, maybe_downcast_numeric, maybe_downcast_to_dtype, @@ -1031,10 +1031,7 @@ def coerce_to_target_dtype(self, other) -> Block: we can also safely try to coerce to the same dtype and will receive the same block """ - # if we cannot then coerce to object - dtype, _ = infer_dtype_from(other, pandas_dtype=True) - - new_dtype = find_common_type([self.dtype, dtype]) + new_dtype = find_result_type(self.values, other) return self.astype(new_dtype, copy=False) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 90b348f401437..29605d0cfad3e 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -753,13 +753,15 @@ def test_where_try_cast_deprecated(frame_or_series): obj.where(mask, -1, try_cast=False) -@pytest.mark.xfail( - reason="After fixing a bug in can_hold_element, we don't go through " - "the deprecated path, and also up-cast to int64 instead of int32 " - "(for now)." -) -def test_where_int_downcasting_deprecated(using_array_manager): +def test_where_int_downcasting_deprecated(using_array_manager, request): # GH#44597 + if not using_array_manager: + mark = pytest.mark.xfail( + reason="After fixing a bug in can_hold_element, we don't go through " + "the deprecated path, and also up-cast both columns to int32 " + "instead of just 1." + ) + request.node.add_marker(mark) arr = np.arange(6).astype(np.int16).reshape(3, 2) df = DataFrame(arr) diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index 699018f08adee..f551363063999 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -2734,14 +2734,14 @@ def test_loc_getitem_nullable_index_with_duplicates(): tm.assert_series_equal(res, expected) -def test_loc_setitem_uint8_upcast(): +@pytest.mark.parametrize("value", [300, np.uint16(300), np.int16(300)]) +def test_loc_setitem_uint8_upcast(value): # GH#26049 df = DataFrame([1, 2, 3, 4], columns=["col1"], dtype="uint8") - df.loc[2, "col1"] = 300 # value that can't be held in uint8 + df.loc[2, "col1"] = value # value that can't be held in uint8 - # TODO: would be better to get uint16? - expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="int64") + expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="uint16") tm.assert_frame_equal(df, expected) diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index d4b69abb6aba9..c4e5164206126 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -1077,27 +1077,6 @@ def key(self): def expected(self): return Series([1, 512, 3], dtype=np.int16) - def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace) - - def test_mask_key(self, obj, key, expected, val, indexer_sli, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_mask_key(obj, key, expected, val, indexer_sli) - - def test_series_where(self, obj, key, expected, val, is_inplace, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_series_where(obj, key, expected, val, is_inplace) - @pytest.mark.parametrize("val", [2 ** 33 + 1.0, 2 ** 33 + 1.1, 2 ** 62]) class TestSmallIntegerSetitemUpcast(SetitemCastingEquivalents): @@ -1118,18 +1097,6 @@ def expected(self, val): dtype = "i8" return Series([val, 2, 3], dtype=dtype) - def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request): - if val % 1 == 0 and isinstance(val, float): - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace) - - def test_mask_key(self, obj, key, expected, val, indexer_sli, request): - if val % 1 == 0 and isinstance(val, float): - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_mask_key(obj, key, expected, val, indexer_sli) - class CoercionTest(SetitemCastingEquivalents): # Tests ported from tests.indexing.test_coercion