From 4a2922f1380ecd2bdce2ec1518ef5b30a7f5c5a0 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 8 Jan 2022 10:42:40 -0800 Subject: [PATCH 1/4] BUG: can_hold_element size checks on ints/floats --- pandas/tests/frame/indexing/test_where.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 2ee777cf53d29..90b348f401437 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -758,7 +758,7 @@ def test_where_try_cast_deprecated(frame_or_series): "the deprecated path, and also up-cast to int64 instead of int32 " "(for now)." ) -def test_where_int_downcasting_deprecated(using_array_manager, request): +def test_where_int_downcasting_deprecated(using_array_manager): # GH#44597 arr = np.arange(6).astype(np.int16).reshape(3, 2) df = DataFrame(arr) From d93e122930ccb462b6e036ca28b3b54808625397 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 10 Jan 2022 16:33:56 -0800 Subject: [PATCH 2/4] REF: implement find_result_type --- pandas/core/dtypes/cast.py | 42 ++++++++++++++++++++ pandas/core/internals/blocks.py | 7 +--- pandas/tests/indexing/test_loc.py | 8 ++-- pandas/tests/series/indexing/test_setitem.py | 33 --------------- 4 files changed, 48 insertions(+), 42 deletions(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index d51f878a1c85a..23ccbaa3ff917 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1418,6 +1418,48 @@ def ensure_nanosecond_dtype(dtype: DtypeObj) -> DtypeObj: return dtype +# TODO: other value-dependent functions to standardize here include +# dtypes.concat.cast_to_common_type and Index._find_common_type_compat +def find_result_type(left: ArrayLike, right: Any) -> DtypeObj: + """ + Find the type/dtype for a the result of an operation between these objects. + + This is similar to find_common_type, but looks at the objects instead + of just their dtypes. This can be useful in particular when one of the + objects does not have a `dtype`. + + Parameters + ---------- + left : np.ndarray or ExtensionArray + right : Any + + Returns + ------- + np.dtype or ExtensionDtype + + See also + -------- + find_common_type + numpy.result_type + """ + if left.dtype.kind in ["i", "u", "c"] and ( + lib.is_integer(right) or lib.is_float(right) + ): + # e.g. with int8 dtype and right=512, we want to end up with + # np.int16, whereas infer_dtype_from(512) gives np.int64, + # which will make us upcast too far. + if lib.is_float(right) and right.is_integer() and left.dtype.kind != "f": + right = int(right) + new_dtype = np.result_type(left, right) + + else: + dtype, _ = infer_dtype_from(right, pandas_dtype=True) + + new_dtype = find_common_type([left.dtype, dtype]) + + return new_dtype + + @overload def find_common_type(types: list[np.dtype]) -> np.dtype: ... diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 0ced984bbc568..9976121056083 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -38,7 +38,7 @@ from pandas.core.dtypes.astype import astype_array_safe from pandas.core.dtypes.cast import ( can_hold_element, - find_common_type, + find_result_type, infer_dtype_from, maybe_downcast_numeric, maybe_downcast_to_dtype, @@ -1031,10 +1031,7 @@ def coerce_to_target_dtype(self, other) -> Block: we can also safely try to coerce to the same dtype and will receive the same block """ - # if we cannot then coerce to object - dtype, _ = infer_dtype_from(other, pandas_dtype=True) - - new_dtype = find_common_type([self.dtype, dtype]) + new_dtype = find_result_type(self.values, other) return self.astype(new_dtype, copy=False) diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index 699018f08adee..f551363063999 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -2734,14 +2734,14 @@ def test_loc_getitem_nullable_index_with_duplicates(): tm.assert_series_equal(res, expected) -def test_loc_setitem_uint8_upcast(): +@pytest.mark.parametrize("value", [300, np.uint16(300), np.int16(300)]) +def test_loc_setitem_uint8_upcast(value): # GH#26049 df = DataFrame([1, 2, 3, 4], columns=["col1"], dtype="uint8") - df.loc[2, "col1"] = 300 # value that can't be held in uint8 + df.loc[2, "col1"] = value # value that can't be held in uint8 - # TODO: would be better to get uint16? - expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="int64") + expected = DataFrame([1, 2, 300, 4], columns=["col1"], dtype="uint16") tm.assert_frame_equal(df, expected) diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index d4b69abb6aba9..c4e5164206126 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -1077,27 +1077,6 @@ def key(self): def expected(self): return Series([1, 512, 3], dtype=np.int16) - def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace) - - def test_mask_key(self, obj, key, expected, val, indexer_sli, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_mask_key(obj, key, expected, val, indexer_sli) - - def test_series_where(self, obj, key, expected, val, is_inplace, request): - if not isinstance(val, np.int16): - # with python int we end up with int64 - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_series_where(obj, key, expected, val, is_inplace) - @pytest.mark.parametrize("val", [2 ** 33 + 1.0, 2 ** 33 + 1.1, 2 ** 62]) class TestSmallIntegerSetitemUpcast(SetitemCastingEquivalents): @@ -1118,18 +1097,6 @@ def expected(self, val): dtype = "i8" return Series([val, 2, 3], dtype=dtype) - def test_int_key(self, obj, key, expected, val, indexer_sli, is_inplace, request): - if val % 1 == 0 and isinstance(val, float): - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_int_key(obj, key, expected, val, indexer_sli, is_inplace) - - def test_mask_key(self, obj, key, expected, val, indexer_sli, request): - if val % 1 == 0 and isinstance(val, float): - mark = pytest.mark.xfail - request.node.add_marker(mark) - super().test_mask_key(obj, key, expected, val, indexer_sli) - class CoercionTest(SetitemCastingEquivalents): # Tests ported from tests.indexing.test_coercion From 056902800f12fab8940db1a207c8fb2e5c36120f Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 10 Jan 2022 19:58:26 -0800 Subject: [PATCH 3/4] fix xfail --- pandas/tests/frame/indexing/test_where.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pandas/tests/frame/indexing/test_where.py b/pandas/tests/frame/indexing/test_where.py index 90b348f401437..29605d0cfad3e 100644 --- a/pandas/tests/frame/indexing/test_where.py +++ b/pandas/tests/frame/indexing/test_where.py @@ -753,13 +753,15 @@ def test_where_try_cast_deprecated(frame_or_series): obj.where(mask, -1, try_cast=False) -@pytest.mark.xfail( - reason="After fixing a bug in can_hold_element, we don't go through " - "the deprecated path, and also up-cast to int64 instead of int32 " - "(for now)." -) -def test_where_int_downcasting_deprecated(using_array_manager): +def test_where_int_downcasting_deprecated(using_array_manager, request): # GH#44597 + if not using_array_manager: + mark = pytest.mark.xfail( + reason="After fixing a bug in can_hold_element, we don't go through " + "the deprecated path, and also up-cast both columns to int32 " + "instead of just 1." + ) + request.node.add_marker(mark) arr = np.arange(6).astype(np.int16).reshape(3, 2) df = DataFrame(arr) From c08b0f2d5c28327cbd8e09fc9b3a6b01c593fff9 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 10 Jan 2022 20:33:11 -0800 Subject: [PATCH 4/4] mypy fixup --- pandas/core/dtypes/cast.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 23ccbaa3ff917..6e83fb421401c 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1442,6 +1442,8 @@ def find_result_type(left: ArrayLike, right: Any) -> DtypeObj: find_common_type numpy.result_type """ + new_dtype: DtypeObj + if left.dtype.kind in ["i", "u", "c"] and ( lib.is_integer(right) or lib.is_float(right) ): @@ -1450,7 +1452,15 @@ def find_result_type(left: ArrayLike, right: Any) -> DtypeObj: # which will make us upcast too far. if lib.is_float(right) and right.is_integer() and left.dtype.kind != "f": right = int(right) - new_dtype = np.result_type(left, right) + + # Argument 1 to "result_type" has incompatible type "Union[ExtensionArray, + # ndarray[Any, Any]]"; expected "Union[Union[_SupportsArray[dtype[Any]], + # _NestedSequence[_SupportsArray[dtype[Any]]], bool, int, float, complex, + # str, bytes, _NestedSequence[Union[bool, int, float, complex, str, bytes]]], + # Union[dtype[Any], None, Type[Any], _SupportsDType[dtype[Any]], str, + # Union[Tuple[Any, int], Tuple[Any, Union[SupportsIndex, + # Sequence[SupportsIndex]]], List[Any], _DTypeDict, Tuple[Any, Any]]]]" + new_dtype = np.result_type(left, right) # type:ignore[arg-type] else: dtype, _ = infer_dtype_from(right, pandas_dtype=True)