-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
BUG/PERF: Avoid listifying in dispatch_to_extension_op #23155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
044a99e
fe693d6
cf945ee
6507f43
b7906ea
0125669
f1dd665
07a632e
e378c7d
b2f9243
42b91d0
f03e66b
f14cf0b
32757f6
4fc1d1b
03a367e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,6 +280,8 @@ def _coerce_to_ndarray(self): | |
data[self._mask] = self._na_value | ||
return data | ||
|
||
__array_priority__ = 1 # higher than ndarray so ops dispatch to us | ||
|
||
def __array__(self, dtype=None): | ||
""" | ||
the array interface, return my values | ||
|
@@ -288,12 +290,6 @@ def __array__(self, dtype=None): | |
return self._coerce_to_ndarray() | ||
|
||
def __iter__(self): | ||
"""Iterate over elements of the array. | ||
|
||
""" | ||
# This needs to be implemented so that pandas recognizes extension | ||
# arrays as list-like. The default implementation makes successive | ||
# calls to ``__getitem__``, which may be slower than necessary. | ||
jorisvandenbossche marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i in range(len(self)): | ||
if self._mask[i]: | ||
yield self.dtype.na_value | ||
|
@@ -504,8 +500,13 @@ def cmp_method(self, other): | |
|
||
op_name = op.__name__ | ||
mask = None | ||
|
||
if isinstance(other, IntegerArray): | ||
other, mask = other._data, other._mask | ||
|
||
elif getattr(other, 'ndim', None) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we usually use |
||
other = other.item() | ||
|
||
elif is_list_like(other): | ||
other = np.asarray(other) | ||
if other.ndim > 0 and len(self) != len(other): | ||
|
@@ -586,14 +587,20 @@ def integer_arithmetic_method(self, other): | |
|
||
op_name = op.__name__ | ||
mask = None | ||
|
||
if getattr(other, 'ndim', 0) > 1: | ||
raise NotImplementedError( | ||
"can only perform ops with 1-d structures") | ||
|
||
if isinstance(other, (ABCSeries, ABCIndexClass)): | ||
other = getattr(other, 'values', other) | ||
|
||
if isinstance(other, IntegerArray): | ||
other, mask = other._data, other._mask | ||
elif getattr(other, 'ndim', 0) > 1: | ||
raise NotImplementedError( | ||
"can only perform ops with 1-d structures") | ||
|
||
elif getattr(other, 'ndim', None) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is moved from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Had to keep this one in an elif, so that we avoid the |
||
other = other.item() | ||
|
||
elif is_list_like(other): | ||
other = np.asarray(other) | ||
if not other.ndim: | ||
|
@@ -612,6 +619,10 @@ def integer_arithmetic_method(self, other): | |
else: | ||
mask = self._mask | mask | ||
|
||
if op_name == 'rpow': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does pow just work? |
||
# 1 ** np.nan is 1. So we have to unmask those. | ||
mask = np.where(other == 1, False, mask) | ||
|
||
with np.errstate(all='ignore'): | ||
result = op(self._data, other) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -862,10 +862,16 @@ def masked_arith_op(x, y, op): | |
# mask is only meaningful for x | ||
result = np.empty(x.size, dtype=x.dtype) | ||
mask = notna(xrav) | ||
|
||
if mask.any(): | ||
with np.errstate(all='ignore'): | ||
result[mask] = op(xrav[mask], y) | ||
|
||
if op == pow: | ||
result = np.where(~mask, x, result) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add the same comments as you have above here (e.g. 1 ** np.nan...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I reworked this to update the mask for both |
||
elif op == rpow: | ||
result = np.where(~mask, y, result) | ||
|
||
result, changed = maybe_upcast_putmask(result, ~mask, np.nan) | ||
result = result.reshape(x.shape) # 2D compat | ||
return result | ||
|
@@ -1202,29 +1208,16 @@ def dispatch_to_extension_op(op, left, right): | |
|
||
# The op calls will raise TypeError if the op is not defined | ||
# on the ExtensionArray | ||
# TODO(jreback) | ||
# we need to listify to avoid ndarray, or non-same-type extension array | ||
# dispatching | ||
|
||
if is_extension_array_dtype(left): | ||
|
||
new_left = left.values | ||
if isinstance(right, np.ndarray): | ||
|
||
# handle numpy scalars, this is a PITA | ||
# TODO(jreback) | ||
new_right = lib.item_from_zerodim(right) | ||
if is_scalar(new_right): | ||
new_right = [new_right] | ||
new_right = list(new_right) | ||
elif is_extension_array_dtype(right) and type(left) != type(right): | ||
new_right = list(right) | ||
else: | ||
new_right = right | ||
|
||
# unbox Series and Index to arrays | ||
if isinstance(left, (ABCSeries, ABCIndexClass)): | ||
new_left = left._values | ||
else: | ||
new_left = left | ||
|
||
new_left = list(left.values) | ||
if isinstance(right, (ABCSeries, ABCIndexClass)): | ||
new_right = right._values | ||
else: | ||
new_right = right | ||
|
||
res_values = op(new_left, new_right) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,25 +218,68 @@ def test_arith_integer_array(self, data, all_arithmetic_operators): | |
def test_arith_series_with_scalar(self, data, all_arithmetic_operators): | ||
# scalar | ||
op = all_arithmetic_operators | ||
if op == '__rpow__': | ||
raise pytest.skip("__rpow__ tested separately.") | ||
|
||
s = pd.Series(data) | ||
self._check_op(s, op, 1, exc=TypeError) | ||
|
||
def test_arith_series_with_scalar_rpow(self, data): | ||
s = pd.Series(data) | ||
# 1^x is 1.0 for all x, so test separately | ||
result = 1 ** s | ||
expected = pd.Series(1, index=s.index, dtype=data.dtype.numpy_dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is adding a lot of duplicated code to test, can you use _check_op here for the 1 case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I started down that route, but couldn't make it work. I found it hard to follow all the indirection. I'm ok with some duplicated code in these tests, to make it clearer what's actually being tested. |
||
assert result.dtype == data.dtype | ||
result = result.astype(data.dtype.numpy_dtype) | ||
self.assert_series_equal(result, expected) | ||
|
||
# test other bases regularly | ||
self._check_op(s, '__rpow__', 2, exc=None) | ||
|
||
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators): | ||
# frame & scalar | ||
op = all_arithmetic_operators | ||
|
||
if op == '__rpow__': | ||
raise pytest.skip("__rpow__ tested separately.") | ||
|
||
df = pd.DataFrame({'A': data}) | ||
self._check_op(df, op, 1, exc=TypeError) | ||
|
||
def test_arith_frame_with_scalar_rpow(self, data): | ||
df = pd.DataFrame({"A": data}) | ||
result = 1.0 ** df | ||
expected = pd.DataFrame(1.0, index=df.index, columns=df.columns) | ||
self.assert_frame_equal(result, expected) | ||
|
||
# test other bases regularly | ||
self._check_op(df, '__rpow__', 2, exc=TypeError) | ||
|
||
def test_arith_series_with_array(self, data, all_arithmetic_operators): | ||
# ndarray & other series | ||
op = all_arithmetic_operators | ||
|
||
if op == '__rpow__': | ||
raise pytest.skip("__rpow__ tested separately.") | ||
|
||
s = pd.Series(data) | ||
other = np.ones(len(s), dtype=s.dtype.type) | ||
self._check_op(s, op, other, exc=TypeError) | ||
|
||
def test_arith_series_with_array_rpow(self, data): | ||
s = pd.Series(data) | ||
other = np.ones(len(s), dtype=data.dtype.numpy_dtype) | ||
expected = pd.Series(1, index=s.index, dtype=data.dtype.numpy_dtype) | ||
result = other ** s | ||
|
||
assert result.dtype == data.dtype | ||
result = result.astype(data.dtype.numpy_dtype) | ||
|
||
self.assert_series_equal(result, expected) | ||
|
||
other = 2 * np.ones(len(s), dtype=s.dtype.type) | ||
self._check_op(s, '__rpow__', other, exc=TypeError) | ||
|
||
def test_arith_coerce_scalar(self, data, all_arithmetic_operators): | ||
|
||
op = all_arithmetic_operators | ||
|
@@ -248,13 +291,20 @@ def test_arith_coerce_scalar(self, data, all_arithmetic_operators): | |
@pytest.mark.parametrize("other", [1., 1.0, np.array(1.), np.array([1.])]) | ||
def test_arithmetic_conversion(self, all_arithmetic_operators, other): | ||
# if we have a float operand we should have a float result | ||
# if if that is equal to an integer | ||
# if that is equal to an integer | ||
op = self.get_op_from_name(all_arithmetic_operators) | ||
|
||
s = pd.Series([1, 2, 3], dtype='Int64') | ||
result = op(s, other) | ||
assert result.dtype is np.dtype('float') | ||
|
||
@pytest.mark.parametrize("other", [0, 0.5]) | ||
def test_arith_zero_dim_ndarray(self, other): | ||
arr = integer_array([1, None, 2]) | ||
result = arr + np.array(other) | ||
expected = arr + other | ||
tm.assert_equal(result, expected) | ||
|
||
def test_error(self, data, all_arithmetic_operators): | ||
# invalid ops | ||
|
||
|
@@ -323,6 +373,14 @@ def test_compare_array(self, data, all_compare_operators): | |
other = pd.Series([0] * len(data)) | ||
self._compare_other(s, data, op_name, other) | ||
|
||
def test_rpow_one_to_na(self): | ||
# https://github.com/pandas-dev/pandas/issues/22022 | ||
# NumPy says 1 ** nan is 1. | ||
arr = integer_array([np.nan, np.nan]) | ||
result = np.array([1.0, 2.0]) ** arr | ||
expected = np.array([1.0, np.nan]) | ||
tm.assert_numpy_array_equal(result, expected) | ||
|
||
|
||
class TestCasting(object): | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Index, Series, and now DataFrame all set this to 1000. Does differing from those matter?