Skip to content

Commit 5b08abb

Browse files
twoertweinnoatamir
authored andcommitted
TYP: tighten Axis (pandas-dev#48612)
* TYP: tighten Axis * allow 'rows'
1 parent 9795b46 commit 5b08abb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+385
-259
lines changed

pandas/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@
106106

107107
NumpyIndexT = TypeVar("NumpyIndexT", np.ndarray, "Index")
108108

109-
Axis = Union[str, int]
109+
AxisInt = int
110+
Axis = Union[AxisInt, Literal["index", "columns", "rows"]]
110111
IndexLabel = Union[Hashable, Sequence[Hashable]]
111112
Level = Hashable
112113
Shape = Tuple[int, ...]

pandas/compat/numpy/function.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
is_bool,
3030
is_integer,
3131
)
32-
from pandas._typing import Axis
32+
from pandas._typing import (
33+
Axis,
34+
AxisInt,
35+
)
3336
from pandas.errors import UnsupportedFunctionCall
3437
from pandas.util._validators import (
3538
validate_args,
@@ -413,7 +416,7 @@ def validate_resampler_func(method: str, args, kwargs) -> None:
413416
raise TypeError("too many arguments passed in")
414417

415418

416-
def validate_minmax_axis(axis: int | None, ndim: int = 1) -> None:
419+
def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
417420
"""
418421
Ensure that the axis argument passed to min, max, argmin, or argmax is zero
419422
or None, as otherwise it will be incorrectly ignored.

pandas/core/algorithms.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pandas._typing import (
3030
AnyArrayLike,
3131
ArrayLike,
32+
AxisInt,
3233
DtypeObj,
3334
IndexLabel,
3435
TakeIndexer,
@@ -1105,7 +1106,7 @@ def mode(
11051106

11061107
def rank(
11071108
values: ArrayLike,
1108-
axis: int = 0,
1109+
axis: AxisInt = 0,
11091110
method: str = "average",
11101111
na_option: str = "keep",
11111112
ascending: bool = True,
@@ -1483,7 +1484,7 @@ def get_indexer(current_indexer, other_indexer):
14831484
def take(
14841485
arr,
14851486
indices: TakeIndexer,
1486-
axis: int = 0,
1487+
axis: AxisInt = 0,
14871488
allow_fill: bool = False,
14881489
fill_value=None,
14891490
):
@@ -1675,7 +1676,7 @@ def searchsorted(
16751676
_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"}
16761677

16771678

1678-
def diff(arr, n: int, axis: int = 0):
1679+
def diff(arr, n: int, axis: AxisInt = 0):
16791680
"""
16801681
difference of n between self,
16811682
analogous to s-s.shift(n)

pandas/core/apply.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AggFuncTypeDict,
3232
AggObjType,
3333
Axis,
34+
AxisInt,
3435
NDFrameT,
3536
npt,
3637
)
@@ -104,7 +105,7 @@ def frame_apply(
104105

105106

106107
class Apply(metaclass=abc.ABCMeta):
107-
axis: int
108+
axis: AxisInt
108109

109110
def __init__(
110111
self,

pandas/core/array_algos/masked_reductions.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import numpy as np
1010

1111
from pandas._libs import missing as libmissing
12-
from pandas._typing import npt
12+
from pandas._typing import (
13+
AxisInt,
14+
npt,
15+
)
1316

1417
from pandas.core.nanops import check_below_min_count
1518

@@ -21,7 +24,7 @@ def _reductions(
2124
*,
2225
skipna: bool = True,
2326
min_count: int = 0,
24-
axis: int | None = None,
27+
axis: AxisInt | None = None,
2528
**kwargs,
2629
):
2730
"""
@@ -62,7 +65,7 @@ def sum(
6265
*,
6366
skipna: bool = True,
6467
min_count: int = 0,
65-
axis: int | None = None,
68+
axis: AxisInt | None = None,
6669
):
6770
return _reductions(
6871
np.sum, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
@@ -75,7 +78,7 @@ def prod(
7578
*,
7679
skipna: bool = True,
7780
min_count: int = 0,
78-
axis: int | None = None,
81+
axis: AxisInt | None = None,
7982
):
8083
return _reductions(
8184
np.prod, values=values, mask=mask, skipna=skipna, min_count=min_count, axis=axis
@@ -88,7 +91,7 @@ def _minmax(
8891
mask: npt.NDArray[np.bool_],
8992
*,
9093
skipna: bool = True,
91-
axis: int | None = None,
94+
axis: AxisInt | None = None,
9295
):
9396
"""
9497
Reduction for 1D masked array.
@@ -125,7 +128,7 @@ def min(
125128
mask: npt.NDArray[np.bool_],
126129
*,
127130
skipna: bool = True,
128-
axis: int | None = None,
131+
axis: AxisInt | None = None,
129132
):
130133
return _minmax(np.min, values=values, mask=mask, skipna=skipna, axis=axis)
131134

@@ -135,7 +138,7 @@ def max(
135138
mask: npt.NDArray[np.bool_],
136139
*,
137140
skipna: bool = True,
138-
axis: int | None = None,
141+
axis: AxisInt | None = None,
139142
):
140143
return _minmax(np.max, values=values, mask=mask, skipna=skipna, axis=axis)
141144

@@ -145,7 +148,7 @@ def mean(
145148
mask: npt.NDArray[np.bool_],
146149
*,
147150
skipna: bool = True,
148-
axis: int | None = None,
151+
axis: AxisInt | None = None,
149152
):
150153
if not values.size or mask.all():
151154
return libmissing.NA
@@ -157,7 +160,7 @@ def var(
157160
mask: npt.NDArray[np.bool_],
158161
*,
159162
skipna: bool = True,
160-
axis: int | None = None,
163+
axis: AxisInt | None = None,
161164
ddof: int = 1,
162165
):
163166
if not values.size or mask.all():

pandas/core/array_algos/take.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from pandas._typing import (
1717
ArrayLike,
18+
AxisInt,
1819
npt,
1920
)
2021

@@ -36,7 +37,7 @@
3637
def take_nd(
3738
arr: np.ndarray,
3839
indexer,
39-
axis: int = ...,
40+
axis: AxisInt = ...,
4041
fill_value=...,
4142
allow_fill: bool = ...,
4243
) -> np.ndarray:
@@ -47,7 +48,7 @@ def take_nd(
4748
def take_nd(
4849
arr: ExtensionArray,
4950
indexer,
50-
axis: int = ...,
51+
axis: AxisInt = ...,
5152
fill_value=...,
5253
allow_fill: bool = ...,
5354
) -> ArrayLike:
@@ -57,7 +58,7 @@ def take_nd(
5758
def take_nd(
5859
arr: ArrayLike,
5960
indexer,
60-
axis: int = 0,
61+
axis: AxisInt = 0,
6162
fill_value=lib.no_default,
6263
allow_fill: bool = True,
6364
) -> ArrayLike:
@@ -120,7 +121,7 @@ def take_nd(
120121
def _take_nd_ndarray(
121122
arr: np.ndarray,
122123
indexer: npt.NDArray[np.intp] | None,
123-
axis: int,
124+
axis: AxisInt,
124125
fill_value,
125126
allow_fill: bool,
126127
) -> np.ndarray:
@@ -287,7 +288,7 @@ def take_2d_multi(
287288

288289
@functools.lru_cache(maxsize=128)
289290
def _get_take_nd_function_cached(
290-
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int
291+
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: AxisInt
291292
):
292293
"""
293294
Part of _get_take_nd_function below that doesn't need `mask_info` and thus
@@ -324,7 +325,11 @@ def _get_take_nd_function_cached(
324325

325326

326327
def _get_take_nd_function(
327-
ndim: int, arr_dtype: np.dtype, out_dtype: np.dtype, axis: int = 0, mask_info=None
328+
ndim: int,
329+
arr_dtype: np.dtype,
330+
out_dtype: np.dtype,
331+
axis: AxisInt = 0,
332+
mask_info=None,
328333
):
329334
"""
330335
Get the appropriate "take" implementation for the given dimension, axis
@@ -503,7 +508,7 @@ def _take_nd_object(
503508
arr: np.ndarray,
504509
indexer: npt.NDArray[np.intp],
505510
out: np.ndarray,
506-
axis: int,
511+
axis: AxisInt,
507512
fill_value,
508513
mask_info,
509514
) -> None:

pandas/core/array_algos/transforms.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import numpy as np
88

9+
from pandas._typing import AxisInt
910

10-
def shift(values: np.ndarray, periods: int, axis: int, fill_value) -> np.ndarray:
11+
12+
def shift(values: np.ndarray, periods: int, axis: AxisInt, fill_value) -> np.ndarray:
1113
new_values = values
1214

1315
if periods == 0 or values.size == 0:

pandas/core/arrays/_mixins.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pandas._libs.arrays import NDArrayBacked
1818
from pandas._typing import (
1919
ArrayLike,
20+
AxisInt,
2021
Dtype,
2122
F,
2223
PositionalIndexer2D,
@@ -157,7 +158,7 @@ def take(
157158
*,
158159
allow_fill: bool = False,
159160
fill_value: Any = None,
160-
axis: int = 0,
161+
axis: AxisInt = 0,
161162
) -> NDArrayBackedExtensionArrayT:
162163
if allow_fill:
163164
fill_value = self._validate_scalar(fill_value)
@@ -192,15 +193,15 @@ def _values_for_factorize(self):
192193
return self._ndarray, self._internal_fill_value
193194

194195
# Signature of "argmin" incompatible with supertype "ExtensionArray"
195-
def argmin(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
196+
def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
196197
# override base class by adding axis keyword
197198
validate_bool_kwarg(skipna, "skipna")
198199
if not skipna and self._hasna:
199200
raise NotImplementedError
200201
return nargminmax(self, "argmin", axis=axis)
201202

202203
# Signature of "argmax" incompatible with supertype "ExtensionArray"
203-
def argmax(self, axis: int = 0, skipna: bool = True): # type: ignore[override]
204+
def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
204205
# override base class by adding axis keyword
205206
validate_bool_kwarg(skipna, "skipna")
206207
if not skipna and self._hasna:
@@ -216,7 +217,7 @@ def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
216217
def _concat_same_type(
217218
cls: type[NDArrayBackedExtensionArrayT],
218219
to_concat: Sequence[NDArrayBackedExtensionArrayT],
219-
axis: int = 0,
220+
axis: AxisInt = 0,
220221
) -> NDArrayBackedExtensionArrayT:
221222
dtypes = {str(x.dtype) for x in to_concat}
222223
if len(dtypes) != 1:
@@ -351,7 +352,7 @@ def fillna(
351352
# ------------------------------------------------------------------------
352353
# Reductions
353354

354-
def _wrap_reduction_result(self, axis: int | None, result):
355+
def _wrap_reduction_result(self, axis: AxisInt | None, result):
355356
if axis is None or self.ndim == 1:
356357
return self._box_func(result)
357358
return self._from_backing_data(result)

pandas/core/arrays/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pandas._typing import (
3131
ArrayLike,
3232
AstypeArg,
33+
AxisInt,
3334
Dtype,
3435
FillnaOptions,
3536
PositionalIndexer,
@@ -1137,7 +1138,7 @@ def factorize(
11371138
@Substitution(klass="ExtensionArray")
11381139
@Appender(_extension_array_shared_docs["repeat"])
11391140
def repeat(
1140-
self: ExtensionArrayT, repeats: int | Sequence[int], axis: int | None = None
1141+
self: ExtensionArrayT, repeats: int | Sequence[int], axis: AxisInt | None = None
11411142
) -> ExtensionArrayT:
11421143
nv.validate_repeat((), {"axis": axis})
11431144
ind = np.arange(len(self)).repeat(repeats)
@@ -1567,7 +1568,7 @@ def _fill_mask_inplace(
15671568
def _rank(
15681569
self,
15691570
*,
1570-
axis: int = 0,
1571+
axis: AxisInt = 0,
15711572
method: str = "average",
15721573
na_option: str = "keep",
15731574
ascending: bool = True,

pandas/core/arrays/categorical.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pandas._typing import (
4040
ArrayLike,
4141
AstypeArg,
42+
AxisInt,
4243
Dtype,
4344
NpDtype,
4445
Ordered,
@@ -1988,7 +1989,7 @@ def sort_values(
19881989
def _rank(
19891990
self,
19901991
*,
1991-
axis: int = 0,
1992+
axis: AxisInt = 0,
19921993
method: str = "average",
19931994
na_option: str = "keep",
19941995
ascending: bool = True,
@@ -2464,7 +2465,7 @@ def equals(self, other: object) -> bool:
24642465

24652466
@classmethod
24662467
def _concat_same_type(
2467-
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: int = 0
2468+
cls: type[CategoricalT], to_concat: Sequence[CategoricalT], axis: AxisInt = 0
24682469
) -> CategoricalT:
24692470
from pandas.core.dtypes.concat import union_categoricals
24702471

0 commit comments

Comments
 (0)