Skip to content

Commit 7b6ab94

Browse files
committed
CLN: Decouple Series/DataFrame.transform
1 parent c413df6 commit 7b6ab94

File tree

5 files changed

+271
-11
lines changed

5 files changed

+271
-11
lines changed

pandas/core/frame.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
)
114114
from pandas.core.dtypes.missing import isna, na_value_for_dtype, notna
115115

116-
from pandas.core import algorithms, common as com, nanops, ops
116+
from pandas.core import algorithms, base, common as com, nanops, ops
117117
from pandas.core.accessor import CachedAccessor
118118
from pandas.core.aggregation import reconstruct_func, relabel_result
119119
from pandas.core.arrays import Categorical, ExtensionArray
@@ -7440,7 +7440,20 @@ def transform(self, func, axis=0, *args, **kwargs) -> "DataFrame":
74407440
axis = self._get_axis_number(axis)
74417441
if axis == 1:
74427442
return self.T.transform(func, *args, **kwargs).T
7443-
return super().transform(func, *args, **kwargs)
7443+
7444+
if isinstance(func, list):
7445+
func = {col: func for col in self}
7446+
elif isinstance(func, dict):
7447+
cols = sorted(set(func.keys()) - set(self.columns))
7448+
if len(cols) > 0:
7449+
raise base.SpecificationError(f"Column(s) {cols} do not exist")
7450+
if any(isinstance(v, dict) for v in func.values()):
7451+
# GH 15931 - deprecation of renaming keys
7452+
raise base.SpecificationError("nested renamer is not supported")
7453+
7454+
result = self._transform(func, *args, **kwargs)
7455+
7456+
return result
74447457

74457458
def apply(self, func, axis=0, raw=False, result_type=None, args=(), **kwds):
74467459
"""

pandas/core/generic.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -10750,9 +10750,48 @@ def transform(self, func, *args, **kwargs):
1075010750
1 1.000000 2.718282
1075110751
2 1.414214 7.389056
1075210752
"""
10753-
result = self.agg(func, *args, **kwargs)
10754-
if is_scalar(result) or len(result) != len(self):
10755-
raise ValueError("transforms cannot produce aggregated results")
10753+
raise NotImplementedError
10754+
10755+
def _transform(self, func, *args, **kwargs):
10756+
if isinstance(func, dict):
10757+
results = {}
10758+
for name, how in func.items():
10759+
colg = self._gotitem(name, ndim=1)
10760+
try:
10761+
results[name] = colg.transform(how, *args, **kwargs)
10762+
except Exception as e:
10763+
if str(e) == "Function did not transform":
10764+
raise e
10765+
10766+
# combine results
10767+
if len(results) == 0:
10768+
raise ValueError("Transform function failed")
10769+
from pandas.core.reshape.concat import concat
10770+
10771+
return concat(results, axis=1)
10772+
10773+
try:
10774+
if isinstance(func, str):
10775+
result = self._try_aggregate_string_function(func, *args, **kwargs)
10776+
else:
10777+
f = self._get_cython_func(func)
10778+
if f and not args and not kwargs:
10779+
result = getattr(self, f)()
10780+
else:
10781+
try:
10782+
result = self.apply(func, args=args, **kwargs)
10783+
except Exception:
10784+
result = func(self, *args, **kwargs)
10785+
10786+
except Exception:
10787+
raise ValueError("Transform function failed")
10788+
10789+
# Functions that transform may return empty Series/DataFrame
10790+
# when the dtype is not appropriate
10791+
if isinstance(result, NDFrame) and result.empty:
10792+
raise ValueError("Transform function failed")
10793+
if not isinstance(result, NDFrame) or not result.index.equals(self.index):
10794+
raise ValueError("Function did not transform")
1075610795

1075710796
return result
1075810797

pandas/core/series.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4083,7 +4083,15 @@ def aggregate(self, func=None, axis=0, *args, **kwargs):
40834083
def transform(self, func, axis=0, *args, **kwargs):
40844084
# Validate the axis parameter
40854085
self._get_axis_number(axis)
4086-
return super().transform(func, *args, **kwargs)
4086+
4087+
if isinstance(func, list):
4088+
func = {com.get_callable_name(v) or v: v for v in func}
4089+
elif isinstance(func, dict):
4090+
if any(isinstance(v, dict) for v in func.values()):
4091+
raise base.SpecificationError("nested renamer is not supported")
4092+
4093+
result = self._transform(func, *args, **kwargs)
4094+
return result
40874095

40884096
def apply(self, func, convert_dtype=True, args=(), **kwds):
40894097
"""

pandas/tests/frame/apply/test_frame_apply.py

+112-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33
from itertools import chain
44
import operator
5+
import re
56
import warnings
67

78
import numpy as np
@@ -14,6 +15,7 @@
1415
import pandas._testing as tm
1516
from pandas.core.apply import frame_apply
1617
from pandas.core.base import SpecificationError
18+
from pandas.core.groupby.base import transformation_kernels
1719

1820

1921
@pytest.fixture
@@ -1131,9 +1133,29 @@ def test_agg_transform(self, axis, float_frame):
11311133
result = float_frame.transform([np.abs, "sqrt"], axis=axis)
11321134
tm.assert_frame_equal(result, expected)
11331135

1136+
# UDF via apply
1137+
def func(x):
1138+
if isinstance(x, DataFrame):
1139+
raise ValueError
1140+
return x + 1
1141+
1142+
result = float_frame.transform(func, axis=axis)
1143+
expected = float_frame + 1
1144+
tm.assert_frame_equal(result, expected)
1145+
1146+
# UDF that maps DataFrame -> DataFrame
1147+
def func(x):
1148+
if not isinstance(x, DataFrame):
1149+
raise ValueError
1150+
return x + 1
1151+
1152+
result = float_frame.transform(func, axis=axis)
1153+
expected = float_frame + 1
1154+
tm.assert_frame_equal(result, expected)
1155+
11341156
def test_transform_and_agg_err(self, axis, float_frame):
11351157
# cannot both transform and agg
1136-
msg = "transforms cannot produce aggregated results"
1158+
msg = "Function did not transform"
11371159
with pytest.raises(ValueError, match=msg):
11381160
float_frame.transform(["max", "min"], axis=axis)
11391161

@@ -1142,6 +1164,7 @@ def test_transform_and_agg_err(self, axis, float_frame):
11421164
with np.errstate(all="ignore"):
11431165
float_frame.agg(["max", "sqrt"], axis=axis)
11441166

1167+
msg = "Function did not transform"
11451168
with pytest.raises(ValueError, match=msg):
11461169
with np.errstate(all="ignore"):
11471170
float_frame.transform(["max", "sqrt"], axis=axis)
@@ -1221,6 +1244,9 @@ def test_agg_dict_nested_renaming_depr(self):
12211244
with pytest.raises(SpecificationError, match=msg):
12221245
df.agg({"A": {"foo": "min"}, "B": {"bar": "max"}})
12231246

1247+
with pytest.raises(SpecificationError, match=msg):
1248+
df.transform({"A": {"foo": "min"}, "B": {"bar": "max"}})
1249+
12241250
def test_agg_reduce(self, axis, float_frame):
12251251
other_axis = 1 if axis in {0, "index"} else 0
12261252
name1, name2 = float_frame.axes[other_axis].unique()[:2].sort_values()
@@ -1550,3 +1576,88 @@ def test_apply_empty_list_reduce():
15501576
result = df.apply(lambda x: [], result_type="reduce")
15511577
expected = pd.Series({"a": [], "b": []}, dtype=object)
15521578
tm.assert_series_equal(result, expected)
1579+
1580+
1581+
def test_transform_reducer_raises(all_reductions):
1582+
op = all_reductions
1583+
s = pd.DataFrame({"A": [1, 2, 3]})
1584+
msg = "Function did not transform"
1585+
with pytest.raises(ValueError, match=msg):
1586+
s.transform(op)
1587+
with pytest.raises(ValueError, match=msg):
1588+
s.transform([op])
1589+
with pytest.raises(ValueError, match=msg):
1590+
s.transform({"A": op})
1591+
with pytest.raises(ValueError, match=msg):
1592+
s.transform({"A": [op]})
1593+
1594+
1595+
# mypy doesn't allow adding lists of different types
1596+
# https://github.com/python/mypy/issues/5492
1597+
@pytest.mark.parametrize("op", [*transformation_kernels, lambda x: x + 1])
1598+
def test_transform_bad_dtype(op):
1599+
s = pd.DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms
1600+
if op in ("backfill", "shift", "pad", "bfill", "ffill"):
1601+
pytest.xfail("Transform function works on any datatype")
1602+
msg = "Transform function failed"
1603+
with pytest.raises(ValueError, match=msg):
1604+
s.transform(op)
1605+
with pytest.raises(ValueError, match=msg):
1606+
s.transform([op])
1607+
with pytest.raises(ValueError, match=msg):
1608+
s.transform({"A": op})
1609+
with pytest.raises(ValueError, match=msg):
1610+
s.transform({"A": [op]})
1611+
1612+
1613+
@pytest.mark.parametrize("op", transformation_kernels)
1614+
def test_transform_multi_dtypes(op):
1615+
df = pd.DataFrame({"A": ["a", "b", "c"], "B": [1, 2, 3]})
1616+
1617+
# Determine which columns op will work on
1618+
columns = []
1619+
for column in df:
1620+
try:
1621+
df[column].transform(op)
1622+
columns.append(column)
1623+
except Exception:
1624+
pass
1625+
1626+
if len(columns) > 0:
1627+
expected = df[columns].transform([op])
1628+
result = df.transform([op])
1629+
tm.assert_equal(result, expected)
1630+
1631+
expected = df[columns].transform({column: op for column in columns})
1632+
result = df.transform({column: op for column in columns})
1633+
tm.assert_equal(result, expected)
1634+
1635+
expected = df[columns].transform({column: [op] for column in columns})
1636+
result = df.transform({column: [op] for column in columns})
1637+
tm.assert_equal(result, expected)
1638+
1639+
1640+
@pytest.mark.parametrize("use_apply", [True, False])
1641+
def test_transform_passes_args(use_apply):
1642+
# transform uses UDF either via apply or passing the entire DataFrame
1643+
expected_args = [1, 2]
1644+
expected_kwargs = {"c": 3}
1645+
1646+
def f(x, a, b, c):
1647+
# transform is using apply iff x is not a DataFrame
1648+
if use_apply == isinstance(x, DataFrame):
1649+
# Force transform to fallback
1650+
raise ValueError
1651+
assert [a, b] == expected_args
1652+
assert c == expected_kwargs["c"]
1653+
return x
1654+
1655+
pd.DataFrame([1]).transform(f, 0, *expected_args, **expected_kwargs)
1656+
1657+
1658+
@pytest.mark.parametrize("axis", [0, "index", 1, "columns"])
1659+
def test_transform_missing_columns(axis):
1660+
df = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
1661+
match = re.escape("Column(s) ['C'] do not exist")
1662+
with pytest.raises(SpecificationError, match=match):
1663+
df.transform({"C": "cumsum"})

pandas/tests/series/apply/test_series_apply.py

+93-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pandas import DataFrame, Index, MultiIndex, Series, isna
99
import pandas._testing as tm
1010
from pandas.core.base import SpecificationError
11+
from pandas.core.groupby.base import transformation_kernels
1112

1213

1314
class TestSeriesApply:
@@ -222,7 +223,7 @@ def test_transform(self, string_series):
222223
expected.columns = ["sqrt"]
223224
tm.assert_frame_equal(result, expected)
224225

225-
result = string_series.transform([np.sqrt])
226+
result = string_series.apply([np.sqrt])
226227
tm.assert_frame_equal(result, expected)
227228

228229
result = string_series.transform(["sqrt"])
@@ -248,9 +249,34 @@ def test_transform(self, string_series):
248249
result = string_series.apply({"foo": np.sqrt, "bar": np.abs})
249250
tm.assert_series_equal(result.reindex_like(expected), expected)
250251

252+
expected = pd.concat([f_sqrt, f_abs], axis=1)
253+
expected.columns = ["foo", "bar"]
254+
result = string_series.transform({"foo": np.sqrt, "bar": np.abs})
255+
tm.assert_frame_equal(result, expected)
256+
257+
# UDF via apply
258+
def func(x):
259+
if isinstance(x, Series):
260+
raise ValueError
261+
return x + 1
262+
263+
result = string_series.transform(func)
264+
expected = string_series + 1
265+
tm.assert_series_equal(result, expected)
266+
267+
# UDF that maps Series -> Series
268+
def func(x):
269+
if not isinstance(x, Series):
270+
raise ValueError
271+
return x + 1
272+
273+
result = string_series.transform(func)
274+
expected = string_series + 1
275+
tm.assert_series_equal(result, expected)
276+
251277
def test_transform_and_agg_error(self, string_series):
252278
# we are trying to transform with an aggregator
253-
msg = "transforms cannot produce aggregated results"
279+
msg = "Function did not transform"
254280
with pytest.raises(ValueError, match=msg):
255281
string_series.transform(["min", "max"])
256282

@@ -259,6 +285,7 @@ def test_transform_and_agg_error(self, string_series):
259285
with np.errstate(all="ignore"):
260286
string_series.agg(["sqrt", "max"])
261287

288+
msg = "Function did not transform"
262289
with pytest.raises(ValueError, match=msg):
263290
with np.errstate(all="ignore"):
264291
string_series.transform(["sqrt", "max"])
@@ -467,11 +494,73 @@ def test_transform_none_to_type(self):
467494
# GH34377
468495
df = pd.DataFrame({"a": [None]})
469496

470-
msg = "DataFrame constructor called with incompatible data and dtype"
471-
with pytest.raises(TypeError, match=msg):
497+
msg = "Transform function failed.*"
498+
with pytest.raises(ValueError, match=msg):
472499
df.transform({"a": int})
473500

474501

502+
def test_transform_reducer_raises(all_reductions):
503+
op = all_reductions
504+
s = pd.Series([1, 2, 3])
505+
msg = "Function did not transform"
506+
with pytest.raises(ValueError, match=msg):
507+
s.transform(op)
508+
with pytest.raises(ValueError, match=msg):
509+
s.transform([op])
510+
with pytest.raises(ValueError, match=msg):
511+
s.transform({"A": op})
512+
with pytest.raises(ValueError, match=msg):
513+
s.transform({"A": [op]})
514+
515+
516+
# mypy doesn't allow adding lists of different types
517+
# https://github.com/python/mypy/issues/5492
518+
@pytest.mark.parametrize("op", [*transformation_kernels, lambda x: x + 1])
519+
def test_transform_bad_dtype(op):
520+
s = pd.Series(3 * [object]) # Series that will fail on most transforms
521+
if op in ("backfill", "shift", "pad", "bfill", "ffill"):
522+
pytest.xfail("Transform function works on any datatype")
523+
msg = "Transform function failed"
524+
with pytest.raises(ValueError, match=msg):
525+
s.transform(op)
526+
with pytest.raises(ValueError, match=msg):
527+
s.transform([op])
528+
with pytest.raises(ValueError, match=msg):
529+
s.transform({"A": op})
530+
with pytest.raises(ValueError, match=msg):
531+
s.transform({"A": [op]})
532+
533+
534+
@pytest.mark.parametrize("use_apply", [True, False])
535+
def test_transform_passes_args(use_apply):
536+
# transform uses UDF either via apply or passing the entire Series
537+
expected_args = [1, 2]
538+
expected_kwargs = {"c": 3}
539+
540+
def f(x, a, b, c):
541+
# transform is using apply iff x is not a Series
542+
if use_apply == isinstance(x, Series):
543+
# Force transform to fallback
544+
raise ValueError
545+
assert [a, b] == expected_args
546+
assert c == expected_kwargs["c"]
547+
return x
548+
549+
pd.Series([1]).transform(f, 0, *expected_args, **expected_kwargs)
550+
551+
552+
def test_transform_axis_1_raises():
553+
msg = "No axis named 1 for object type Series"
554+
with pytest.raises(ValueError, match=msg):
555+
pd.Series([1]).transform("sum", axis=1)
556+
557+
558+
def test_transform_nested_renamer():
559+
match = "nested renamer is not supported"
560+
with pytest.raises(SpecificationError, match=match):
561+
pd.Series([1]).transform({"A": {"B": ["sum"]}})
562+
563+
475564
class TestSeriesMap:
476565
def test_map(self, datetime_series):
477566
index, data = tm.getMixedTypeDict()

0 commit comments

Comments
 (0)