Skip to content

Commit f115360

Browse files
PERF: no need to check for DataFrame in pandas.core.computation.expressions (#40445)
1 parent 05a0e98 commit f115360

File tree

2 files changed

+20
-38
lines changed

2 files changed

+20
-38
lines changed

pandas/core/computation/expressions.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
from pandas._typing import FuncType
2121

22-
from pandas.core.dtypes.generic import ABCDataFrame
23-
2422
from pandas.core.computation.check import NUMEXPR_INSTALLED
2523
from pandas.core.ops import roperator
2624

@@ -83,14 +81,8 @@ def _can_use_numexpr(op, op_str, a, b, dtype_check):
8381
# check for dtype compatibility
8482
dtypes: Set[str] = set()
8583
for o in [a, b]:
86-
# Series implements dtypes, check for dimension count as well
87-
if hasattr(o, "dtypes") and o.ndim > 1:
88-
s = o.dtypes.value_counts()
89-
if len(s) > 1:
90-
return False
91-
dtypes |= set(s.index.astype(str))
9284
# ndarray and Series Case
93-
elif hasattr(o, "dtype"):
85+
if hasattr(o, "dtype"):
9486
dtypes |= {o.dtype.name}
9587

9688
# allowed are a superset
@@ -190,8 +182,6 @@ def _where_numexpr(cond, a, b):
190182

191183

192184
def _has_bool_dtype(x):
193-
if isinstance(x, ABCDataFrame):
194-
return "bool" in x.dtypes
195185
try:
196186
return x.dtype == bool
197187
except AttributeError:

pandas/tests/test_expressions.py

+19-27
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from pandas.core.computation import expressions as expr
1414

15-
_frame = DataFrame(np.random.randn(10000, 4), columns=list("ABCD"), dtype="float64")
15+
_frame = DataFrame(np.random.randn(10001, 4), columns=list("ABCD"), dtype="float64")
1616
_frame2 = DataFrame(np.random.randn(100, 4), columns=list("ABCD"), dtype="float64")
1717
_mixed = DataFrame(
1818
{
@@ -36,14 +36,21 @@
3636
_integer2 = DataFrame(
3737
np.random.randint(1, 100, size=(101, 4)), columns=list("ABCD"), dtype="int64"
3838
)
39+
_array = _frame["A"].values.copy()
40+
_array2 = _frame2["A"].values.copy()
41+
42+
_array_mixed = _mixed["D"].values.copy()
43+
_array_mixed2 = _mixed2["D"].values.copy()
3944

4045

4146
@pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr")
4247
class TestExpressions:
4348
def setup_method(self, method):
4449

4550
self.frame = _frame.copy()
51+
self.array = _array.copy()
4652
self.frame2 = _frame2.copy()
53+
self.array2 = _array2.copy()
4754
self.mixed = _mixed.copy()
4855
self.mixed2 = _mixed2.copy()
4956
self._MIN_ELEMENTS = expr._MIN_ELEMENTS
@@ -134,33 +141,29 @@ def test_invalid(self):
134141

135142
# no op
136143
result = expr._can_use_numexpr(
137-
operator.add, None, self.frame, self.frame, "evaluate"
138-
)
139-
assert not result
140-
141-
# mixed
142-
result = expr._can_use_numexpr(
143-
operator.add, "+", self.mixed, self.frame, "evaluate"
144+
operator.add, None, self.array, self.array, "evaluate"
144145
)
145146
assert not result
146147

147148
# min elements
148149
result = expr._can_use_numexpr(
149-
operator.add, "+", self.frame2, self.frame2, "evaluate"
150+
operator.add, "+", self.array2, self.array2, "evaluate"
150151
)
151152
assert not result
152153

153154
# ok, we only check on first part of expression
154155
result = expr._can_use_numexpr(
155-
operator.add, "+", self.frame, self.frame2, "evaluate"
156+
operator.add, "+", self.array, self.array2, "evaluate"
156157
)
157158
assert result
158159

159160
@pytest.mark.parametrize(
160161
"opname,op_str",
161162
[("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")],
162163
)
163-
@pytest.mark.parametrize("left,right", [(_frame, _frame2), (_mixed, _mixed2)])
164+
@pytest.mark.parametrize(
165+
"left,right", [(_array, _array2), (_array_mixed, _array_mixed2)]
166+
)
164167
def test_binary_ops(self, opname, op_str, left, right):
165168
def testit():
166169

@@ -170,16 +173,9 @@ def testit():
170173

171174
op = getattr(operator, opname)
172175

173-
result = expr._can_use_numexpr(op, op_str, left, left, "evaluate")
174-
assert result != left._is_mixed_type
175-
176176
result = expr.evaluate(op, left, left, use_numexpr=True)
177177
expected = expr.evaluate(op, left, left, use_numexpr=False)
178-
179-
if isinstance(result, DataFrame):
180-
tm.assert_frame_equal(result, expected)
181-
else:
182-
tm.assert_numpy_array_equal(result, expected.values)
178+
tm.assert_numpy_array_equal(result, expected)
183179

184180
result = expr._can_use_numexpr(op, op_str, right, right, "evaluate")
185181
assert not result
@@ -203,23 +199,19 @@ def testit():
203199
("ne", "!="),
204200
],
205201
)
206-
@pytest.mark.parametrize("left,right", [(_frame, _frame2), (_mixed, _mixed2)])
202+
@pytest.mark.parametrize(
203+
"left,right", [(_array, _array2), (_array_mixed, _array_mixed2)]
204+
)
207205
def test_comparison_ops(self, opname, op_str, left, right):
208206
def testit():
209207
f12 = left + 1
210208
f22 = right + 1
211209

212210
op = getattr(operator, opname)
213211

214-
result = expr._can_use_numexpr(op, op_str, left, f12, "evaluate")
215-
assert result != left._is_mixed_type
216-
217212
result = expr.evaluate(op, left, f12, use_numexpr=True)
218213
expected = expr.evaluate(op, left, f12, use_numexpr=False)
219-
if isinstance(result, DataFrame):
220-
tm.assert_frame_equal(result, expected)
221-
else:
222-
tm.assert_numpy_array_equal(result, expected.values)
214+
tm.assert_numpy_array_equal(result, expected)
223215

224216
result = expr._can_use_numexpr(op, op_str, right, f22, "evaluate")
225217
assert not result

0 commit comments

Comments
 (0)