Skip to content

Commit 3ee3b70

Browse files
committed
Split up test_arith_flex_frame (review jreback)
1 parent 054f4e9 commit 3ee3b70

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

pandas/tests/frame/test_arithmetic.py

+55-56
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,17 @@ def test_df_add_flex_filled_mixed_dtypes(self):
126126
'B': ser * 2})
127127
tm.assert_frame_equal(result, expected)
128128

129-
def test_arith_flex_frame(self, all_arithmetic_operators, int_frame,
130-
mixed_int_frame, float_frame, mixed_float_frame):
129+
def test_arith_flex_frame(self, all_arithmetic_operators, float_frame,
130+
mixed_float_frame):
131131

132-
op = all_arithmetic_operators
132+
op = all_arithmetic_operators # one instance of parametrized fixture
133133
if op.startswith('__r'):
134-
pytest.skip('Reverse methods not available in operator library')
134+
# get op without "r" and invert it
135+
tmp = getattr(operator, op[:2] + op[3:])
136+
f = lambda x, y: tmp(y, x)
137+
else:
138+
f = getattr(operator, op)
135139

136-
f = getattr(operator, op)
137140
result = getattr(float_frame, op)(2 * float_frame)
138141
exp = f(float_frame, 2 * float_frame)
139142
tm.assert_frame_equal(result, exp)
@@ -144,58 +147,52 @@ def test_arith_flex_frame(self, all_arithmetic_operators, int_frame,
144147
tm.assert_frame_equal(result, exp)
145148
_check_mixed_float(result, dtype=dict(C=None))
146149

150+
@pytest.mark.parametrize('op', ['__add__', '__sub__', '__mul__'])
151+
def test_arith_flex_frame_mixed(self, op, int_frame, mixed_int_frame,
152+
float_frame, mixed_float_frame):
153+
154+
if op.startswith('__r'):
155+
# get op without "r" and invert it
156+
tmp = getattr(operator, op[:2] + op[3:])
157+
f = lambda x, y: tmp(y, x)
158+
else:
159+
f = getattr(operator, op)
160+
147161
# vs mix int
148-
if op in ['add', 'sub', 'mul']:
149-
result = getattr(mixed_int_frame, op)(2 + mixed_int_frame)
150-
exp = f(mixed_int_frame, 2 + mixed_int_frame)
151-
152-
# no overflow in the uint
153-
dtype = None
154-
if op in ['sub']:
155-
dtype = dict(B='uint64', C=None)
156-
elif op in ['add', 'mul']:
157-
dtype = dict(C=None)
158-
tm.assert_frame_equal(result, exp)
159-
_check_mixed_int(result, dtype=dtype)
160-
161-
# rops
162-
r_f = lambda x, y: f(y, x)
163-
result = getattr(float_frame, 'r' + op)(2 * float_frame)
164-
exp = r_f(float_frame, 2 * float_frame)
165-
tm.assert_frame_equal(result, exp)
166-
167-
# vs mix float
168-
result = getattr(mixed_float_frame, op)(2 * mixed_float_frame)
169-
exp = f(mixed_float_frame, 2 * mixed_float_frame)
170-
tm.assert_frame_equal(result, exp)
171-
_check_mixed_float(result, dtype=dict(C=None))
172-
173-
result = getattr(int_frame, op)(2 * int_frame)
174-
exp = f(int_frame, 2 * int_frame)
175-
tm.assert_frame_equal(result, exp)
176-
177-
# vs mix int
178-
if op in ['add', 'sub', 'mul']:
179-
result = getattr(mixed_int_frame, op)(2 + mixed_int_frame)
180-
exp = f(mixed_int_frame, 2 + mixed_int_frame)
181-
182-
# no overflow in the uint
183-
dtype = None
184-
if op in ['sub']:
185-
dtype = dict(B='uint64', C=None)
186-
elif op in ['add', 'mul']:
187-
dtype = dict(C=None)
188-
tm.assert_frame_equal(result, exp)
189-
_check_mixed_int(result, dtype=dtype)
190-
191-
# ndim >= 3
192-
ndim_5 = np.ones(float_frame.shape + (3, 4, 5))
193-
msg = "Unable to coerce to Series/DataFrame"
194-
with tm.assert_raises_regex(ValueError, msg):
195-
f(float_frame, ndim_5)
196-
197-
with tm.assert_raises_regex(ValueError, msg):
198-
getattr(float_frame, op)(ndim_5)
162+
result = getattr(mixed_int_frame, op)(2 + mixed_int_frame)
163+
exp = f(mixed_int_frame, 2 + mixed_int_frame)
164+
165+
# no overflow in the uint
166+
dtype = None
167+
if op in ['__sub__']:
168+
dtype = dict(B='uint64', C=None)
169+
elif op in ['__add__', '__mul__']:
170+
dtype = dict(C=None)
171+
tm.assert_frame_equal(result, exp)
172+
_check_mixed_int(result, dtype=dtype)
173+
174+
# vs mix float
175+
result = getattr(mixed_float_frame, op)(2 * mixed_float_frame)
176+
exp = f(mixed_float_frame, 2 * mixed_float_frame)
177+
tm.assert_frame_equal(result, exp)
178+
_check_mixed_float(result, dtype=dict(C=None))
179+
180+
# vs plain int
181+
result = getattr(int_frame, op)(2 * int_frame)
182+
exp = f(int_frame, 2 * int_frame)
183+
tm.assert_frame_equal(result, exp)
184+
185+
def test_arith_flex_frame_corner(self, all_arithmetic_operators,
186+
float_frame):
187+
188+
op = all_arithmetic_operators
189+
190+
# Check that arrays with dim >= 3 raise
191+
for dim in range(3, 6):
192+
arr = np.ones((1,) * dim)
193+
msg = "Unable to coerce to Series/DataFrame"
194+
with tm.assert_raises_regex(ValueError, msg):
195+
getattr(float_frame, op)(arr)
199196

200197
const_add = float_frame.add(1)
201198
tm.assert_frame_equal(const_add, float_frame + 1)
@@ -206,8 +203,10 @@ def test_arith_flex_frame(self, all_arithmetic_operators, int_frame,
206203

207204
result = float_frame[:0].add(float_frame)
208205
tm.assert_frame_equal(result, float_frame * np.nan)
206+
209207
with tm.assert_raises_regex(NotImplementedError, 'fill_value'):
210208
float_frame.add(float_frame.iloc[0], fill_value=3)
209+
211210
with tm.assert_raises_regex(NotImplementedError, 'fill_value'):
212211
float_frame.add(float_frame.iloc[0], axis='index', fill_value=3)
213212

0 commit comments

Comments
 (0)