18
18
Sequence ,
19
19
TypeVar ,
20
20
cast ,
21
+ overload ,
21
22
)
22
23
23
24
import numpy as np
24
25
25
26
from pandas ._libs import lib
26
27
from pandas ._typing import (
27
28
ArrayLike ,
29
+ AstypeArg ,
28
30
Dtype ,
29
31
FillnaOptions ,
30
32
PositionalIndexer ,
@@ -520,9 +522,21 @@ def nbytes(self) -> int:
520
522
# Additional Methods
521
523
# ------------------------------------------------------------------------
522
524
523
- def astype (self , dtype , copy = True ):
525
+ @overload
526
+ def astype (self , dtype : npt .DTypeLike , copy : bool = ...) -> np .ndarray :
527
+ ...
528
+
529
+ @overload
530
+ def astype (self , dtype : ExtensionDtype , copy : bool = ...) -> ExtensionArray :
531
+ ...
532
+
533
+ @overload
534
+ def astype (self , dtype : AstypeArg , copy : bool = ...) -> ArrayLike :
535
+ ...
536
+
537
+ def astype (self , dtype : AstypeArg , copy : bool = True ) -> ArrayLike :
524
538
"""
525
- Cast to a NumPy array with 'dtype'.
539
+ Cast to a NumPy array or ExtensionArray with 'dtype'.
526
540
527
541
Parameters
528
542
----------
@@ -535,8 +549,10 @@ def astype(self, dtype, copy=True):
535
549
536
550
Returns
537
551
-------
538
- array : ndarray
539
- NumPy ndarray with 'dtype' for its dtype.
552
+ array : np.ndarray or ExtensionArray
553
+ An ExtensionArray if dtype is StringDtype,
554
+ or same as that of underlying array.
555
+ Otherwise a NumPy ndarray with 'dtype' for its dtype.
540
556
"""
541
557
from pandas .core .arrays .string_ import StringDtype
542
558
@@ -552,7 +568,11 @@ def astype(self, dtype, copy=True):
552
568
# allow conversion to StringArrays
553
569
return dtype .construct_array_type ()._from_sequence (self , copy = False )
554
570
555
- return np .array (self , dtype = dtype , copy = copy )
571
+ # error: Argument "dtype" to "array" has incompatible type
572
+ # "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type,
573
+ # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
574
+ # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
575
+ return np .array (self , dtype = dtype , copy = copy ) # type: ignore[arg-type]
556
576
557
577
def isna (self ) -> np .ndarray | ExtensionArraySupportsAnyAll :
558
578
"""
@@ -863,6 +883,8 @@ def searchsorted(
863
883
# 2. Values between the values in the `data_for_sorting` fixture
864
884
# 3. Missing values.
865
885
arr = self .astype (object )
886
+ if isinstance (value , ExtensionArray ):
887
+ value = value .astype (object )
866
888
return arr .searchsorted (value , side = side , sorter = sorter )
867
889
868
890
def equals (self , other : object ) -> bool :
0 commit comments