Skip to content

[BUG] maybe_upcast_putmast also handle ndarray #25431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def trans(x): # noqa

def maybe_upcast_putmask(result, mask, other):
"""
A safe version of putmask that potentially upcasts the result
A safe version of putmask that potentially upcasts the result.
The result is replaced with the first N elements of other,
where N is the number of True values in mask.
If the length of other is shorter than N, other will be repeated.

Parameters
----------
Expand All @@ -185,8 +188,18 @@ def maybe_upcast_putmask(result, mask, other):
result : ndarray
changed : boolean
Set to true if the result array was upcasted

Examples
--------
>>> result, _ = maybe_upcast_putmask(np.arange(1,6),
np.array([False, True, False, True, True]), np.arange(21,23))
>>> result
array([1, 21, 3, 22, 21])
"""

if not isinstance(result, np.ndarray):
raise ValueError("The result input must be a ndarray.")

if mask.any():
# Two conversions for date-like dtypes that can't be done automatically
# in np.place:
Expand Down Expand Up @@ -241,7 +254,7 @@ def changeit():
# we have an ndarray and the masking has nans in it
else:

if isna(other[mask]).any():
if isna(other).any():
return changeit()

try:
Expand Down
62 changes: 62 additions & 0 deletions pandas/tests/dtypes/cast/test_upcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-

import numpy as np
import pytest

from pandas.core.dtypes.cast import maybe_upcast_putmask

from pandas import Series
from pandas.util import testing as tm


def test_upcast_error():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there more cases to consider? e.g. non-ndarray

# GH23823
ser = Series([10, 11, 12])
mask = np.array([False, True, False])
other = np.array([61, 62, 63])
with pytest.raises(ValueError):
result, _ = maybe_upcast_putmask(ser, mask, other)


@pytest.mark.parametrize("arr, other, exp_changed, expected", [
(np.arange(1, 6), np.array([61, 62, 63]),
False, np.array([1, 61, 3, 62, 63])),
(np.arange(1, 6), np.array([61.1, 62.2, 63.3]),
True, np.array([1, 61.1, 3, 62.2, 63.3])),
(np.arange(1, 6), np.nan,
True, np.array([1, np.nan, 3, np.nan, np.nan])),
(np.arange(10, 15), np.array([61, 62]),
False, np.array([10, 61, 12, 62, 61])),
(np.arange(10, 15), np.array([61, np.nan]),
True, np.array([10, 61, 12, np.nan, 61]))
])
def test_upcast(arr, other, exp_changed, expected):
# GH23823
mask = np.array([False, True, False, True, True])
result, changed = maybe_upcast_putmask(arr, mask, other)

assert changed == exp_changed
tm.assert_numpy_array_equal(result, expected)


@pytest.mark.parametrize("arr, other, exp_changed, expected", [
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'),
np.arange('2018-01-01', '2018-01-04', dtype='datetime64[D]'),
False, np.array(['2019-01-01', '2018-01-01', '2019-01-03',
'2018-01-02', '2018-01-03'], dtype='datetime64[D]')),
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'), np.nan,
False, np.array(['2019-01-01', np.datetime64('NaT'),
'2019-01-03', np.datetime64('NaT'),
np.datetime64('NaT')], dtype='datetime64[D]')),
(np.arange('2019-01-01', '2019-01-06', dtype='datetime64[D]'),
np.arange('2018-01-01', '2018-01-03', dtype='datetime64[D]'),
False, np.array(['2019-01-01', '2018-01-01', '2019-01-03',
'2018-01-02', '2018-01-01'], dtype='datetime64[D]'))
])
def test_upcast_datetime(arr, other, exp_changed, expected):
# GH23823
mask = np.array([False, True, False, True, True])
result, changed = maybe_upcast_putmask(arr, mask, other)

assert changed == exp_changed
tm.assert_numpy_array_equal(result, expected)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect these tests already have a home in another, existing test file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, reference the issue number as a comment under each test definition.