Skip to content

Commit 338ea4d

Browse files
committed
ENH: add ops to extension array
1 parent a854f06 commit 338ea4d

File tree

7 files changed

+194
-11
lines changed

7 files changed

+194
-11
lines changed

pandas/conftest.py

+9
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,16 @@ def observed(request):
9292
def all_arithmetic_operators(request):
9393
"""
9494
Fixture for dunder names for common arithmetic operations
95+
"""
96+
return request.param
97+
98+
99+
@pytest.fixture(params=['__eq__', '__ne__', '__le__',
100+
'__lt__', '__ge__', '__gt__'])
101+
def all_compare_operators(request):
95102
"""
103+
Fixture for dunder names for common compare operations
104+
"""
96105
return request.param
97106

98107

pandas/core/arrays/base.py

+54
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99

1010
from pandas.errors import AbstractMethodError
1111
from pandas.compat.numpy import function as nv
12+
from pandas.compat import set_function_name, PY3
13+
from pandas.core import ops
14+
import operator
1215

1316
_not_implemented_message = "{} does not implement {}."
1417

@@ -652,3 +655,54 @@ def _ndarray_values(self):
652655
used for interacting with our indexers.
653656
"""
654657
return np.array(self)
658+
659+
# ------------------------------------------------------------------------
660+
# ops-related methods
661+
# ------------------------------------------------------------------------
662+
663+
@classmethod
664+
def _add_comparison_methods_binary(cls):
665+
cls.__eq__ = cls._make_comparison_op(operator.eq)
666+
cls.__ne__ = cls._make_comparison_op(operator.ne)
667+
cls.__lt__ = cls._make_comparison_op(operator.lt)
668+
cls.__gt__ = cls._make_comparison_op(operator.gt)
669+
cls.__le__ = cls._make_comparison_op(operator.le)
670+
cls.__ge__ = cls._make_comparison_op(operator.ge)
671+
672+
@classmethod
673+
def _add_numeric_methods_binary(cls):
674+
""" add in numeric methods """
675+
cls.__add__ = cls._make_arithmetic_op(operator.add)
676+
cls.__radd__ = cls._make_arithmetic_op(ops.radd)
677+
cls.__sub__ = cls._make_arithmetic_op(operator.sub)
678+
cls.__rsub__ = cls._make_arithmetic_op(ops.rsub)
679+
cls.__mul__ = cls._make_arithmetic_op(operator.mul)
680+
cls.__rmul__ = cls._make_arithmetic_op(ops.rmul)
681+
cls.__rpow__ = cls._make_arithmetic_op(ops.rpow)
682+
cls.__pow__ = cls._make_arithmetic_op(operator.pow)
683+
cls.__mod__ = cls._make_arithmetic_op(operator.mod)
684+
cls.__floordiv__ = cls._make_arithmetic_op(operator.floordiv)
685+
cls.__rfloordiv__ = cls._make_arithmetic_op(ops.rfloordiv)
686+
cls.__truediv__ = cls._make_arithmetic_op(operator.truediv)
687+
cls.__rtruediv__ = cls._make_arithmetic_op(ops.rtruediv)
688+
if not PY3:
689+
cls.__div__ = cls._make_arithmetic_op(operator.div)
690+
cls.__rdiv__ = cls._make_arithmetic_op(ops.rdiv)
691+
692+
cls.__divmod__ = cls._make_arithmetic_op(divmod)
693+
694+
@classmethod
695+
def make_comparison_op(cls, op):
696+
def cmp_method(self, other):
697+
raise NotImplementedError
698+
699+
name = '__{name}__'.format(name=op.__name__)
700+
return set_function_name(cmp_method, name, cls)
701+
702+
@classmethod
703+
def make_arithmetic_op(cls, op):
704+
def integer_arithmetic_method(self, other):
705+
raise NotImplementedError
706+
707+
name = '__{name}__'.format(name=op.__name__)
708+
return set_function_name(integer_arithmetic_method, name, cls)

pandas/core/missing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,8 @@ def fill_zeros(result, x, y, name, fill):
638638
# if we have a fill of inf, then sign it correctly
639639
# (GH 6178 and PR 9308)
640640
if np.isinf(fill):
641-
signs = np.sign(y if name.startswith(('r', '__r')) else x)
641+
signs = y if name.startswith(('r', '__r')) else x
642+
signs = np.sign(signs.astype('float', copy=False))
642643
negative_inf_mask = (signs.ravel() < 0) & mask
643644
np.putmask(result, negative_inf_mask, -fill)
644645

pandas/core/ops.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
is_integer_dtype, is_categorical_dtype,
2828
is_object_dtype, is_timedelta64_dtype,
2929
is_datetime64_dtype, is_datetime64tz_dtype,
30-
is_bool_dtype,
30+
is_bool_dtype, is_extension_array_dtype,
3131
is_list_like,
3232
is_scalar,
3333
_ensure_object)
@@ -1003,8 +1003,18 @@ def _arith_method_SERIES(cls, op, special):
10031003
if op is divmod else _construct_result)
10041004

10051005
def na_op(x, y):
1006-
import pandas.core.computation.expressions as expressions
1006+
# handle extension array ops
1007+
# TODO(extension)
1008+
# the ops *between* non-same-type extension arrays are not
1009+
# very well defined
1010+
if (is_extension_array_dtype(x) or is_extension_array_dtype(y)):
1011+
if (op_name.startswith('__r') and not
1012+
is_extension_array_dtype(y) and not
1013+
is_scalar(y)):
1014+
y = x.__class__._from_sequence(y)
1015+
return op(x, y)
10071016

1017+
import pandas.core.computation.expressions as expressions
10081018
try:
10091019
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
10101020
except TypeError:
@@ -1025,6 +1035,7 @@ def na_op(x, y):
10251035
return result
10261036

10271037
def safe_na_op(lvalues, rvalues):
1038+
# all others
10281039
try:
10291040
with np.errstate(all='ignore'):
10301041
return na_op(lvalues, rvalues)
@@ -1035,14 +1046,21 @@ def safe_na_op(lvalues, rvalues):
10351046
raise
10361047

10371048
def wrapper(left, right):
1038-
10391049
if isinstance(right, ABCDataFrame):
10401050
return NotImplemented
10411051

10421052
left, right = _align_method_SERIES(left, right)
10431053
res_name = get_op_result_name(left, right)
10441054

1045-
if is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
1055+
if is_categorical_dtype(left):
1056+
raise TypeError("{typ} cannot perform the operation "
1057+
"{op}".format(typ=type(left).__name__, op=str_rep))
1058+
1059+
elif (is_extension_array_dtype(left) or
1060+
is_extension_array_dtype(right)):
1061+
pass
1062+
1063+
elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
10461064
result = dispatch_to_index_op(op, left, right, pd.DatetimeIndex)
10471065
return construct_result(left, result,
10481066
index=left.index, name=res_name,
@@ -1054,10 +1072,6 @@ def wrapper(left, right):
10541072
index=left.index, name=res_name,
10551073
dtype=result.dtype)
10561074

1057-
elif is_categorical_dtype(left):
1058-
raise TypeError("{typ} cannot perform the operation "
1059-
"{op}".format(typ=type(left).__name__, op=str_rep))
1060-
10611075
lvalues = left.values
10621076
rvalues = right
10631077
if isinstance(rvalues, ABCSeries):
@@ -1136,6 +1150,14 @@ def na_op(x, y):
11361150
# The `not is_scalar(y)` check excludes the string "category"
11371151
return op(y, x)
11381152

1153+
# handle extension array ops
1154+
# TODO(extension)
1155+
# the ops *between* non-same-type extension arrays are not
1156+
# very well defined
1157+
elif (is_extension_array_dtype(x) or
1158+
is_extension_array_dtype(y)):
1159+
return op(x, y)
1160+
11391161
elif is_object_dtype(x.dtype):
11401162
result = _comp_method_OBJECT_ARRAY(op, x, y)
11411163

pandas/tests/extension/base/ops.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
import numpy as np
3+
import pandas as pd
4+
from .base import BaseExtensionTests
5+
6+
7+
class BaseOpsTests(BaseExtensionTests):
8+
"""Various Series and DataFrame ops methos."""
9+
10+
def compare(self, s, op, other, exc=NotImplementedError):
11+
12+
with pytest.raises(exc):
13+
getattr(s, op)(other)
14+
15+
def test_arith_scalar(self, data, all_arithmetic_operators):
16+
# scalar
17+
op = all_arithmetic_operators
18+
s = pd.Series(data)
19+
self.compare(s, op, 1, exc=TypeError)
20+
21+
def test_arith_array(self, data, all_arithmetic_operators):
22+
# ndarray & other series
23+
op = all_arithmetic_operators
24+
s = pd.Series(data)
25+
self.compare(s, op, np.ones(len(s), dtype=s.dtype.type), exc=TypeError)
26+
27+
def test_compare_scalar(self, data, all_compare_operators):
28+
op = all_compare_operators
29+
30+
s = pd.Series(data)
31+
32+
if op in '__eq__':
33+
assert getattr(data, op)(0) is NotImplemented
34+
assert not getattr(s, op)(0).all()
35+
elif op in '__ne__':
36+
assert getattr(data, op)(0) is NotImplemented
37+
assert getattr(s, op)(0).all()
38+
39+
else:
40+
41+
# array
42+
getattr(data, op)(0) is NotImplementedError
43+
44+
# series
45+
s = pd.Series(data)
46+
with pytest.raises(TypeError):
47+
getattr(s, op)(0)
48+
49+
def test_error(self, data, all_arithmetic_operators):
50+
51+
# invalid ops
52+
op = all_arithmetic_operators
53+
with pytest.raises(AttributeError):
54+
getattr(data, op)

pandas/tests/extension/category/test_categorical.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,19 @@ class TestDtype(base.BaseDtypeTests):
5656

5757

5858
class TestOps(base.BaseOpsTests):
59-
pass
59+
60+
def test_compare_scalar(self, data, all_compare_operators):
61+
op = all_compare_operators
62+
63+
if op == '__eq__':
64+
assert not getattr(data, op)(0).all()
65+
66+
elif op == '__ne__':
67+
assert getattr(data, op)(0).all()
68+
69+
else:
70+
with pytest.raises(TypeError):
71+
getattr(data, op)(0)
6072

6173

6274
class TestInterface(base.BaseInterfaceTests):

pandas/tests/extension/decimal/test_decimal.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,38 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests):
100100

101101

102102
class TestOps(BaseDecimal, base.BaseOpsTests):
103-
pass
103+
104+
def compare(self, s, op, other):
105+
# TODO(extension)
106+
107+
pytest.xfail("not implemented")
108+
109+
result = getattr(s, op)(other)
110+
expected = result
111+
112+
self.assert_series_equal(result, expected)
113+
114+
def test_arith_scalar(self, data, all_arithmetic_operators):
115+
# scalar
116+
op = all_arithmetic_operators
117+
s = pd.Series(data)
118+
self.compare(s, op, 1)
119+
120+
def test_arith_array(self, data, all_arithmetic_operators):
121+
# ndarray & other series
122+
op = all_arithmetic_operators
123+
s = pd.Series(data)
124+
self.compare(s, op, np.ones(len(s), dtype=s.dtype.type))
125+
126+
@pytest.mark.xfail(reason="Not implemented")
127+
def test_compare_scalar(self, data, all_compare_operators):
128+
op = all_compare_operators
129+
130+
# array
131+
result = getattr(data, op)(0)
132+
expected = getattr(data.data, op)(0)
133+
134+
tm.assert_series_equal(result, expected)
104135

105136

106137
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):

0 commit comments

Comments
 (0)