Skip to content

Commit 8f9a75c

Browse files
jbrockmendeljreback
authored andcommitted
TST/CLN: parametrize and clean test_expressions, test_nanops (#28553)
1 parent e448a26 commit 8f9a75c

File tree

2 files changed

+132
-215
lines changed

2 files changed

+132
-215
lines changed

pandas/tests/test_expressions.py

+44-50
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,12 @@ def run_arithmetic(self, df, other):
5454
operations = ["add", "sub", "mul", "mod", "truediv", "floordiv"]
5555
for test_flex in [True, False]:
5656
for arith in operations:
57-
58-
operator_name = arith
59-
57+
# TODO: share with run_binary
6058
if test_flex:
6159
op = lambda x, y: getattr(x, arith)(y)
6260
op.__name__ = arith
6361
else:
64-
op = getattr(operator, operator_name)
62+
op = getattr(operator, arith)
6563
expr.set_use_numexpr(False)
6664
expected = op(df, other)
6765
expr.set_use_numexpr(True)
@@ -87,13 +85,14 @@ def run_binary(self, df, other):
8785
for test_flex in [True, False]:
8886
for arith in operations:
8987
if test_flex:
90-
op = lambda x, y: getattr(df, arith)(y)
88+
op = lambda x, y: getattr(x, arith)(y)
9189
op.__name__ = arith
9290
else:
9391
op = getattr(operator, arith)
9492
expr.set_use_numexpr(False)
9593
expected = op(df, other)
9694
expr.set_use_numexpr(True)
95+
9796
expr.get_test_result()
9897
result = op(df, other)
9998
used_numexpr = expr.get_test_result()
@@ -167,29 +166,29 @@ def test_invalid(self):
167166
"opname,op_str",
168167
[("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")],
169168
)
170-
def test_binary_ops(self, opname, op_str):
169+
@pytest.mark.parametrize("left,right", [(_frame, _frame2), (_mixed, _mixed2)])
170+
def test_binary_ops(self, opname, op_str, left, right):
171171
def testit():
172172

173-
for f, f2 in [(self.frame, self.frame2), (self.mixed, self.mixed2)]:
173+
if opname == "pow":
174+
# TODO: get this working
175+
return
174176

175-
if opname == "pow":
176-
continue
177+
op = getattr(operator, opname)
177178

178-
op = getattr(operator, opname)
179+
result = expr._can_use_numexpr(op, op_str, left, left, "evaluate")
180+
assert result != left._is_mixed_type
179181

180-
result = expr._can_use_numexpr(op, op_str, f, f, "evaluate")
181-
assert result != f._is_mixed_type
182+
result = expr.evaluate(op, op_str, left, left, use_numexpr=True)
183+
expected = expr.evaluate(op, op_str, left, left, use_numexpr=False)
182184

183-
result = expr.evaluate(op, op_str, f, f, use_numexpr=True)
184-
expected = expr.evaluate(op, op_str, f, f, use_numexpr=False)
185+
if isinstance(result, DataFrame):
186+
tm.assert_frame_equal(result, expected)
187+
else:
188+
tm.assert_numpy_array_equal(result, expected.values)
185189

186-
if isinstance(result, DataFrame):
187-
tm.assert_frame_equal(result, expected)
188-
else:
189-
tm.assert_numpy_array_equal(result, expected.values)
190-
191-
result = expr._can_use_numexpr(op, op_str, f2, f2, "evaluate")
192-
assert not result
190+
result = expr._can_use_numexpr(op, op_str, right, right, "evaluate")
191+
assert not result
193192

194193
expr.set_use_numexpr(False)
195194
testit()
@@ -210,30 +209,26 @@ def testit():
210209
("ne", "!="),
211210
],
212211
)
213-
def test_comparison_ops(self, opname, op_str):
212+
@pytest.mark.parametrize("left,right", [(_frame, _frame2), (_mixed, _mixed2)])
213+
def test_comparison_ops(self, opname, op_str, left, right):
214214
def testit():
215-
for f, f2 in [(self.frame, self.frame2), (self.mixed, self.mixed2)]:
216-
217-
f11 = f
218-
f12 = f + 1
215+
f12 = left + 1
216+
f22 = right + 1
219217

220-
f21 = f2
221-
f22 = f2 + 1
218+
op = getattr(operator, opname)
222219

223-
op = getattr(operator, opname)
220+
result = expr._can_use_numexpr(op, op_str, left, f12, "evaluate")
221+
assert result != left._is_mixed_type
224222

225-
result = expr._can_use_numexpr(op, op_str, f11, f12, "evaluate")
226-
assert result != f11._is_mixed_type
223+
result = expr.evaluate(op, op_str, left, f12, use_numexpr=True)
224+
expected = expr.evaluate(op, op_str, left, f12, use_numexpr=False)
225+
if isinstance(result, DataFrame):
226+
tm.assert_frame_equal(result, expected)
227+
else:
228+
tm.assert_numpy_array_equal(result, expected.values)
227229

228-
result = expr.evaluate(op, op_str, f11, f12, use_numexpr=True)
229-
expected = expr.evaluate(op, op_str, f11, f12, use_numexpr=False)
230-
if isinstance(result, DataFrame):
231-
tm.assert_frame_equal(result, expected)
232-
else:
233-
tm.assert_numpy_array_equal(result, expected.values)
234-
235-
result = expr._can_use_numexpr(op, op_str, f21, f22, "evaluate")
236-
assert not result
230+
result = expr._can_use_numexpr(op, op_str, right, f22, "evaluate")
231+
assert not result
237232

238233
expr.set_use_numexpr(False)
239234
testit()
@@ -244,15 +239,14 @@ def testit():
244239
testit()
245240

246241
@pytest.mark.parametrize("cond", [True, False])
247-
def test_where(self, cond):
242+
@pytest.mark.parametrize("df", [_frame, _frame2, _mixed, _mixed2])
243+
def test_where(self, cond, df):
248244
def testit():
249-
for f in [self.frame, self.frame2, self.mixed, self.mixed2]:
250-
251-
c = np.empty(f.shape, dtype=np.bool_)
252-
c.fill(cond)
253-
result = expr.where(c, f.values, f.values + 1)
254-
expected = np.where(c, f.values, f.values + 1)
255-
tm.assert_numpy_array_equal(result, expected)
245+
c = np.empty(df.shape, dtype=np.bool_)
246+
c.fill(cond)
247+
result = expr.where(c, df.values, df.values + 1)
248+
expected = np.where(c, df.values, df.values + 1)
249+
tm.assert_numpy_array_equal(result, expected)
256250

257251
expr.set_use_numexpr(False)
258252
testit()
@@ -263,7 +257,7 @@ def testit():
263257
testit()
264258

265259
@pytest.mark.parametrize(
266-
"op_str,opname", list(zip(["/", "//", "**"], ["truediv", "floordiv", "pow"]))
260+
"op_str,opname", [("/", "truediv"), ("//", "floordiv"), ("**", "pow")]
267261
)
268262
def test_bool_ops_raise_on_arithmetic(self, op_str, opname):
269263
df = DataFrame({"a": np.random.rand(10) > 0.5, "b": np.random.rand(10) > 0.5})
@@ -291,7 +285,7 @@ def test_bool_ops_raise_on_arithmetic(self, op_str, opname):
291285
f(df, True)
292286

293287
@pytest.mark.parametrize(
294-
"op_str,opname", list(zip(["+", "*", "-"], ["add", "mul", "sub"]))
288+
"op_str,opname", [("+", "add"), ("*", "mul"), ("-", "sub")]
295289
)
296290
def test_bool_ops_warn_on_arithmetic(self, op_str, opname):
297291
n = 10

0 commit comments

Comments
 (0)