Skip to content

Commit 023fa0c

Browse files
jbrockmendeljreback
authored andcommitted
maybe_upcast_putmask: require other to be a scalar (#29332)
1 parent e9046cc commit 023fa0c

File tree

3 files changed

+49
-60
lines changed

3 files changed

+49
-60
lines changed

pandas/core/dtypes/cast.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def trans(x):
232232
return result
233233

234234

235-
def maybe_upcast_putmask(result, mask, other):
235+
def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other):
236236
"""
237237
A safe version of putmask that potentially upcasts the result.
238238
The result is replaced with the first N elements of other,
@@ -245,8 +245,8 @@ def maybe_upcast_putmask(result, mask, other):
245245
The destination array. This will be mutated in-place if no upcasting is
246246
necessary.
247247
mask : boolean ndarray
248-
other : ndarray or scalar
249-
The source array or value
248+
other : scalar
249+
The source value
250250
251251
Returns
252252
-------
@@ -264,6 +264,10 @@ def maybe_upcast_putmask(result, mask, other):
264264

265265
if not isinstance(result, np.ndarray):
266266
raise ValueError("The result input must be a ndarray.")
267+
if not is_scalar(other):
268+
# We _could_ support non-scalar other, but until we have a compelling
269+
# use case, we assume away the possibility.
270+
raise ValueError("other must be a scalar")
267271

268272
if mask.any():
269273
# Two conversions for date-like dtypes that can't be done automatically

pandas/core/nanops.py

+6
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ def _get_values(
273273
fill_value : Any
274274
fill value used
275275
"""
276+
277+
# In _get_values is only called from within nanops, and in all cases
278+
# with scalar fill_value. This guarantee is important for the
279+
# maybe_upcast_putmask call below
280+
assert is_scalar(fill_value)
281+
276282
mask = _maybe_get_mask(values, skipna, mask)
277283

278284
if is_datetime64tz_dtype(values):

pandas/tests/dtypes/cast/test_upcast.py

+36-57
Original file line numberDiff line numberDiff line change
@@ -9,84 +9,63 @@
99

1010
@pytest.mark.parametrize("result", [Series([10, 11, 12]), [10, 11, 12], (10, 11, 12)])
1111
def test_upcast_error(result):
12-
# GH23823
12+
# GH23823 require result arg to be ndarray
1313
mask = np.array([False, True, False])
1414
other = np.array([61, 62, 63])
1515
with pytest.raises(ValueError):
1616
result, _ = maybe_upcast_putmask(result, mask, other)
1717

1818

1919
@pytest.mark.parametrize(
20-
"arr, other, exp_changed, expected",
20+
"arr, other",
2121
[
22-
(np.arange(1, 6), np.array([61, 62, 63]), False, np.array([1, 61, 3, 62, 63])),
22+
(np.arange(1, 6), np.array([61, 62, 63])),
23+
(np.arange(1, 6), np.array([61.1, 62.2, 63.3])),
24+
(np.arange(10, 15), np.array([61, 62])),
25+
(np.arange(10, 15), np.array([61, np.nan])),
2326
(
24-
np.arange(1, 6),
25-
np.array([61.1, 62.2, 63.3]),
26-
True,
27-
np.array([1, 61.1, 3, 62.2, 63.3]),
27+
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
28+
np.arange("2018-01-01", "2018-01-04", dtype="datetime64[D]"),
2829
),
29-
(np.arange(1, 6), np.nan, True, np.array([1, np.nan, 3, np.nan, np.nan])),
30-
(np.arange(10, 15), np.array([61, 62]), False, np.array([10, 61, 12, 62, 61])),
3130
(
32-
np.arange(10, 15),
33-
np.array([61, np.nan]),
34-
True,
35-
np.array([10, 61, 12, np.nan, 61]),
31+
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
32+
np.arange("2018-01-01", "2018-01-03", dtype="datetime64[D]"),
3633
),
3734
],
3835
)
39-
def test_upcast(arr, other, exp_changed, expected):
36+
def test_upcast_scalar_other(arr, other):
37+
# for now we do not support non-scalar `other`
38+
mask = np.array([False, True, False, True, True])
39+
with pytest.raises(ValueError, match="other must be a scalar"):
40+
maybe_upcast_putmask(arr, mask, other)
41+
42+
43+
def test_upcast():
4044
# GH23823
45+
arr = np.arange(1, 6)
4146
mask = np.array([False, True, False, True, True])
42-
result, changed = maybe_upcast_putmask(arr, mask, other)
47+
result, changed = maybe_upcast_putmask(arr, mask, other=np.nan)
4348

44-
assert changed == exp_changed
49+
expected = np.array([1, np.nan, 3, np.nan, np.nan])
50+
assert changed
4551
tm.assert_numpy_array_equal(result, expected)
4652

4753

48-
@pytest.mark.parametrize(
49-
"arr, other, exp_changed, expected",
50-
[
51-
(
52-
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
53-
np.arange("2018-01-01", "2018-01-04", dtype="datetime64[D]"),
54-
False,
55-
np.array(
56-
["2019-01-01", "2018-01-01", "2019-01-03", "2018-01-02", "2018-01-03"],
57-
dtype="datetime64[D]",
58-
),
59-
),
60-
(
61-
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
62-
np.nan,
63-
False,
64-
np.array(
65-
[
66-
"2019-01-01",
67-
np.datetime64("NaT"),
68-
"2019-01-03",
69-
np.datetime64("NaT"),
70-
np.datetime64("NaT"),
71-
],
72-
dtype="datetime64[D]",
73-
),
74-
),
75-
(
76-
np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"),
77-
np.arange("2018-01-01", "2018-01-03", dtype="datetime64[D]"),
78-
False,
79-
np.array(
80-
["2019-01-01", "2018-01-01", "2019-01-03", "2018-01-02", "2018-01-01"],
81-
dtype="datetime64[D]",
82-
),
83-
),
84-
],
85-
)
86-
def test_upcast_datetime(arr, other, exp_changed, expected):
54+
def test_upcast_datetime():
8755
# GH23823
56+
arr = np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]")
8857
mask = np.array([False, True, False, True, True])
89-
result, changed = maybe_upcast_putmask(arr, mask, other)
58+
result, changed = maybe_upcast_putmask(arr, mask, other=np.nan)
9059

91-
assert changed == exp_changed
60+
expected = np.array(
61+
[
62+
"2019-01-01",
63+
np.datetime64("NaT"),
64+
"2019-01-03",
65+
np.datetime64("NaT"),
66+
np.datetime64("NaT"),
67+
],
68+
dtype="datetime64[D]",
69+
)
70+
assert not changed
9271
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)