Skip to content

Commit 0de0459

Browse files
committed
ENH: Series lhs, scalar rhs bool comparison support
1 parent fb2bb58 commit 0de0459

File tree

5 files changed

+56
-11
lines changed

5 files changed

+56
-11
lines changed

pandas/core/ops.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -564,21 +564,31 @@ def na_op(x, y):
564564
y = com._ensure_object(y)
565565
result = lib.vec_binop(x, y, op)
566566
else:
567-
result = lib.scalar_binop(x, y, op)
567+
try:
568+
569+
# let null fall thru
570+
if not isnull(y):
571+
y = bool(y)
572+
result = lib.scalar_binop(x, y, op)
573+
except:
574+
raise TypeError("cannot compare a dtyped [{0}] array with "
575+
"a scalar of type [{1}]".format(x.dtype,type(y).__name__))
568576

569577
return result
570578

571579
def wrapper(self, other):
572580
if isinstance(other, pd.Series):
573581
name = _maybe_match_name(self, other)
582+
583+
other = other.reindex_like(self).fillna(False).astype(bool)
574584
return self._constructor(na_op(self.values, other.values),
575-
index=self.index, name=name)
585+
index=self.index, name=name).fillna(False).astype(bool)
576586
elif isinstance(other, pd.DataFrame):
577587
return NotImplemented
578588
else:
579589
# scalars
580590
return self._constructor(na_op(self.values, other),
581-
index=self.index, name=self.name)
591+
index=self.index, name=self.name).fillna(False).astype(bool)
582592
return wrapper
583593

584594

pandas/core/series.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
_values_from_object,
2222
_possibly_cast_to_datetime, _possibly_castable,
2323
_possibly_convert_platform,
24-
ABCSparseArray, _maybe_match_name)
24+
ABCSparseArray, _maybe_match_name, _ensure_object)
25+
2526
from pandas.core.index import (Index, MultiIndex, InvalidIndexError,
2627
_ensure_index, _handle_legacy_indexes)
2728
from pandas.core.indexing import (
@@ -1170,7 +1171,7 @@ def duplicated(self, take_last=False):
11701171
-------
11711172
duplicated : Series
11721173
"""
1173-
keys = com._ensure_object(self.values)
1174+
keys = _ensure_object(self.values)
11741175
duplicated = lib.duplicated(keys, take_last=take_last)
11751176
return self._constructor(duplicated, index=self.index, name=self.name)
11761177

pandas/lib.pyx

+3
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,9 @@ def scalar_binop(ndarray[object] values, object val, object op):
672672
object x
673673

674674
result = np.empty(n, dtype=object)
675+
if util._checknull(val):
676+
result.fill(val)
677+
return result
675678

676679
for i in range(n):
677680
x = values[i]

pandas/tests/test_frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4526,7 +4526,7 @@ def test_logical_with_nas(self):
45264526
# GH4947
45274527
# bool comparisons should return bool
45284528
result = d['a'] | d['b']
4529-
expected = Series([True, True])
4529+
expected = Series([False, True])
45304530
assert_series_equal(result, expected)
45314531

45324532
# GH4604, automatic casting here

pandas/tests/test_series.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -2797,7 +2797,7 @@ def test_comparison_label_based(self):
27972797
assert_series_equal(result,expected)
27982798

27992799
result = a | Series([])
2800-
expected = Series([True, True, True], list('bca'))
2800+
expected = Series([True, False, True], list('bca'))
28012801
assert_series_equal(result,expected)
28022802

28032803
# vs non-matching
@@ -2806,14 +2806,43 @@ def test_comparison_label_based(self):
28062806
assert_series_equal(result,expected)
28072807

28082808
result = a | Series([1],['z'])
2809-
expected = Series([True, True, True], list('bca'))
2809+
expected = Series([True, False, True], list('bca'))
28102810
assert_series_equal(result,expected)
28112811

28122812
# identity
28132813
# we would like s[s|e] == s to hold for any e, whether empty or not
28142814
for e in [Series([]),Series([1],['z']),Series(['z']),Series(np.nan,b.index),Series(np.nan,a.index)]:
28152815
result = a[a | e]
2816-
assert_series_equal(result,a)
2816+
assert_series_equal(result,a[a])
2817+
2818+
# vs scalars
2819+
index = list('bca')
2820+
t = Series([True,False,True])
2821+
2822+
for v in [True,1,2]:
2823+
result = Series([True,False,True],index=index) | v
2824+
expected = Series([True,True,True],index=index)
2825+
assert_series_equal(result,expected)
2826+
2827+
for v in [np.nan,'foo']:
2828+
self.assertRaises(TypeError, lambda : t | v)
2829+
2830+
for v in [False,0]:
2831+
result = Series([True,False,True],index=index) | v
2832+
expected = Series([True,False,True],index=index)
2833+
assert_series_equal(result,expected)
2834+
2835+
for v in [True,1]:
2836+
result = Series([True,False,True],index=index) & v
2837+
expected = Series([True,False,True],index=index)
2838+
assert_series_equal(result,expected)
2839+
2840+
for v in [False,0]:
2841+
result = Series([True,False,True],index=index) & v
2842+
expected = Series([False,False,False],index=index)
2843+
assert_series_equal(result,expected)
2844+
for v in [np.nan]:
2845+
self.assertRaises(TypeError, lambda : t & v)
28172846

28182847
def test_between(self):
28192848
s = Series(bdate_range('1/1/2000', periods=20).asobject)
@@ -2851,12 +2880,14 @@ def test_scalar_na_cmp_corners(self):
28512880
def tester(a, b):
28522881
return a & b
28532882

2854-
self.assertRaises(ValueError, tester, s, datetime(2005, 1, 1))
2883+
self.assertRaises(TypeError, tester, s, datetime(2005, 1, 1))
28552884

28562885
s = Series([2, 3, 4, 5, 6, 7, 8, 9, datetime(2005, 1, 1)])
28572886
s[::2] = np.nan
28582887

2859-
assert_series_equal(tester(s, list(s)), s)
2888+
expected = Series(True,index=s.index)
2889+
expected[::2] = False
2890+
assert_series_equal(tester(s, list(s)), expected)
28602891

28612892
d = DataFrame({'A': s})
28622893
# TODO: Fix this exception - needs to be fixed! (see GH5035)

0 commit comments

Comments
 (0)