Skip to content

Commit f108ff7

Browse files
Chang Shewesm
Chang She
authored andcommitted
TST: tests for flex compare #652
1 parent d0e870c commit f108ff7

File tree

2 files changed

+158
-24
lines changed

2 files changed

+158
-24
lines changed

pandas/core/frame.py

+47-24
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,40 @@ def f(self, other, axis=default_axis, level=None, fill_value=None):
222222

223223
return f
224224

225-
def flex_comp_method(op, name, default_axis='columns'):
225+
def _flex_comp_method(op, name, default_axis='columns'):
226+
227+
def na_op(x, y):
228+
try:
229+
result = op(x, y)
230+
except TypeError:
231+
xrav = x.ravel()
232+
result = np.empty(x.size, dtype=x.dtype)
233+
if isinstance(y, np.ndarray):
234+
yrav = y.ravel()
235+
mask = notnull(xrav) & notnull(yrav)
236+
result[mask] = op(xrav[mask], yrav[mask])
237+
else:
238+
mask = notnull(xrav)
239+
result[mask] = op(xrav[mask], y)
240+
241+
if op == operator.ne:
242+
np.putmask(result, -mask, False)
243+
else:
244+
np.putmask(result, -mask, False)
245+
result = result.reshape(x.shape)
246+
247+
return result
226248

227249
@Appender('Wrapper for flexible comparison methods %s' % name)
228250
def f(self, other, axis=default_axis, level=None):
229251
if isinstance(other, DataFrame): # Another DataFrame
230-
return self._flex_compare_frame(other, op, level)
252+
return self._flex_compare_frame(other, na_op, level)
231253

232254
elif isinstance(other, Series):
233255
try:
234-
return self._combine_series(other, op, None, axis, level)
256+
return self._combine_series(other, na_op, None, axis, level)
235257
except Exception:
236-
return self._combine_series_infer(other, op)
258+
return self._combine_series_infer(other, na_op)
237259

238260
elif isinstance(other, (list, tuple)):
239261
if axis is not None and self._get_axis_name(axis) == 'index':
@@ -242,9 +264,9 @@ def f(self, other, axis=default_axis, level=None):
242264
casted = Series(other, index=self.columns)
243265

244266
try:
245-
return self._combine_series(casted, op, None, axis, level)
267+
return self._combine_series(casted, na_op, None, axis, level)
246268
except Exception:
247-
return self._combine_series_infer(casted, op)
269+
return self._combine_series_infer(casted, na_op)
248270

249271
elif isinstance(other, np.ndarray):
250272
if other.ndim == 1:
@@ -254,27 +276,28 @@ def f(self, other, axis=default_axis, level=None):
254276
casted = Series(other, index=self.columns)
255277

256278
try:
257-
return self._combine_series(casted, op, None, axis, level)
279+
return self._combine_series(casted, na_op, None, axis,
280+
level)
258281
except Exception:
259-
return self._combine_series_infer(casted, op)
282+
return self._combine_series_infer(casted, na_op)
260283

261284
elif other.ndim == 2:
262285
casted = DataFrame(other, index=self.index,
263286
columns=self.columns)
264-
return self._flex_compare_frame(casted, op, level)
287+
return self._flex_compare_frame(casted, na_op, level)
265288

266289
else: # pragma: no cover
267290
raise ValueError("Bad argument shape")
268291

269292
else:
270-
return self._combine_const(other, op)
293+
return self._combine_const(other, na_op)
271294

272295
f.__name__ = name
273296

274297
return f
275298

276299

277-
def comp_method(func, name):
300+
def _comp_method(func, name):
278301
@Appender('Wrapper for comparison method %s' % name)
279302
def f(self, other):
280303
if isinstance(other, DataFrame): # Another DataFrame
@@ -666,19 +689,19 @@ def __neg__(self):
666689
return self._wrap_array(arr, self.axes, copy=False)
667690

668691
# Comparison methods
669-
__eq__ = comp_method(operator.eq, '__eq__')
670-
__ne__ = comp_method(operator.ne, '__ne__')
671-
__lt__ = comp_method(operator.lt, '__lt__')
672-
__gt__ = comp_method(operator.gt, '__gt__')
673-
__le__ = comp_method(operator.le, '__le__')
674-
__ge__ = comp_method(operator.ge, '__ge__')
675-
676-
eq = flex_comp_method(operator.eq, 'eq')
677-
ne = flex_comp_method(operator.ne, 'ne')
678-
gt = flex_comp_method(operator.gt, 'gt')
679-
lt = flex_comp_method(operator.lt, 'lt')
680-
ge = flex_comp_method(operator.ge, 'ge')
681-
le = flex_comp_method(operator.le, 'le')
692+
__eq__ = _comp_method(operator.eq, '__eq__')
693+
__ne__ = _comp_method(operator.ne, '__ne__')
694+
__lt__ = _comp_method(operator.lt, '__lt__')
695+
__gt__ = _comp_method(operator.gt, '__gt__')
696+
__le__ = _comp_method(operator.le, '__le__')
697+
__ge__ = _comp_method(operator.ge, '__ge__')
698+
699+
eq = _flex_comp_method(operator.eq, 'eq')
700+
ne = _flex_comp_method(operator.ne, 'ne')
701+
gt = _flex_comp_method(operator.gt, 'gt')
702+
lt = _flex_comp_method(operator.lt, 'lt')
703+
ge = _flex_comp_method(operator.ge, 'ge')
704+
le = _flex_comp_method(operator.le, 'le')
682705

683706
def dot(self, other):
684707
"""

pandas/tests/test_frame.py

+111
Original file line numberDiff line numberDiff line change
@@ -2449,6 +2449,117 @@ def test_arith_flex_frame(self):
24492449
result = self.frame[:0].add(self.frame)
24502450
assert_frame_equal(result, self.frame * np.nan)
24512451

2452+
def test_bool_flex_frame(self):
2453+
data = np.random.randn(5, 3)
2454+
other_data = np.random.randn(5, 3)
2455+
df = DataFrame(data)
2456+
other = DataFrame(other_data)
2457+
2458+
# No NAs
2459+
2460+
# DataFrame
2461+
self.assert_(df.eq(df).values.all())
2462+
self.assert_(not df.ne(df).values.any())
2463+
2464+
assert_frame_equal((df == other), df.eq(other))
2465+
assert_frame_equal((df != other), df.ne(other))
2466+
assert_frame_equal((df > other), df.gt(other))
2467+
assert_frame_equal((df < other), df.lt(other))
2468+
assert_frame_equal((df >= other), df.ge(other))
2469+
assert_frame_equal((df <= other), df.le(other))
2470+
2471+
# Unaligned
2472+
def _check_unaligned_frame(meth, op, df, other, default=False):
2473+
part_o = other.ix[3:, 1:].copy()
2474+
rs = meth(df, part_o)
2475+
xp = op(df, part_o.reindex(index=df.index, columns=df.columns))
2476+
assert_frame_equal(rs, xp)
2477+
2478+
_check_unaligned_frame(DataFrame.eq, operator.eq, df, other)
2479+
_check_unaligned_frame(DataFrame.ne, operator.ne, df, other,
2480+
default=True)
2481+
_check_unaligned_frame(DataFrame.gt, operator.gt, df, other)
2482+
_check_unaligned_frame(DataFrame.lt, operator.lt, df, other)
2483+
_check_unaligned_frame(DataFrame.ge, operator.ge, df, other)
2484+
_check_unaligned_frame(DataFrame.le, operator.le, df, other)
2485+
2486+
# Series
2487+
def _test_seq(df, idx_ser, col_ser):
2488+
idx_eq = df.eq(idx_ser, axis=0)
2489+
col_eq = df.eq(col_ser)
2490+
idx_ne = df.ne(idx_ser, axis=0)
2491+
col_ne = df.ne(col_ser)
2492+
assert_frame_equal(col_eq, df == Series(col_ser))
2493+
assert_frame_equal(col_eq, -col_ne)
2494+
assert_frame_equal(idx_eq, -idx_ne)
2495+
assert_frame_equal(idx_eq, df.T.eq(idx_ser).T)
2496+
2497+
idx_gt = df.gt(idx_ser, axis=0)
2498+
col_gt = df.gt(col_ser)
2499+
idx_le = df.le(idx_ser, axis=0)
2500+
col_le = df.le(col_ser)
2501+
2502+
assert_frame_equal(col_gt, df > Series(col_ser))
2503+
assert_frame_equal(col_gt, -col_le)
2504+
assert_frame_equal(idx_gt, -idx_le)
2505+
assert_frame_equal(idx_gt, df.T.gt(idx_ser).T)
2506+
2507+
idx_ge = df.ge(idx_ser, axis=0)
2508+
col_ge = df.ge(col_ser)
2509+
idx_lt = df.lt(idx_ser, axis=0)
2510+
col_lt = df.lt(col_ser)
2511+
assert_frame_equal(col_ge, df >= Series(col_ser))
2512+
assert_frame_equal(col_ge, -col_lt)
2513+
assert_frame_equal(idx_ge, -idx_lt)
2514+
assert_frame_equal(idx_ge, df.T.ge(idx_ser).T)
2515+
2516+
idx_ser = Series(np.random.randn(5))
2517+
col_ser = Series(np.random.randn(3))
2518+
_test_seq(df, idx_ser, col_ser)
2519+
2520+
# ndarray
2521+
2522+
assert_frame_equal((df == other.values), df.eq(other.values))
2523+
assert_frame_equal((df != other.values), df.ne(other.values))
2524+
assert_frame_equal((df > other.values), df.gt(other.values))
2525+
assert_frame_equal((df < other.values), df.lt(other.values))
2526+
assert_frame_equal((df >= other.values), df.ge(other.values))
2527+
assert_frame_equal((df <= other.values), df.le(other.values))
2528+
2529+
# list/tuple
2530+
_test_seq(df, idx_ser.values, col_ser.values)
2531+
2532+
# NA
2533+
df.ix[0, 0] = np.nan
2534+
rs = df.eq(df)
2535+
self.assert_(not rs.ix[0, 0])
2536+
rs = df.ne(df)
2537+
self.assert_(rs.ix[0, 0])
2538+
rs = df.gt(df)
2539+
self.assert_(not rs.ix[0, 0])
2540+
rs = df.lt(df)
2541+
self.assert_(not rs.ix[0, 0])
2542+
rs = df.ge(df)
2543+
self.assert_(not rs.ix[0, 0])
2544+
rs = df.le(df)
2545+
self.assert_(not rs.ix[0, 0])
2546+
2547+
2548+
# scalar
2549+
assert_frame_equal(df.eq(0), df == 0)
2550+
assert_frame_equal(df.ne(0), df != 0)
2551+
assert_frame_equal(df.gt(0), df > 0)
2552+
assert_frame_equal(df.lt(0), df < 0)
2553+
assert_frame_equal(df.ge(0), df >= 0)
2554+
assert_frame_equal(df.le(0), df <= 0)
2555+
2556+
assert_frame_equal(df.eq(np.nan), df == np.nan)
2557+
assert_frame_equal(df.ne(np.nan), df != np.nan)
2558+
assert_frame_equal(df.gt(np.nan), df > np.nan)
2559+
assert_frame_equal(df.lt(np.nan), df < np.nan)
2560+
assert_frame_equal(df.ge(np.nan), df >= np.nan)
2561+
assert_frame_equal(df.le(np.nan), df <= np.nan)
2562+
24522563
def test_arith_flex_series(self):
24532564
df = self.simple
24542565

0 commit comments

Comments
 (0)