diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 3e92906be706c..304eeac87f64d 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -232,7 +232,7 @@ def trans(x): return result -def maybe_upcast_putmask(result, mask, other): +def maybe_upcast_putmask(result: np.ndarray, mask: np.ndarray, other): """ A safe version of putmask that potentially upcasts the result. The result is replaced with the first N elements of other, @@ -245,8 +245,8 @@ def maybe_upcast_putmask(result, mask, other): The destination array. This will be mutated in-place if no upcasting is necessary. mask : boolean ndarray - other : ndarray or scalar - The source array or value + other : scalar + The source value Returns ------- @@ -264,6 +264,10 @@ def maybe_upcast_putmask(result, mask, other): if not isinstance(result, np.ndarray): raise ValueError("The result input must be a ndarray.") + if not is_scalar(other): + # We _could_ support non-scalar other, but until we have a compelling + # use case, we assume away the possibility. + raise ValueError("other must be a scalar") if mask.any(): # Two conversions for date-like dtypes that can't be done automatically diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index b9267db76e1a8..5d4f75d93db1c 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -273,6 +273,12 @@ def _get_values( fill_value : Any fill value used """ + + # In _get_values is only called from within nanops, and in all cases + # with scalar fill_value. This guarantee is important for the + # maybe_upcast_putmask call below + assert is_scalar(fill_value) + mask = _maybe_get_mask(values, skipna, mask) if is_datetime64tz_dtype(values): diff --git a/pandas/tests/dtypes/cast/test_upcast.py b/pandas/tests/dtypes/cast/test_upcast.py index b22ed0bcd0a11..49e850f3e87b5 100644 --- a/pandas/tests/dtypes/cast/test_upcast.py +++ b/pandas/tests/dtypes/cast/test_upcast.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("result", [Series([10, 11, 12]), [10, 11, 12], (10, 11, 12)]) def test_upcast_error(result): - # GH23823 + # GH23823 require result arg to be ndarray mask = np.array([False, True, False]) other = np.array([61, 62, 63]) with pytest.raises(ValueError): @@ -17,76 +17,55 @@ def test_upcast_error(result): @pytest.mark.parametrize( - "arr, other, exp_changed, expected", + "arr, other", [ - (np.arange(1, 6), np.array([61, 62, 63]), False, np.array([1, 61, 3, 62, 63])), + (np.arange(1, 6), np.array([61, 62, 63])), + (np.arange(1, 6), np.array([61.1, 62.2, 63.3])), + (np.arange(10, 15), np.array([61, 62])), + (np.arange(10, 15), np.array([61, np.nan])), ( - np.arange(1, 6), - np.array([61.1, 62.2, 63.3]), - True, - np.array([1, 61.1, 3, 62.2, 63.3]), + np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"), + np.arange("2018-01-01", "2018-01-04", dtype="datetime64[D]"), ), - (np.arange(1, 6), np.nan, True, np.array([1, np.nan, 3, np.nan, np.nan])), - (np.arange(10, 15), np.array([61, 62]), False, np.array([10, 61, 12, 62, 61])), ( - np.arange(10, 15), - np.array([61, np.nan]), - True, - np.array([10, 61, 12, np.nan, 61]), + np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"), + np.arange("2018-01-01", "2018-01-03", dtype="datetime64[D]"), ), ], ) -def test_upcast(arr, other, exp_changed, expected): +def test_upcast_scalar_other(arr, other): + # for now we do not support non-scalar `other` + mask = np.array([False, True, False, True, True]) + with pytest.raises(ValueError, match="other must be a scalar"): + maybe_upcast_putmask(arr, mask, other) + + +def test_upcast(): # GH23823 + arr = np.arange(1, 6) mask = np.array([False, True, False, True, True]) - result, changed = maybe_upcast_putmask(arr, mask, other) + result, changed = maybe_upcast_putmask(arr, mask, other=np.nan) - assert changed == exp_changed + expected = np.array([1, np.nan, 3, np.nan, np.nan]) + assert changed tm.assert_numpy_array_equal(result, expected) -@pytest.mark.parametrize( - "arr, other, exp_changed, expected", - [ - ( - np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"), - np.arange("2018-01-01", "2018-01-04", dtype="datetime64[D]"), - False, - np.array( - ["2019-01-01", "2018-01-01", "2019-01-03", "2018-01-02", "2018-01-03"], - dtype="datetime64[D]", - ), - ), - ( - np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"), - np.nan, - False, - np.array( - [ - "2019-01-01", - np.datetime64("NaT"), - "2019-01-03", - np.datetime64("NaT"), - np.datetime64("NaT"), - ], - dtype="datetime64[D]", - ), - ), - ( - np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]"), - np.arange("2018-01-01", "2018-01-03", dtype="datetime64[D]"), - False, - np.array( - ["2019-01-01", "2018-01-01", "2019-01-03", "2018-01-02", "2018-01-01"], - dtype="datetime64[D]", - ), - ), - ], -) -def test_upcast_datetime(arr, other, exp_changed, expected): +def test_upcast_datetime(): # GH23823 + arr = np.arange("2019-01-01", "2019-01-06", dtype="datetime64[D]") mask = np.array([False, True, False, True, True]) - result, changed = maybe_upcast_putmask(arr, mask, other) + result, changed = maybe_upcast_putmask(arr, mask, other=np.nan) - assert changed == exp_changed + expected = np.array( + [ + "2019-01-01", + np.datetime64("NaT"), + "2019-01-03", + np.datetime64("NaT"), + np.datetime64("NaT"), + ], + dtype="datetime64[D]", + ) + assert not changed tm.assert_numpy_array_equal(result, expected)