|
18 | 18 | Union,
|
19 | 19 | )
|
20 | 20 |
|
21 |
| -from pandas._typing import AggFuncType, FrameOrSeries, Label |
| 21 | +from pandas._typing import AggFuncType, Axis, FrameOrSeries, Label |
22 | 22 |
|
23 | 23 | from pandas.core.dtypes.common import is_dict_like, is_list_like
|
| 24 | +from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries |
24 | 25 |
|
25 | 26 | from pandas.core.base import SpecificationError
|
26 | 27 | import pandas.core.common as com
|
@@ -384,3 +385,98 @@ def validate_func_kwargs(
|
384 | 385 | if not columns:
|
385 | 386 | raise TypeError(no_arg_message)
|
386 | 387 | return columns, func
|
| 388 | + |
| 389 | + |
| 390 | +def transform( |
| 391 | + obj: FrameOrSeries, func: AggFuncType, axis: Axis, *args, **kwargs, |
| 392 | +) -> FrameOrSeries: |
| 393 | + """ |
| 394 | + Transform a DataFrame or Series |
| 395 | +
|
| 396 | + Parameters |
| 397 | + ---------- |
| 398 | + obj : DataFrame or Series |
| 399 | + Object to compute the transform on. |
| 400 | + func : string, function, list, or dictionary |
| 401 | + Function(s) to compute the transform with. |
| 402 | + axis : {0 or 'index', 1 or 'columns'} |
| 403 | + Axis along which the function is applied: |
| 404 | +
|
| 405 | + * 0 or 'index': apply function to each column. |
| 406 | + * 1 or 'columns': apply function to each row. |
| 407 | +
|
| 408 | + Returns |
| 409 | + ------- |
| 410 | + DataFrame or Series |
| 411 | + Result of applying ``func`` along the given axis of the |
| 412 | + Series or DataFrame. |
| 413 | +
|
| 414 | + Raises |
| 415 | + ------ |
| 416 | + ValueError |
| 417 | + If the transform function fails or does not transform. |
| 418 | + """ |
| 419 | + from pandas.core.reshape.concat import concat |
| 420 | + |
| 421 | + is_series = obj.ndim == 1 |
| 422 | + |
| 423 | + if obj._get_axis_number(axis) == 1: |
| 424 | + assert not is_series |
| 425 | + return transform(obj.T, func, 0, *args, **kwargs).T |
| 426 | + |
| 427 | + if isinstance(func, list): |
| 428 | + if is_series: |
| 429 | + func = {com.get_callable_name(v) or v: v for v in func} |
| 430 | + else: |
| 431 | + func = {col: func for col in obj} |
| 432 | + |
| 433 | + if isinstance(func, dict): |
| 434 | + if not is_series: |
| 435 | + cols = sorted(set(func.keys()) - set(obj.columns)) |
| 436 | + if len(cols) > 0: |
| 437 | + raise SpecificationError(f"Column(s) {cols} do not exist") |
| 438 | + |
| 439 | + if any(isinstance(v, dict) for v in func.values()): |
| 440 | + # GH 15931 - deprecation of renaming keys |
| 441 | + raise SpecificationError("nested renamer is not supported") |
| 442 | + |
| 443 | + results = {} |
| 444 | + for name, how in func.items(): |
| 445 | + colg = obj._gotitem(name, ndim=1) |
| 446 | + try: |
| 447 | + results[name] = transform(colg, how, 0, *args, **kwargs) |
| 448 | + except Exception as e: |
| 449 | + if str(e) == "Function did not transform": |
| 450 | + raise e |
| 451 | + |
| 452 | + # combine results |
| 453 | + if len(results) == 0: |
| 454 | + raise ValueError("Transform function failed") |
| 455 | + return concat(results, axis=1) |
| 456 | + |
| 457 | + # func is either str or callable |
| 458 | + try: |
| 459 | + if isinstance(func, str): |
| 460 | + result = obj._try_aggregate_string_function(func, *args, **kwargs) |
| 461 | + else: |
| 462 | + f = obj._get_cython_func(func) |
| 463 | + if f and not args and not kwargs: |
| 464 | + result = getattr(obj, f)() |
| 465 | + else: |
| 466 | + try: |
| 467 | + result = obj.apply(func, args=args, **kwargs) |
| 468 | + except Exception: |
| 469 | + result = func(obj, *args, **kwargs) |
| 470 | + except Exception: |
| 471 | + raise ValueError("Transform function failed") |
| 472 | + |
| 473 | + # Functions that transform may return empty Series/DataFrame |
| 474 | + # when the dtype is not appropriate |
| 475 | + if isinstance(result, (ABCSeries, ABCDataFrame)) and result.empty: |
| 476 | + raise ValueError("Transform function failed") |
| 477 | + if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals( |
| 478 | + obj.index |
| 479 | + ): |
| 480 | + raise ValueError("Function did not transform") |
| 481 | + |
| 482 | + return result |
0 commit comments