Skip to content

Commit 04b538a

Browse files
authored
TST: Refactor test_expressions.py (#44778)
* TST: Refactor test_expressions.py * Address comments
1 parent 0b0cac5 commit 04b538a

File tree

1 file changed

+54
-50
lines changed

1 file changed

+54
-50
lines changed

pandas/tests/test_expressions.py

+54-50
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@
4848
@pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr")
4949
class TestExpressions:
5050
def setup_method(self, method):
51-
52-
self.frame = _frame.copy()
53-
self.frame2 = _frame2.copy()
54-
self.mixed = _mixed.copy()
55-
self.mixed2 = _mixed2.copy()
5651
self._MIN_ELEMENTS = expr._MIN_ELEMENTS
5752

5853
def teardown_method(self, method):
@@ -75,50 +70,36 @@ def call_op(df, other, flex: bool, opname: str):
7570
result = op(df, other)
7671
return result, expected
7772

78-
def run_arithmetic(self, df, other, flex: bool):
79-
expr._MIN_ELEMENTS = 0
80-
operations = ["add", "sub", "mul", "mod", "truediv", "floordiv"]
81-
for arith in operations:
82-
result, expected = self.call_op(df, other, flex, arith)
83-
84-
if arith == "truediv":
85-
if expected.ndim == 1:
86-
assert expected.dtype.kind == "f"
87-
else:
88-
assert all(x.kind == "f" for x in expected.dtypes.values)
89-
tm.assert_equal(expected, result)
90-
91-
def run_binary(self, df, other, flex: bool):
92-
"""
93-
tests solely that the result is the same whether or not numexpr is
94-
enabled. Need to test whether the function does the correct thing
95-
elsewhere.
96-
"""
73+
@pytest.mark.parametrize(
74+
"df",
75+
[
76+
_integer,
77+
_integer2,
78+
# randint to get a case with zeros
79+
_integer * np.random.randint(0, 2, size=np.shape(_integer)),
80+
_frame,
81+
_frame2,
82+
_mixed,
83+
_mixed2,
84+
],
85+
)
86+
@pytest.mark.parametrize("flex", [True, False])
87+
@pytest.mark.parametrize(
88+
"arith", ["add", "sub", "mul", "mod", "truediv", "floordiv"]
89+
)
90+
def test_run_arithmetic(self, df, flex, arith):
9791
expr._MIN_ELEMENTS = 0
98-
expr.set_test_mode(True)
99-
operations = ["gt", "lt", "ge", "le", "eq", "ne"]
100-
101-
for arith in operations:
102-
result, expected = self.call_op(df, other, flex, arith)
103-
104-
used_numexpr = expr.get_test_result()
105-
assert used_numexpr, "Did not use numexpr as expected."
106-
tm.assert_equal(expected, result)
92+
result, expected = self.call_op(df, df, flex, arith)
10793

108-
def run_frame(self, df, other, flex: bool):
109-
self.run_arithmetic(df, other, flex)
110-
111-
set_option("compute.use_numexpr", False)
112-
binary_comp = other + 1
113-
set_option("compute.use_numexpr", True)
114-
self.run_binary(df, binary_comp, flex)
94+
if arith == "truediv":
95+
assert all(x.kind == "f" for x in expected.dtypes.values)
96+
tm.assert_equal(expected, result)
11597

11698
for i in range(len(df.columns)):
117-
self.run_arithmetic(df.iloc[:, i], other.iloc[:, i], flex)
118-
# FIXME: dont leave commented-out
119-
# series doesn't uses vec_compare instead of numexpr...
120-
# binary_comp = other.iloc[:, i] + 1
121-
# self.run_binary(df.iloc[:, i], binary_comp, flex)
99+
result, expected = self.call_op(df.iloc[:, i], df.iloc[:, i], flex, arith)
100+
if arith == "truediv":
101+
assert expected.dtype.kind == "f"
102+
tm.assert_equal(expected, result)
122103

123104
@pytest.mark.parametrize(
124105
"df",
@@ -134,8 +115,31 @@ def run_frame(self, df, other, flex: bool):
134115
],
135116
)
136117
@pytest.mark.parametrize("flex", [True, False])
137-
def test_arithmetic(self, df, flex):
138-
self.run_frame(df, df, flex)
118+
def test_run_binary(self, df, flex, comparison_op):
119+
"""
120+
tests solely that the result is the same whether or not numexpr is
121+
enabled. Need to test whether the function does the correct thing
122+
elsewhere.
123+
"""
124+
arith = comparison_op.__name__
125+
set_option("compute.use_numexpr", False)
126+
other = df.copy() + 1
127+
set_option("compute.use_numexpr", True)
128+
129+
expr._MIN_ELEMENTS = 0
130+
expr.set_test_mode(True)
131+
132+
result, expected = self.call_op(df, other, flex, arith)
133+
134+
used_numexpr = expr.get_test_result()
135+
assert used_numexpr, "Did not use numexpr as expected."
136+
tm.assert_equal(expected, result)
137+
138+
# FIXME: dont leave commented-out
139+
# series doesn't uses vec_compare instead of numexpr...
140+
# for i in range(len(df.columns)):
141+
# binary_comp = other.iloc[:, i] + 1
142+
# self.run_binary(df.iloc[:, i], binary_comp, flex)
139143

140144
def test_invalid(self):
141145
array = np.random.randn(1_000_001)
@@ -351,11 +355,11 @@ def test_bool_ops_column_name_dtype(self, test_input, expected):
351355
def test_frame_series_axis(self, axis, arith):
352356
# GH#26736 Dataframe.floordiv(Series, axis=1) fails
353357

354-
df = self.frame
358+
df = _frame
355359
if axis == 1:
356-
other = self.frame.iloc[0, :]
360+
other = df.iloc[0, :]
357361
else:
358-
other = self.frame.iloc[:, 0]
362+
other = df.iloc[:, 0]
359363

360364
expr._MIN_ELEMENTS = 0
361365

0 commit comments

Comments
 (0)