Skip to content

Commit 71c94af

Browse files
authored
BUG/PERF: MaskedArray.__setitem__ validation (#45404)
1 parent 2e5b05e commit 71c94af

File tree

6 files changed

+104
-22
lines changed

6 files changed

+104
-22
lines changed

pandas/core/arrays/masked.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from pandas.core.dtypes.inference import is_array_like
5656
from pandas.core.dtypes.missing import (
5757
array_equivalent,
58+
is_valid_na_for_dtype,
5859
isna,
5960
notna,
6061
)
@@ -233,17 +234,45 @@ def _coerce_to_array(
233234
) -> tuple[np.ndarray, np.ndarray]:
234235
raise AbstractMethodError(cls)
235236

236-
def __setitem__(self, key, value) -> None:
237-
_is_scalar = is_scalar(value)
238-
if _is_scalar:
239-
value = [value]
240-
value, mask = self._coerce_to_array(value, dtype=self.dtype)
237+
def _validate_setitem_value(self, value):
238+
"""
239+
Check if we have a scalar that we can cast losslessly.
240+
241+
Raises
242+
------
243+
TypeError
244+
"""
245+
kind = self.dtype.kind
246+
# TODO: get this all from np_can_hold_element?
247+
if kind == "b":
248+
if lib.is_bool(value):
249+
return value
250+
251+
elif kind == "f":
252+
if lib.is_integer(value) or lib.is_float(value):
253+
return value
254+
255+
else:
256+
if lib.is_integer(value) or (lib.is_float(value) and value.is_integer()):
257+
return value
258+
# TODO: unsigned checks
241259

242-
if _is_scalar:
243-
value = value[0]
244-
mask = mask[0]
260+
raise TypeError(f"Invalid value '{value}' for dtype {self.dtype}")
245261

262+
def __setitem__(self, key, value) -> None:
246263
key = check_array_indexer(self, key)
264+
265+
if is_scalar(value):
266+
if is_valid_na_for_dtype(value, self.dtype):
267+
self._mask[key] = True
268+
else:
269+
value = self._validate_setitem_value(value)
270+
self._data[key] = value
271+
self._mask[key] = False
272+
return
273+
274+
value, mask = self._coerce_to_array(value, dtype=self.dtype)
275+
247276
self._data[key] = value
248277
self._mask[key] = mask
249278

pandas/core/dtypes/missing.py

+3
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,9 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool:
648648
elif dtype.kind in ["i", "u", "f", "c"]:
649649
# Numeric
650650
return obj is not NaT and not isinstance(obj, (np.datetime64, np.timedelta64))
651+
elif dtype.kind == "b":
652+
# We allow pd.NA, None, np.nan in BooleanArray (same as IntervalDtype)
653+
return lib.is_float(obj) or obj is None or obj is libmissing.NA
651654

652655
elif dtype == _dtype_str:
653656
# numpy string dtypes to avoid float np.nan

pandas/io/stata.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,8 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
593593
missing_loc = data[col].isna()
594594
if missing_loc.any():
595595
# Replace with always safe value
596-
data.loc[missing_loc, col] = 0
596+
fv = 0 if isinstance(data[col].dtype, _IntegerDtype) else False
597+
data.loc[missing_loc, col] = fv
597598
# Replace with NumPy-compatible column
598599
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
599600
dtype = data[col].dtype
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import re
2+
3+
import numpy as np
4+
import pytest
5+
6+
import pandas as pd
7+
8+
9+
class TestSetitemValidation:
10+
def _check_setitem_invalid(self, arr, invalid):
11+
msg = f"Invalid value '{str(invalid)}' for dtype {arr.dtype}"
12+
msg = re.escape(msg)
13+
with pytest.raises(TypeError, match=msg):
14+
arr[0] = invalid
15+
16+
with pytest.raises(TypeError, match=msg):
17+
arr[:] = invalid
18+
19+
with pytest.raises(TypeError, match=msg):
20+
arr[[0]] = invalid
21+
22+
# FIXME: don't leave commented-out
23+
# with pytest.raises(TypeError):
24+
# arr[[0]] = [invalid]
25+
26+
# with pytest.raises(TypeError):
27+
# arr[[0]] = np.array([invalid], dtype=object)
28+
29+
# Series non-coercion, behavior subject to change
30+
ser = pd.Series(arr)
31+
with pytest.raises(TypeError, match=msg):
32+
ser[0] = invalid
33+
# TODO: so, so many other variants of this...
34+
35+
_invalid_scalars = [
36+
1 + 2j,
37+
"True",
38+
"1",
39+
"1.0",
40+
pd.NaT,
41+
np.datetime64("NaT"),
42+
np.timedelta64("NaT"),
43+
]
44+
45+
@pytest.mark.parametrize(
46+
"invalid", _invalid_scalars + [1, 1.0, np.int64(1), np.float64(1)]
47+
)
48+
def test_setitem_validation_scalar_bool(self, invalid):
49+
arr = pd.array([True, False, None], dtype="boolean")
50+
self._check_setitem_invalid(arr, invalid)
51+
52+
@pytest.mark.parametrize("invalid", _invalid_scalars + [True, 1.5, np.float64(1.5)])
53+
def test_setitem_validation_scalar_int(self, invalid, any_int_ea_dtype):
54+
arr = pd.array([1, 2, None], dtype=any_int_ea_dtype)
55+
self._check_setitem_invalid(arr, invalid)
56+
57+
@pytest.mark.parametrize("invalid", _invalid_scalars + [True])
58+
def test_setitem_validation_scalar_float(self, invalid, float_ea_dtype):
59+
arr = pd.array([1, 2, None], dtype=float_ea_dtype)
60+
self._check_setitem_invalid(arr, invalid)

pandas/tests/frame/indexing/test_indexing.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1240,12 +1240,10 @@ def test_setting_mismatched_na_into_nullable_fails(
12401240

12411241
msg = "|".join(
12421242
[
1243-
r"int\(\) argument must be a string, a bytes-like object or a "
1244-
"(real )?number, not 'NaTType'",
12451243
r"timedelta64\[ns\] cannot be converted to an? (Floating|Integer)Dtype",
12461244
r"datetime64\[ns\] cannot be converted to an? (Floating|Integer)Dtype",
1247-
"object cannot be converted to a FloatingDtype",
12481245
"'values' contains non-numeric NA",
1246+
r"Invalid value '.*' for dtype (U?Int|Float)\d{1,2}",
12491247
]
12501248
)
12511249
with pytest.raises(TypeError, match=msg):

pandas/tests/frame/indexing/test_where.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -907,16 +907,7 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
907907

908908
mask = np.array([True, True, False], ndmin=obj.ndim).T
909909

910-
msg = "|".join(
911-
[
912-
r"datetime64\[.{1,2}\] cannot be converted to an? (Integer|Floating)Dtype",
913-
r"timedelta64\[.{1,2}\] cannot be converted to an? (Integer|Floating)Dtype",
914-
r"int\(\) argument must be a string, a bytes-like object or a number, "
915-
"not 'NaTType'",
916-
"object cannot be converted to a FloatingDtype",
917-
"'values' contains non-numeric NA",
918-
]
919-
)
910+
msg = r"Invalid value '.*' for dtype (U?Int|Float)\d{1,2}"
920911

921912
for null in tm.NP_NAT_OBJECTS + [pd.NaT]:
922913
# NaT is an NA value that we should *not* cast to pd.NA dtype

0 commit comments

Comments
 (0)