diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 1823a8e8654fd..6ac7fdd2434c7 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -169,7 +169,10 @@ def trans(x): # noqa def maybe_upcast_putmask(result, mask, other): """ - A safe version of putmask that potentially upcasts the result + A safe version of putmask that potentially upcasts the result. + The result is replaced with the first N elements of other, + where N is the number of True values in mask. + If the length of other is shorter than N, other will be repeated. Parameters ---------- @@ -185,8 +188,18 @@ def maybe_upcast_putmask(result, mask, other): result : ndarray changed : boolean Set to true if the result array was upcasted + + Examples + -------- + >>> result, _ = maybe_upcast_putmask(np.arange(1,6), + np.array([False, True, False, True, True]), np.arange(21,23)) + >>> result + array([1, 21, 3, 22, 21]) """ + if not isinstance(result, np.ndarray): + raise ValueError("The result input must be a ndarray.") + if mask.any(): # Two conversions for date-like dtypes that can't be done automatically # in np.place: @@ -241,7 +254,7 @@ def changeit(): # we have an ndarray and the masking has nans in it else: - if isna(other[mask]).any(): + if isna(other).any(): return changeit() try: diff --git a/pandas/tests/dtypes/cast/test_upcast.py b/pandas/tests/dtypes/cast/test_upcast.py new file mode 100644 index 0000000000000..074e89274cc88 --- /dev/null +++ b/pandas/tests/dtypes/cast/test_upcast.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import pytest + +from pandas.core.dtypes.cast import maybe_upcast_putmask + +from pandas import Series +from pandas.util import testing as tm + + +@pytest.mark.parametrize("result", [ + Series([10, 11, 12]), + [10, 11, 12], + (10, 11, 12) +]) +def test_upcast_error(result): + # GH23823 + mask = np.array([False, True, False]) + other = np.array([61, 62, 63]) + with pytest.raises(ValueError): + result, _ = maybe_upcast_putmask(result, mask, other) + + +@pytest.mark.parametrize("arr, other, exp_changed, expected", [ + (np.arange(1, 6), np.array([61, 62, 63]), + False, np.array([1, 61, 3, 62, 63])), + (np.arange(1, 6), np.array([61.1, 62.2, 63.3]), + True, np.array([1, 61.1, 3, 62.2, 63.3])), + (np.arange(1, 6), np.nan, + True, np.array([1, np.nan, 3, np.nan, np.nan])), + (np.arange(10, 15), np.array([61, 62]), + False, np.array([10, 61, 12, 62, 61])), + (np.arange(10, 15), np.array([61, np.nan]), + True, np.array([10, 61, 12, np.nan, 61])) +]) +def test_upcast(arr, other, exp_changed, expected): + # GH23823 + mask = np.array([False, True, False, True, True]) + result, changed = maybe_upcast_putmask(arr, mask, other) + + assert changed == exp_changed + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("arr, other, exp_changed, expected", [ + (np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'), + np.arange('2018-01-01', '2018-01-04', dtype='datetime64[D]'), + False, np.array(['2019-01-01', '2018-01-01', '2019-01-03', + '2018-01-02', '2018-01-03'], dtype='datetime64[D]')), + (np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'), np.nan, + False, np.array(['2019-01-01', np.datetime64('NaT'), + '2019-01-03', np.datetime64('NaT'), + np.datetime64('NaT')], dtype='datetime64[D]')), + (np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'), + np.arange('2018-01-01', '2018-01-03', dtype='datetime64[D]'), + False, np.array(['2019-01-01', '2018-01-01', '2019-01-03', + '2018-01-02', '2018-01-01'], dtype='datetime64[D]')) +]) +def test_upcast_datetime(arr, other, exp_changed, expected): + # GH23823 + mask = np.array([False, True, False, True, True]) + result, changed = maybe_upcast_putmask(arr, mask, other) + + assert changed == exp_changed + tm.assert_numpy_array_equal(result, expected)