Skip to content

Commit f2bcb9f

Browse files
makbigcjreback
authored andcommitted
[BUG] maybe_upcast_putmast also handle ndarray (#25431)
1 parent aab42a2 commit f2bcb9f

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

pandas/core/dtypes/cast.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ def trans(x): # noqa
169169

170170
def maybe_upcast_putmask(result, mask, other):
171171
"""
172-
A safe version of putmask that potentially upcasts the result
172+
A safe version of putmask that potentially upcasts the result.
173+
The result is replaced with the first N elements of other,
174+
where N is the number of True values in mask.
175+
If the length of other is shorter than N, other will be repeated.
173176
174177
Parameters
175178
----------
@@ -185,8 +188,18 @@ def maybe_upcast_putmask(result, mask, other):
185188
result : ndarray
186189
changed : boolean
187190
Set to true if the result array was upcasted
191+
192+
Examples
193+
--------
194+
>>> result, _ = maybe_upcast_putmask(np.arange(1,6),
195+
np.array([False, True, False, True, True]), np.arange(21,23))
196+
>>> result
197+
array([1, 21, 3, 22, 21])
188198
"""
189199

200+
if not isinstance(result, np.ndarray):
201+
raise ValueError("The result input must be a ndarray.")
202+
190203
if mask.any():
191204
# Two conversions for date-like dtypes that can't be done automatically
192205
# in np.place:
@@ -241,7 +254,7 @@ def changeit():
241254
# we have an ndarray and the masking has nans in it
242255
else:
243256

244-
if isna(other[mask]).any():
257+
if isna(other).any():
245258
return changeit()
246259

247260
try:
+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import numpy as np
4+
import pytest
5+
6+
from pandas.core.dtypes.cast import maybe_upcast_putmask
7+
8+
from pandas import Series
9+
from pandas.util import testing as tm
10+
11+
12+
@pytest.mark.parametrize("result", [
13+
Series([10, 11, 12]),
14+
[10, 11, 12],
15+
(10, 11, 12)
16+
])
17+
def test_upcast_error(result):
18+
# GH23823
19+
mask = np.array([False, True, False])
20+
other = np.array([61, 62, 63])
21+
with pytest.raises(ValueError):
22+
result, _ = maybe_upcast_putmask(result, mask, other)
23+
24+
25+
@pytest.mark.parametrize("arr, other, exp_changed, expected", [
26+
(np.arange(1, 6), np.array([61, 62, 63]),
27+
False, np.array([1, 61, 3, 62, 63])),
28+
(np.arange(1, 6), np.array([61.1, 62.2, 63.3]),
29+
True, np.array([1, 61.1, 3, 62.2, 63.3])),
30+
(np.arange(1, 6), np.nan,
31+
True, np.array([1, np.nan, 3, np.nan, np.nan])),
32+
(np.arange(10, 15), np.array([61, 62]),
33+
False, np.array([10, 61, 12, 62, 61])),
34+
(np.arange(10, 15), np.array([61, np.nan]),
35+
True, np.array([10, 61, 12, np.nan, 61]))
36+
])
37+
def test_upcast(arr, other, exp_changed, expected):
38+
# GH23823
39+
mask = np.array([False, True, False, True, True])
40+
result, changed = maybe_upcast_putmask(arr, mask, other)
41+
42+
assert changed == exp_changed
43+
tm.assert_numpy_array_equal(result, expected)
44+
45+
46+
@pytest.mark.parametrize("arr, other, exp_changed, expected", [
47+
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'),
48+
np.arange('2018-01-01', '2018-01-04', dtype='datetime64[D]'),
49+
False, np.array(['2019-01-01', '2018-01-01', '2019-01-03',
50+
'2018-01-02', '2018-01-03'], dtype='datetime64[D]')),
51+
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'), np.nan,
52+
False, np.array(['2019-01-01', np.datetime64('NaT'),
53+
'2019-01-03', np.datetime64('NaT'),
54+
np.datetime64('NaT')], dtype='datetime64[D]')),
55+
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'),
56+
np.arange('2018-01-01', '2018-01-03', dtype='datetime64[D]'),
57+
False, np.array(['2019-01-01', '2018-01-01', '2019-01-03',
58+
'2018-01-02', '2018-01-01'], dtype='datetime64[D]'))
59+
])
60+
def test_upcast_datetime(arr, other, exp_changed, expected):
61+
# GH23823
62+
mask = np.array([False, True, False, True, True])
63+
result, changed = maybe_upcast_putmask(arr, mask, other)
64+
65+
assert changed == exp_changed
66+
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)