Skip to content

Commit dc3c4b7

Browse files
authored
REF: simplify EA.astype (#44133)
1 parent c8ad024 commit dc3c4b7

File tree

6 files changed

+20
-45
lines changed

6 files changed

+20
-45
lines changed

pandas/core/arrays/base.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,9 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
561561
Returns
562562
-------
563563
array : np.ndarray or ExtensionArray
564-
An ExtensionArray if dtype is StringDtype,
565-
or same as that of underlying array.
564+
An ExtensionArray if dtype is ExtensionDtype,
566565
Otherwise a NumPy ndarray with 'dtype' for its dtype.
567566
"""
568-
from pandas.core.arrays.string_ import StringDtype
569567

570568
dtype = pandas_dtype(dtype)
571569
if is_dtype_equal(dtype, self.dtype):
@@ -574,16 +572,11 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
574572
else:
575573
return self.copy()
576574

577-
# FIXME: Really hard-code here?
578-
if isinstance(dtype, StringDtype):
579-
# allow conversion to StringArrays
580-
return dtype.construct_array_type()._from_sequence(self, copy=False)
575+
if isinstance(dtype, ExtensionDtype):
576+
cls = dtype.construct_array_type()
577+
return cls._from_sequence(self, dtype=dtype, copy=copy)
581578

582-
# error: Argument "dtype" to "array" has incompatible type
583-
# "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type,
584-
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
585-
# Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
586-
return np.array(self, dtype=dtype, copy=copy) # type: ignore[arg-type]
579+
return np.array(self, dtype=dtype, copy=copy)
587580

588581
def isna(self) -> np.ndarray | ExtensionArraySupportsAnyAll:
589582
"""

pandas/core/arrays/categorical.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@
110110
)
111111
import pandas.core.common as com
112112
from pandas.core.construction import (
113-
array as pd_array,
114113
extract_array,
115114
sanitize_array,
116115
)
@@ -527,7 +526,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
527526

528527
# TODO: consolidate with ndarray case?
529528
elif isinstance(dtype, ExtensionDtype):
530-
result = pd_array(self, dtype=dtype, copy=copy)
529+
return super().astype(dtype, copy=copy)
531530

532531
elif is_integer_dtype(dtype) and self.isna().any():
533532
raise ValueError("Cannot convert float NaN to integer")

pandas/core/arrays/datetimelike.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
is_datetime64tz_dtype,
7575
is_datetime_or_timedelta_dtype,
7676
is_dtype_equal,
77-
is_extension_array_dtype,
7877
is_float_dtype,
7978
is_integer_dtype,
8079
is_list_like,
@@ -87,6 +86,7 @@
8786
)
8887
from pandas.core.dtypes.dtypes import (
8988
DatetimeTZDtype,
89+
ExtensionDtype,
9090
PeriodDtype,
9191
)
9292
from pandas.core.dtypes.missing import (
@@ -407,12 +407,11 @@ def astype(self, dtype, copy: bool = True):
407407

408408
if is_object_dtype(dtype):
409409
return self._box_values(self.asi8.ravel()).reshape(self.shape)
410-
elif is_string_dtype(dtype) and not is_categorical_dtype(dtype):
411-
if is_extension_array_dtype(dtype):
412-
arr_cls = dtype.construct_array_type()
413-
return arr_cls._from_sequence(self, dtype=dtype, copy=copy)
414-
else:
415-
return self._format_native_types()
410+
411+
elif isinstance(dtype, ExtensionDtype):
412+
return super().astype(dtype, copy=copy)
413+
elif is_string_dtype(dtype):
414+
return self._format_native_types()
416415
elif is_integer_dtype(dtype):
417416
# we deliberately ignore int32 vs. int64 here.
418417
# See https://github.com/pandas-dev/pandas/issues/24381 for more.
@@ -442,9 +441,6 @@ def astype(self, dtype, copy: bool = True):
442441
# and conversions for any datetimelike to float
443442
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
444443
raise TypeError(msg)
445-
elif is_categorical_dtype(dtype):
446-
arr_cls = dtype.construct_array_type()
447-
return arr_cls(self, dtype=dtype)
448444
else:
449445
return np.asarray(self, dtype=dtype)
450446

pandas/core/arrays/interval.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
ExtensionArray,
8080
_extension_array_shared_docs,
8181
)
82-
from pandas.core.arrays.categorical import Categorical
8382
import pandas.core.common as com
8483
from pandas.core.construction import (
8584
array as pd_array,
@@ -850,7 +849,6 @@ def astype(self, dtype, copy: bool = True):
850849
ExtensionArray or NumPy ndarray with 'dtype' for its dtype.
851850
"""
852851
from pandas import Index
853-
from pandas.core.arrays.string_ import StringDtype
854852

855853
if dtype is not None:
856854
dtype = pandas_dtype(dtype)
@@ -871,17 +869,12 @@ def astype(self, dtype, copy: bool = True):
871869
)
872870
raise TypeError(msg) from err
873871
return self._shallow_copy(new_left, new_right)
874-
elif is_categorical_dtype(dtype):
875-
return Categorical(np.asarray(self), dtype=dtype)
876-
elif isinstance(dtype, StringDtype):
877-
return dtype.construct_array_type()._from_sequence(self, copy=False)
878-
879-
# TODO: This try/except will be repeated.
880-
try:
881-
return np.asarray(self).astype(dtype, copy=copy)
882-
except (TypeError, ValueError) as err:
883-
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
884-
raise TypeError(msg) from err
872+
else:
873+
try:
874+
return super().astype(dtype, copy=copy)
875+
except (TypeError, ValueError) as err:
876+
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
877+
raise TypeError(msg) from err
885878

886879
def equals(self, other) -> bool:
887880
if type(self) != type(other):

pandas/core/arrays/string_.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,7 @@ def astype(self, dtype, copy=True):
437437
values = arr.astype(dtype.numpy_dtype)
438438
return FloatingArray(values, mask, copy=False)
439439
elif isinstance(dtype, ExtensionDtype):
440-
cls = dtype.construct_array_type()
441-
return cls._from_sequence(self, dtype=dtype, copy=copy)
440+
return super().astype(dtype, copy=copy)
442441
elif np.issubdtype(dtype, np.floating):
443442
arr = self._ndarray.copy()
444443
mask = self.isna()

pandas/core/arrays/string_arrow.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from pandas.util._decorators import doc
3636
from pandas.util._validators import validate_fillna_kwargs
3737

38-
from pandas.core.dtypes.base import ExtensionDtype
3938
from pandas.core.dtypes.common import (
4039
is_array_like,
4140
is_bool_dtype,
@@ -685,11 +684,7 @@ def astype(self, dtype, copy=True):
685684
data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
686685
return dtype.__from_arrow__(data)
687686

688-
elif isinstance(dtype, ExtensionDtype):
689-
cls = dtype.construct_array_type()
690-
return cls._from_sequence(self, dtype=dtype, copy=copy)
691-
692-
return super().astype(dtype, copy)
687+
return super().astype(dtype, copy=copy)
693688

694689
# ------------------------------------------------------------------------
695690
# String methods interface

0 commit comments

Comments
 (0)