From 7b9feb0d049045792b330fafe2a8f3c9592685fa Mon Sep 17 00:00:00 2001 From: Kaiqi Dong Date: Sat, 1 Feb 2020 15:37:24 +0100 Subject: [PATCH] BUG: Array.__setitem__ failing with nullable boolean mask (#31484) --- doc/source/whatsnew/v1.0.1.rst | 1 + pandas/core/arrays/boolean.py | 2 ++ pandas/core/arrays/categorical.py | 2 ++ pandas/core/arrays/datetimelike.py | 2 ++ pandas/core/arrays/integer.py | 2 ++ pandas/core/arrays/interval.py | 1 + pandas/core/arrays/numpy_.py | 1 + pandas/core/arrays/string_.py | 2 ++ pandas/tests/arrays/test_integer.py | 17 +++++++++++++++++ pandas/tests/extension/base/setitem.py | 12 ++++++++++++ pandas/tests/extension/decimal/array.py | 3 +++ 11 files changed, 45 insertions(+) diff --git a/doc/source/whatsnew/v1.0.1.rst b/doc/source/whatsnew/v1.0.1.rst index cb916cecd4f1b..9e78ff03f5f67 100644 --- a/doc/source/whatsnew/v1.0.1.rst +++ b/doc/source/whatsnew/v1.0.1.rst @@ -71,6 +71,7 @@ Indexing - - +- Bug where assigning to a :class:`Series` using a IntegerArray / BooleanArray as a mask would raise ``TypeError`` (:issue:`31446`) Missing ^^^^^^^ diff --git a/pandas/core/arrays/boolean.py b/pandas/core/arrays/boolean.py index 7b12f3348e7e7..9eeed42124f2a 100644 --- a/pandas/core/arrays/boolean.py +++ b/pandas/core/arrays/boolean.py @@ -26,6 +26,7 @@ from pandas.core.dtypes.missing import isna, notna from pandas.core import nanops, ops +from pandas.core.indexers import check_array_indexer from .masked import BaseMaskedArray @@ -369,6 +370,7 @@ def __setitem__(self, key, value): value = value[0] mask = mask[0] + key = check_array_indexer(self, key) self._data[key] = value self._mask[key] = mask diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 105b14aa3c3b7..aa84edd413bc9 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2073,6 +2073,8 @@ def __setitem__(self, key, value): lindexer = self.categories.get_indexer(rvalue) lindexer = self._maybe_coerce_indexer(lindexer) + + key = check_array_indexer(self, key) self._codes[key] = lindexer def _reverse_indexer(self) -> Dict[Hashable, np.ndarray]: diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 82fa9197b39eb..e8d5890d2564f 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -601,6 +601,8 @@ def __setitem__( f"or array of those. Got '{type(value).__name__}' instead." ) raise TypeError(msg) + + key = check_array_indexer(self, key) self._data[key] = value self._maybe_clear_freq() diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index 022e6a7322872..9a0f5794e7607 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -25,6 +25,7 @@ from pandas.core.dtypes.missing import isna from pandas.core import nanops, ops +from pandas.core.indexers import check_array_indexer from pandas.core.ops import invalid_comparison from pandas.core.ops.common import unpack_zerodim_and_defer from pandas.core.tools.numeric import to_numeric @@ -414,6 +415,7 @@ def __setitem__(self, key, value): value = value[0] mask = mask[0] + key = check_array_indexer(self, key) self._data[key] = value self._mask[key] = mask diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index d890c0c16aecc..23cf5f317ac7d 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -541,6 +541,7 @@ def __setitem__(self, key, value): msg = f"'value' should be an interval type, got {type(value)} instead." raise TypeError(msg) + key = check_array_indexer(self, key) # Need to ensure that left and right are updated atomically, so we're # forced to copy, update the copy, and swap in the new values. left = self.left.copy(deep=True) diff --git a/pandas/core/arrays/numpy_.py b/pandas/core/arrays/numpy_.py index 8b1d1e58dc36c..57cc52ce24f8c 100644 --- a/pandas/core/arrays/numpy_.py +++ b/pandas/core/arrays/numpy_.py @@ -243,6 +243,7 @@ def __getitem__(self, item): def __setitem__(self, key, value): value = extract_array(value, extract_numpy=True) + key = check_array_indexer(self, key) scalar_key = lib.is_scalar(key) scalar_value = lib.is_scalar(value) diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index c485d1f50dc9d..b53484e1892f9 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -15,6 +15,7 @@ from pandas.core import ops from pandas.core.arrays import PandasArray from pandas.core.construction import extract_array +from pandas.core.indexers import check_array_indexer from pandas.core.missing import isna @@ -224,6 +225,7 @@ def __setitem__(self, key, value): # extract_array doesn't extract PandasArray subclasses value = value._ndarray + key = check_array_indexer(self, key) scalar_key = lib.is_scalar(key) scalar_value = lib.is_scalar(value) if scalar_key and not scalar_value: diff --git a/pandas/tests/arrays/test_integer.py b/pandas/tests/arrays/test_integer.py index 0c5ae506ae0ce..c165910777649 100644 --- a/pandas/tests/arrays/test_integer.py +++ b/pandas/tests/arrays/test_integer.py @@ -1072,6 +1072,23 @@ def test_cut(bins, right, include_lowest): tm.assert_categorical_equal(result, expected) +def test_array_setitem_nullable_boolean_mask(): + # GH 31446 + ser = pd.Series([1, 2], dtype="Int64") + result = ser.where(ser > 1) + expected = pd.Series([pd.NA, 2], dtype="Int64") + tm.assert_series_equal(result, expected) + + +def test_array_setitem(): + # GH 31446 + arr = pd.Series([1, 2], dtype="Int64").array + arr[arr > 1] = 1 + + expected = pd.array([1, 1], dtype="Int64") + tm.assert_extension_array_equal(arr, expected) + + # TODO(jreback) - these need testing / are broken # shift diff --git a/pandas/tests/extension/base/setitem.py b/pandas/tests/extension/base/setitem.py index 0bb8aede6298c..e0ca603aaa0ed 100644 --- a/pandas/tests/extension/base/setitem.py +++ b/pandas/tests/extension/base/setitem.py @@ -4,6 +4,7 @@ import pytest import pandas as pd +from pandas.core.arrays.numpy_ import PandasDtype from .base import BaseExtensionTests @@ -195,3 +196,14 @@ def test_setitem_preserves_views(self, data): data[0] = data[1] assert view1[0] == data[1] assert view2[0] == data[1] + + def test_setitem_nullable_mask(self, data): + # GH 31446 + # TODO: there is some issue with PandasArray, therefore, + # TODO: skip the setitem test for now, and fix it later + if data.dtype != PandasDtype("object"): + arr = data[:5] + expected = data.take([0, 0, 0, 3, 4]) + mask = pd.array([True, True, True, False, False]) + arr[mask] = data[0] + self.assert_extension_array_equal(expected, arr) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 743852c35dbd8..8fd4a0171a222 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -10,6 +10,7 @@ import pandas as pd from pandas.api.extensions import no_default, register_extension_dtype from pandas.core.arrays import ExtensionArray, ExtensionScalarOpsMixin +from pandas.core.indexers import check_array_indexer @register_extension_dtype @@ -144,6 +145,8 @@ def __setitem__(self, key, value): value = [decimal.Decimal(v) for v in value] else: value = decimal.Decimal(value) + + key = check_array_indexer(self, key) self._data[key] = value def __len__(self) -> int: