Skip to content

Commit 40274ac

Browse files
authored
TYP: Typing changes for ExtensionArray.astype (#41251)
1 parent 0a9f9ee commit 40274ac

File tree

12 files changed

+126
-40
lines changed

12 files changed

+126
-40
lines changed

pandas/_typing.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@
126126
]
127127

128128
# dtypes
129-
NpDtype = Union[str, np.dtype]
130-
Dtype = Union[
131-
"ExtensionDtype", NpDtype, type_t[Union[str, float, int, complex, bool, object]]
132-
]
129+
NpDtype = Union[str, np.dtype, type_t[Union[str, float, int, complex, bool, object]]]
130+
Dtype = Union["ExtensionDtype", NpDtype]
131+
AstypeArg = Union["ExtensionDtype", "npt.DTypeLike"]
133132
# DtypeArg specifies all allowable dtypes in a functions its dtype argument
134133
DtypeArg = Union[Dtype, Dict[Hashable, Dtype]]
135134
DtypeObj = Union[np.dtype, "ExtensionDtype"]

pandas/core/arrays/base.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
Sequence,
1919
TypeVar,
2020
cast,
21+
overload,
2122
)
2223

2324
import numpy as np
2425

2526
from pandas._libs import lib
2627
from pandas._typing import (
2728
ArrayLike,
29+
AstypeArg,
2830
Dtype,
2931
FillnaOptions,
3032
PositionalIndexer,
@@ -520,9 +522,21 @@ def nbytes(self) -> int:
520522
# Additional Methods
521523
# ------------------------------------------------------------------------
522524

523-
def astype(self, dtype, copy=True):
525+
@overload
526+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
527+
...
528+
529+
@overload
530+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
531+
...
532+
533+
@overload
534+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
535+
...
536+
537+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
524538
"""
525-
Cast to a NumPy array with 'dtype'.
539+
Cast to a NumPy array or ExtensionArray with 'dtype'.
526540
527541
Parameters
528542
----------
@@ -535,8 +549,10 @@ def astype(self, dtype, copy=True):
535549
536550
Returns
537551
-------
538-
array : ndarray
539-
NumPy ndarray with 'dtype' for its dtype.
552+
array : np.ndarray or ExtensionArray
553+
An ExtensionArray if dtype is StringDtype,
554+
or same as that of underlying array.
555+
Otherwise a NumPy ndarray with 'dtype' for its dtype.
540556
"""
541557
from pandas.core.arrays.string_ import StringDtype
542558

@@ -552,7 +568,11 @@ def astype(self, dtype, copy=True):
552568
# allow conversion to StringArrays
553569
return dtype.construct_array_type()._from_sequence(self, copy=False)
554570

555-
return np.array(self, dtype=dtype, copy=copy)
571+
# error: Argument "dtype" to "array" has incompatible type
572+
# "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type,
573+
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
574+
# Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
575+
return np.array(self, dtype=dtype, copy=copy) # type: ignore[arg-type]
556576

557577
def isna(self) -> np.ndarray | ExtensionArraySupportsAnyAll:
558578
"""
@@ -863,6 +883,8 @@ def searchsorted(
863883
# 2. Values between the values in the `data_for_sorting` fixture
864884
# 3. Missing values.
865885
arr = self.astype(object)
886+
if isinstance(value, ExtensionArray):
887+
value = value.astype(object)
866888
return arr.searchsorted(value, side=side, sorter=sorter)
867889

868890
def equals(self, other: object) -> bool:

pandas/core/arrays/boolean.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
import numbers
4-
from typing import TYPE_CHECKING
4+
from typing import (
5+
TYPE_CHECKING,
6+
overload,
7+
)
58
import warnings
69

710
import numpy as np
@@ -12,7 +15,9 @@
1215
)
1316
from pandas._typing import (
1417
ArrayLike,
18+
AstypeArg,
1519
Dtype,
20+
npt,
1621
type_t,
1722
)
1823
from pandas.compat.numpy import function as nv
@@ -33,6 +38,7 @@
3338
from pandas.core.dtypes.missing import isna
3439

3540
from pandas.core import ops
41+
from pandas.core.arrays import ExtensionArray
3642
from pandas.core.arrays.masked import (
3743
BaseMaskedArray,
3844
BaseMaskedDtype,
@@ -392,7 +398,20 @@ def reconstruct(x):
392398
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
393399
return coerce_to_array(value)
394400

395-
def astype(self, dtype, copy: bool = True) -> ArrayLike:
401+
@overload
402+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
403+
...
404+
405+
@overload
406+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
407+
...
408+
409+
@overload
410+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
411+
...
412+
413+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
414+
396415
"""
397416
Cast to a NumPy array or ExtensionArray with 'dtype'.
398417

pandas/core/arrays/categorical.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TypeVar,
1212
Union,
1313
cast,
14+
overload,
1415
)
1516
from warnings import (
1617
catch_warnings,
@@ -32,6 +33,7 @@
3233
from pandas._libs.lib import no_default
3334
from pandas._typing import (
3435
ArrayLike,
36+
AstypeArg,
3537
Dtype,
3638
NpDtype,
3739
Ordered,
@@ -482,7 +484,19 @@ def _constructor(self) -> type[Categorical]:
482484
def _from_sequence(cls, scalars, *, dtype: Dtype | None = None, copy=False):
483485
return Categorical(scalars, dtype=dtype, copy=copy)
484486

485-
def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
487+
@overload
488+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
489+
...
490+
491+
@overload
492+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
493+
...
494+
495+
@overload
496+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
497+
...
498+
499+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
486500
"""
487501
Coerce this type to another dtype
488502
@@ -2458,11 +2472,7 @@ def _str_get_dummies(self, sep="|"):
24582472
# sep may not be in categories. Just bail on this.
24592473
from pandas.core.arrays import PandasArray
24602474

2461-
# error: Argument 1 to "PandasArray" has incompatible type
2462-
# "ExtensionArray"; expected "Union[ndarray, PandasArray]"
2463-
return PandasArray(self.astype(str))._str_get_dummies( # type: ignore[arg-type]
2464-
sep
2465-
)
2475+
return PandasArray(self.astype(str))._str_get_dummies(sep)
24662476

24672477

24682478
# The Series.cat accessor

pandas/core/arrays/floating.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import overload
34
import warnings
45

56
import numpy as np
@@ -10,7 +11,9 @@
1011
)
1112
from pandas._typing import (
1213
ArrayLike,
14+
AstypeArg,
1315
DtypeObj,
16+
npt,
1417
)
1518
from pandas.compat.numpy import function as nv
1619
from pandas.util._decorators import cache_readonly
@@ -31,6 +34,7 @@
3134
)
3235
from pandas.core.dtypes.missing import isna
3336

37+
from pandas.core.arrays import ExtensionArray
3438
from pandas.core.arrays.numeric import (
3539
NumericArray,
3640
NumericDtype,
@@ -271,7 +275,19 @@ def _from_sequence_of_strings(
271275
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
272276
return coerce_to_array(value, dtype=self.dtype)
273277

274-
def astype(self, dtype, copy: bool = True) -> ArrayLike:
278+
@overload
279+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
280+
...
281+
282+
@overload
283+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
284+
...
285+
286+
@overload
287+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
288+
...
289+
290+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
275291
"""
276292
Cast to a NumPy array or ExtensionArray with 'dtype'.
277293

pandas/core/arrays/integer.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import overload
34
import warnings
45

56
import numpy as np
@@ -11,8 +12,10 @@
1112
)
1213
from pandas._typing import (
1314
ArrayLike,
15+
AstypeArg,
1416
Dtype,
1517
DtypeObj,
18+
npt,
1619
)
1720
from pandas.compat.numpy import function as nv
1821
from pandas.util._decorators import cache_readonly
@@ -33,6 +36,7 @@
3336
)
3437
from pandas.core.dtypes.missing import isna
3538

39+
from pandas.core.arrays import ExtensionArray
3640
from pandas.core.arrays.masked import (
3741
BaseMaskedArray,
3842
BaseMaskedDtype,
@@ -333,7 +337,19 @@ def _from_sequence_of_strings(
333337
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
334338
return coerce_to_array(value, dtype=self.dtype)
335339

336-
def astype(self, dtype, copy: bool = True) -> ArrayLike:
340+
@overload
341+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
342+
...
343+
344+
@overload
345+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
346+
...
347+
348+
@overload
349+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
350+
...
351+
352+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
337353
"""
338354
Cast to a NumPy array or ExtensionArray with 'dtype'.
339355

pandas/core/arrays/masked.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Sequence,
77
TypeVar,
8+
overload,
89
)
910

1011
import numpy as np
@@ -15,10 +16,11 @@
1516
)
1617
from pandas._typing import (
1718
ArrayLike,
18-
Dtype,
19+
AstypeArg,
1920
NpDtype,
2021
PositionalIndexer,
2122
Scalar,
23+
npt,
2224
type_t,
2325
)
2426
from pandas.errors import AbstractMethodError
@@ -282,9 +284,7 @@ def to_numpy( # type: ignore[override]
282284
if na_value is lib.no_default:
283285
na_value = libmissing.NA
284286
if dtype is None:
285-
# error: Incompatible types in assignment (expression has type
286-
# "Type[object]", variable has type "Union[str, dtype[Any], None]")
287-
dtype = object # type: ignore[assignment]
287+
dtype = object
288288
if self._hasna:
289289
if (
290290
not is_object_dtype(dtype)
@@ -303,7 +303,19 @@ def to_numpy( # type: ignore[override]
303303
data = self._data.astype(dtype, copy=copy)
304304
return data
305305

306-
def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
306+
@overload
307+
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
308+
...
309+
310+
@overload
311+
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
312+
...
313+
314+
@overload
315+
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
316+
...
317+
318+
def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
307319
dtype = pandas_dtype(dtype)
308320

309321
if is_dtype_equal(dtype, self.dtype):

pandas/core/arrays/period.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,7 @@ def freq(self) -> BaseOffset:
351351
def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
352352
if dtype == "i8":
353353
return self.asi8
354-
# error: Non-overlapping equality check (left operand type: "Optional[Union[str,
355-
# dtype[Any]]]", right operand type: "Type[bool]")
356-
elif dtype == bool: # type: ignore[comparison-overlap]
354+
elif dtype == bool:
357355
return ~self._isnan
358356

359357
# This will raise TypeError for non-object dtypes

pandas/core/arrays/sparse/array.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pandas._libs.tslibs import NaT
2828
from pandas._typing import (
2929
ArrayLike,
30+
AstypeArg,
3031
Dtype,
3132
NpDtype,
3233
Scalar,
@@ -527,9 +528,7 @@ def __array__(self, dtype: NpDtype | None = None) -> np.ndarray:
527528
try:
528529
dtype = np.result_type(self.sp_values.dtype, type(fill_value))
529530
except TypeError:
530-
# error: Incompatible types in assignment (expression has type
531-
# "Type[object]", variable has type "Union[str, dtype[Any], None]")
532-
dtype = object # type: ignore[assignment]
531+
dtype = object
533532

534533
out = np.full(self.shape, fill_value, dtype=dtype)
535534
out[self.sp_index.to_int_index().indices] = self.sp_values
@@ -1072,7 +1071,7 @@ def _concat_same_type(
10721071

10731072
return cls(data, sparse_index=sp_index, fill_value=fill_value)
10741073

1075-
def astype(self, dtype: Dtype | None = None, copy=True):
1074+
def astype(self, dtype: AstypeArg | None = None, copy=True):
10761075
"""
10771076
Change the dtype of a SparseArray.
10781077

pandas/core/common.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,7 @@ def asarray_tuplesafe(values, dtype: NpDtype | None = None) -> np.ndarray:
232232
# expected "ndarray")
233233
return values._values # type: ignore[return-value]
234234

235-
# error: Non-overlapping container check (element type: "Union[str, dtype[Any],
236-
# None]", container item type: "type")
237-
if isinstance(values, list) and dtype in [ # type: ignore[comparison-overlap]
238-
np.object_,
239-
object,
240-
]:
235+
if isinstance(values, list) and dtype in [np.object_, object]:
241236
return construct_1d_object_array_from_listlike(values)
242237

243238
result = np.asarray(values, dtype=dtype)

0 commit comments

Comments
 (0)