Skip to content

Commit d0e870c

Browse files
Chang Shewesm
Chang She
authored andcommitted
ENH: flex comparison operators on DataFrame #652
1 parent d8a0427 commit d0e870c

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

pandas/core/frame.py

+69
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,57 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
222222

223223
return f
224224

225+
def flex_comp_method(op, name, default_axis='columns'):
226+
227+
@Appender('Wrapper for flexible comparison methods %s' % name)
228+
def f(self, other, axis=default_axis, level=None):
229+
if isinstance(other, DataFrame): # Another DataFrame
230+
return self._flex_compare_frame(other, op, level)
231+
232+
elif isinstance(other, Series):
233+
try:
234+
return self._combine_series(other, op, None, axis, level)
235+
except Exception:
236+
return self._combine_series_infer(other, op)
237+
238+
elif isinstance(other, (list, tuple)):
239+
if axis is not None and self._get_axis_name(axis) == 'index':
240+
casted = Series(other, index=self.index)
241+
else:
242+
casted = Series(other, index=self.columns)
243+
244+
try:
245+
return self._combine_series(casted, op, None, axis, level)
246+
except Exception:
247+
return self._combine_series_infer(casted, op)
248+
249+
elif isinstance(other, np.ndarray):
250+
if other.ndim == 1:
251+
if axis is not None and self._get_axis_name(axis) == 'index':
252+
casted = Series(other, index=self.index)
253+
else:
254+
casted = Series(other, index=self.columns)
255+
256+
try:
257+
return self._combine_series(casted, op, None, axis, level)
258+
except Exception:
259+
return self._combine_series_infer(casted, op)
260+
261+
elif other.ndim == 2:
262+
casted = DataFrame(other, index=self.index,
263+
columns=self.columns)
264+
return self._flex_compare_frame(casted, op, level)
265+
266+
else: # pragma: no cover
267+
raise ValueError("Bad argument shape")
268+
269+
else:
270+
return self._combine_const(other, op)
271+
272+
f.__name__ = name
273+
274+
return f
275+
225276

226277
def comp_method(func, name):
227278
@Appender('Wrapper for comparison method %s' % name)
@@ -622,6 +673,13 @@ def __neg__(self):
622673
__le__ = comp_method(operator.le, '__le__')
623674
__ge__ = comp_method(operator.ge, '__ge__')
624675

676+
eq = flex_comp_method(operator.eq, 'eq')
677+
ne = flex_comp_method(operator.ne, 'ne')
678+
gt = flex_comp_method(operator.gt, 'gt')
679+
lt = flex_comp_method(operator.lt, 'lt')
680+
ge = flex_comp_method(operator.ge, 'ge')
681+
le = flex_comp_method(operator.le, 'le')
682+
625683
def dot(self, other):
626684
"""
627685
Matrix multiplication with DataFrame objects. Does no data alignment
@@ -2945,6 +3003,17 @@ def _compare_frame(self, other, func):
29453003
return self._constructor(data=new_data, index=self.index,
29463004
columns=self.columns, copy=False)
29473005

3006+
def _flex_compare_frame(self, other, func, level):
3007+
if not self._indexed_same(other):
3008+
self, other = self.align(other, 'outer', level=level)
3009+
3010+
new_data = {}
3011+
for col in self.columns:
3012+
new_data[col] = func(self[col], other[col])
3013+
3014+
return self._constructor(data=new_data, index=self.index,
3015+
columns=self.columns, copy=False)
3016+
29483017
def combine(self, other, func, fill_value=None):
29493018
"""
29503019
Add two DataFrame objects and do not propagate NaN values, so if for a

0 commit comments

Comments
 (0)