Skip to content

Commit 3080dcf

Browse files
rhshadrachvladu
authored andcommitted
CLN/TST: Some cleanups in apply.test_invalid_arg (pandas-dev#40688)
1 parent 3ab361f commit 3080dcf

File tree

1 file changed

+48
-67
lines changed

1 file changed

+48
-67
lines changed

pandas/tests/apply/test_invalid_arg.py

+48-67
Original file line numberDiff line numberDiff line change
@@ -102,26 +102,6 @@ def test_series_nested_renamer(renamer):
102102
s.agg(renamer)
103103

104104

105-
def test_agg_dict_nested_renaming_depr_agg():
106-
107-
df = DataFrame({"A": range(5), "B": 5})
108-
109-
# nested renaming
110-
msg = r"nested renamer is not supported"
111-
with pytest.raises(SpecificationError, match=msg):
112-
df.agg({"A": {"foo": "min"}, "B": {"bar": "max"}})
113-
114-
115-
def test_agg_dict_nested_renaming_depr_transform():
116-
df = DataFrame({"A": range(5), "B": 5})
117-
118-
# nested renaming
119-
msg = r"nested renamer is not supported"
120-
with pytest.raises(SpecificationError, match=msg):
121-
# mypy identifies the argument as an invalid type
122-
df.transform({"A": {"foo": "min"}, "B": {"bar": "max"}})
123-
124-
125105
def test_apply_dict_depr():
126106

127107
tsdf = DataFrame(
@@ -134,6 +114,17 @@ def test_apply_dict_depr():
134114
tsdf.A.agg({"foo": ["sum", "mean"]})
135115

136116

117+
@pytest.mark.parametrize("method", ["agg", "transform"])
118+
def test_dict_nested_renaming_depr(method):
119+
120+
df = DataFrame({"A": range(5), "B": 5})
121+
122+
# nested renaming
123+
msg = r"nested renamer is not supported"
124+
with pytest.raises(SpecificationError, match=msg):
125+
getattr(df, method)({"A": {"foo": "min"}, "B": {"bar": "max"}})
126+
127+
137128
@pytest.mark.parametrize("method", ["apply", "agg", "transform"])
138129
@pytest.mark.parametrize("func", [{"B": "sum"}, {"B": ["sum"]}])
139130
def test_missing_column(method, func):
@@ -288,25 +279,21 @@ def test_transform_none_to_type():
288279
df.transform({"a": int})
289280

290281

291-
def test_apply_broadcast_error(int_frame_const_col):
282+
@pytest.mark.parametrize(
283+
"func",
284+
[
285+
lambda x: np.array([1, 2]).reshape(-1, 2),
286+
lambda x: [1, 2],
287+
lambda x: Series([1, 2]),
288+
],
289+
)
290+
def test_apply_broadcast_error(int_frame_const_col, func):
292291
df = int_frame_const_col
293292

294293
# > 1 ndim
295-
msg = "too many dims to broadcast"
294+
msg = "too many dims to broadcast|cannot broadcast result"
296295
with pytest.raises(ValueError, match=msg):
297-
df.apply(
298-
lambda x: np.array([1, 2]).reshape(-1, 2),
299-
axis=1,
300-
result_type="broadcast",
301-
)
302-
303-
# cannot broadcast
304-
msg = "cannot broadcast result"
305-
with pytest.raises(ValueError, match=msg):
306-
df.apply(lambda x: [1, 2], axis=1, result_type="broadcast")
307-
308-
with pytest.raises(ValueError, match=msg):
309-
df.apply(lambda x: Series([1, 2]), axis=1, result_type="broadcast")
296+
df.apply(func, axis=1, result_type="broadcast")
310297

311298

312299
def test_transform_and_agg_err_agg(axis, float_frame):
@@ -317,34 +304,47 @@ def test_transform_and_agg_err_agg(axis, float_frame):
317304
float_frame.agg(["max", "sqrt"], axis=axis)
318305

319306

320-
def test_transform_and_agg_err_series(string_series):
307+
@pytest.mark.parametrize(
308+
"func, msg",
309+
[
310+
(["sqrt", "max"], "cannot combine transform and aggregation"),
311+
(
312+
{"foo": np.sqrt, "bar": "sum"},
313+
"cannot perform both aggregation and transformation",
314+
),
315+
],
316+
)
317+
def test_transform_and_agg_err_series(string_series, func, msg):
321318
# we are trying to transform with an aggregator
322-
msg = "cannot combine transform and aggregation"
323319
with pytest.raises(ValueError, match=msg):
324320
with np.errstate(all="ignore"):
325-
string_series.agg(["sqrt", "max"])
321+
string_series.agg(func)
326322

327-
msg = "cannot perform both aggregation and transformation"
328-
with pytest.raises(ValueError, match=msg):
329-
with np.errstate(all="ignore"):
330-
string_series.agg({"foo": np.sqrt, "bar": "sum"})
331323

332-
333-
def test_transform_and_agg_err_frame(axis, float_frame):
324+
@pytest.mark.parametrize("func", [["max", "min"], ["max", "sqrt"]])
325+
def test_transform_wont_agg_frame(axis, float_frame, func):
334326
# GH 35964
335327
# cannot both transform and agg
336328
msg = "Function did not transform"
337329
with pytest.raises(ValueError, match=msg):
338-
float_frame.transform(["max", "min"], axis=axis)
330+
float_frame.transform(func, axis=axis)
331+
339332

333+
@pytest.mark.parametrize("func", [["min", "max"], ["sqrt", "max"]])
334+
def test_transform_wont_agg_series(string_series, func):
335+
# GH 35964
336+
# we are trying to transform with an aggregator
340337
msg = "Function did not transform"
341338
with pytest.raises(ValueError, match=msg):
342-
float_frame.transform(["max", "sqrt"], axis=axis)
339+
string_series.transform(func)
343340

344341

345-
def test_transform_reducer_raises(all_reductions, frame_or_series):
342+
@pytest.mark.parametrize(
343+
"op_wrapper", [lambda x: x, lambda x: [x], lambda x: {"A": x}, lambda x: {"A": [x]}]
344+
)
345+
def test_transform_reducer_raises(all_reductions, frame_or_series, op_wrapper):
346346
# GH 35964
347-
op = all_reductions
347+
op = op_wrapper(all_reductions)
348348

349349
obj = DataFrame({"A": [1, 2, 3]})
350350
if frame_or_series is not DataFrame:
@@ -353,22 +353,3 @@ def test_transform_reducer_raises(all_reductions, frame_or_series):
353353
msg = "Function did not transform"
354354
with pytest.raises(ValueError, match=msg):
355355
obj.transform(op)
356-
with pytest.raises(ValueError, match=msg):
357-
obj.transform([op])
358-
with pytest.raises(ValueError, match=msg):
359-
obj.transform({"A": op})
360-
with pytest.raises(ValueError, match=msg):
361-
obj.transform({"A": [op]})
362-
363-
364-
def test_transform_wont_agg(string_series):
365-
# GH 35964
366-
# we are trying to transform with an aggregator
367-
msg = "Function did not transform"
368-
with pytest.raises(ValueError, match=msg):
369-
string_series.transform(["min", "max"])
370-
371-
msg = "Function did not transform"
372-
with pytest.raises(ValueError, match=msg):
373-
with np.errstate(all="ignore"):
374-
string_series.transform(["sqrt", "max"])

0 commit comments

Comments
 (0)