Skip to content

Commit 5eb1add

Browse files
authored
CLN: Clean series/test_arithmetic.py (#36406)
1 parent 81d24c7 commit 5eb1add

File tree

1 file changed

+42
-56
lines changed

1 file changed

+42
-56
lines changed

pandas/tests/series/test_arithmetic.py

+42-56
Original file line numberDiff line numberDiff line change
@@ -260,73 +260,59 @@ def test_sub_datetimelike_align(self):
260260

261261

262262
class TestSeriesFlexComparison:
263-
def test_comparison_flex_basic(self):
263+
@pytest.mark.parametrize("axis", [0, None, "index"])
264+
def test_comparison_flex_basic(self, axis, all_compare_operators):
265+
op = all_compare_operators.strip("__")
264266
left = pd.Series(np.random.randn(10))
265267
right = pd.Series(np.random.randn(10))
268+
result = getattr(left, op)(right, axis=axis)
269+
expected = getattr(operator, op)(left, right)
270+
tm.assert_series_equal(result, expected)
266271

267-
tm.assert_series_equal(left.eq(right), left == right)
268-
tm.assert_series_equal(left.ne(right), left != right)
269-
tm.assert_series_equal(left.le(right), left < right)
270-
tm.assert_series_equal(left.lt(right), left <= right)
271-
tm.assert_series_equal(left.gt(right), left > right)
272-
tm.assert_series_equal(left.ge(right), left >= right)
273-
274-
for axis in [0, None, "index"]:
275-
tm.assert_series_equal(left.eq(right, axis=axis), left == right)
276-
tm.assert_series_equal(left.ne(right, axis=axis), left != right)
277-
tm.assert_series_equal(left.le(right, axis=axis), left < right)
278-
tm.assert_series_equal(left.lt(right, axis=axis), left <= right)
279-
tm.assert_series_equal(left.gt(right, axis=axis), left > right)
280-
tm.assert_series_equal(left.ge(right, axis=axis), left >= right)
272+
def test_comparison_bad_axis(self, all_compare_operators):
273+
op = all_compare_operators.strip("__")
274+
left = pd.Series(np.random.randn(10))
275+
right = pd.Series(np.random.randn(10))
281276

282277
msg = "No axis named 1 for object type"
283-
for op in ["eq", "ne", "le", "le", "gt", "ge"]:
284-
with pytest.raises(ValueError, match=msg):
285-
getattr(left, op)(right, axis=1)
278+
with pytest.raises(ValueError, match=msg):
279+
getattr(left, op)(right, axis=1)
286280

287-
def test_comparison_flex_alignment(self):
281+
@pytest.mark.parametrize(
282+
"values, op",
283+
[
284+
([False, False, True, False], "eq"),
285+
([True, True, False, True], "ne"),
286+
([False, False, True, False], "le"),
287+
([False, False, False, False], "lt"),
288+
([False, True, True, False], "ge"),
289+
([False, True, False, False], "gt"),
290+
],
291+
)
292+
def test_comparison_flex_alignment(self, values, op):
288293
left = Series([1, 3, 2], index=list("abc"))
289294
right = Series([2, 2, 2], index=list("bcd"))
295+
result = getattr(left, op)(right)
296+
expected = pd.Series(values, index=list("abcd"))
297+
tm.assert_series_equal(result, expected)
290298

291-
exp = pd.Series([False, False, True, False], index=list("abcd"))
292-
tm.assert_series_equal(left.eq(right), exp)
293-
294-
exp = pd.Series([True, True, False, True], index=list("abcd"))
295-
tm.assert_series_equal(left.ne(right), exp)
296-
297-
exp = pd.Series([False, False, True, False], index=list("abcd"))
298-
tm.assert_series_equal(left.le(right), exp)
299-
300-
exp = pd.Series([False, False, False, False], index=list("abcd"))
301-
tm.assert_series_equal(left.lt(right), exp)
302-
303-
exp = pd.Series([False, True, True, False], index=list("abcd"))
304-
tm.assert_series_equal(left.ge(right), exp)
305-
306-
exp = pd.Series([False, True, False, False], index=list("abcd"))
307-
tm.assert_series_equal(left.gt(right), exp)
308-
309-
def test_comparison_flex_alignment_fill(self):
299+
@pytest.mark.parametrize(
300+
"values, op, fill_value",
301+
[
302+
([False, False, True, True], "eq", 2),
303+
([True, True, False, False], "ne", 2),
304+
([False, False, True, True], "le", 0),
305+
([False, False, False, True], "lt", 0),
306+
([True, True, True, False], "ge", 0),
307+
([True, True, False, False], "gt", 0),
308+
],
309+
)
310+
def test_comparison_flex_alignment_fill(self, values, op, fill_value):
310311
left = Series([1, 3, 2], index=list("abc"))
311312
right = Series([2, 2, 2], index=list("bcd"))
312-
313-
exp = pd.Series([False, False, True, True], index=list("abcd"))
314-
tm.assert_series_equal(left.eq(right, fill_value=2), exp)
315-
316-
exp = pd.Series([True, True, False, False], index=list("abcd"))
317-
tm.assert_series_equal(left.ne(right, fill_value=2), exp)
318-
319-
exp = pd.Series([False, False, True, True], index=list("abcd"))
320-
tm.assert_series_equal(left.le(right, fill_value=0), exp)
321-
322-
exp = pd.Series([False, False, False, True], index=list("abcd"))
323-
tm.assert_series_equal(left.lt(right, fill_value=0), exp)
324-
325-
exp = pd.Series([True, True, True, False], index=list("abcd"))
326-
tm.assert_series_equal(left.ge(right, fill_value=0), exp)
327-
328-
exp = pd.Series([True, True, False, False], index=list("abcd"))
329-
tm.assert_series_equal(left.gt(right, fill_value=0), exp)
313+
result = getattr(left, op)(right, fill_value=fill_value)
314+
expected = pd.Series(values, index=list("abcd"))
315+
tm.assert_series_equal(result, expected)
330316

331317

332318
class TestSeriesComparison:

0 commit comments

Comments
 (0)