diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index a6d1986937d2b..d904cc6e1b3de 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -529,10 +529,8 @@ def astype(self, dtype, copy=True): Returns ------- - array : ndarray - NumPy ndarray with 'dtype' for its dtype. + np.ndarray or ExtensionArray """ - from pandas.core.arrays.string_ import StringDtype dtype = pandas_dtype(dtype) if is_dtype_equal(dtype, self.dtype): @@ -541,10 +539,10 @@ def astype(self, dtype, copy=True): else: return self.copy() - # FIXME: Really hard-code here? - if isinstance(dtype, StringDtype): - # allow conversion to StringArrays - return dtype.construct_array_type()._from_sequence(self, copy=False) + if isinstance(dtype, ExtensionDtype): + # allow conversion to e.g. StringArrays + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) return np.array(self, dtype=dtype, copy=copy) diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index ecc45357db8c1..71a187cf5f4e0 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -101,7 +101,6 @@ ) import pandas.core.common as com from pandas.core.construction import ( - array as pd_array, extract_array, sanitize_array, ) @@ -494,19 +493,18 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: """ dtype = pandas_dtype(dtype) if self.dtype is dtype: - result = self.copy() if copy else self + return self.copy() if copy else self elif is_categorical_dtype(dtype): dtype = cast(Union[str, CategoricalDtype], dtype) # GH 10696/18593/18630 dtype = self.dtype.update_dtype(dtype) - self = self.copy() if copy else self - result = self._set_dtype(dtype) + obj = self.copy() if copy else self + return obj._set_dtype(dtype) - # TODO: consolidate with ndarray case? elif isinstance(dtype, ExtensionDtype): - result = pd_array(self, dtype=dtype, copy=copy) + return super().astype(dtype, copy=copy) elif is_integer_dtype(dtype) and self.isna().any(): raise ValueError("Cannot convert float NaN to integer") diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 08cb12a1373bb..f78e148850e67 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -69,7 +69,6 @@ is_datetime64tz_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, - is_extension_array_dtype, is_float_dtype, is_integer_dtype, is_list_like, @@ -82,6 +81,7 @@ ) from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, + ExtensionDtype, PeriodDtype, ) from pandas.core.dtypes.missing import ( @@ -385,14 +385,13 @@ def astype(self, dtype, copy: bool = True): # 3. DatetimeArray.astype handles datetime -> period dtype = pandas_dtype(dtype) + if isinstance(dtype, ExtensionDtype): + return super().astype(dtype=dtype, copy=copy) + if is_object_dtype(dtype): return self._box_values(self.asi8.ravel()).reshape(self.shape) - elif is_string_dtype(dtype) and not is_categorical_dtype(dtype): - if is_extension_array_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls._from_sequence(self, dtype=dtype, copy=copy) - else: - return self._format_native_types() + elif is_string_dtype(dtype): + return self._format_native_types() elif is_integer_dtype(dtype): # we deliberately ignore int32 vs. int64 here. # See https://github.com/pandas-dev/pandas/issues/24381 for more. @@ -422,9 +421,6 @@ def astype(self, dtype, copy: bool = True): # and conversions for any datetimelike to float msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" raise TypeError(msg) - elif is_categorical_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls(self, dtype=dtype) else: return np.asarray(self, dtype=dtype) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 2318cae004c5a..d054bfa1503e1 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -48,7 +48,10 @@ needs_i8_conversion, pandas_dtype, ) -from pandas.core.dtypes.dtypes import IntervalDtype +from pandas.core.dtypes.dtypes import ( + ExtensionDtype, + IntervalDtype, +) from pandas.core.dtypes.generic import ( ABCDataFrame, ABCDatetimeIndex, @@ -70,7 +73,6 @@ ExtensionArray, _extension_array_shared_docs, ) -from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.construction import ( array as pd_array, @@ -827,7 +829,6 @@ def astype(self, dtype, copy: bool = True): ExtensionArray or NumPy ndarray with 'dtype' for its dtype. """ from pandas import Index - from pandas.core.arrays.string_ import StringDtype if dtype is not None: dtype = pandas_dtype(dtype) @@ -848,13 +849,11 @@ def astype(self, dtype, copy: bool = True): ) raise TypeError(msg) from err return self._shallow_copy(new_left, new_right) - elif is_categorical_dtype(dtype): - return Categorical(np.asarray(self), dtype=dtype) - elif isinstance(dtype, StringDtype): - return dtype.construct_array_type()._from_sequence(self, copy=False) - # TODO: This try/except will be repeated. try: + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) return np.asarray(self).astype(dtype, copy=copy) except (TypeError, ValueError) as err: msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" diff --git a/pandas/core/arrays/masked.py b/pandas/core/arrays/masked.py index d274501143916..2d6eef44e23d4 100644 --- a/pandas/core/arrays/masked.py +++ b/pandas/core/arrays/masked.py @@ -320,8 +320,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike: return cls(data, mask, copy=False) if isinstance(dtype, ExtensionDtype): - eacls = dtype.construct_array_type() - return eacls._from_sequence(self, dtype=dtype, copy=copy) + return super().astype(dtype=dtype, copy=copy) raise NotImplementedError("subclass must implement astype to np.dtype") diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 8d150c8f6ad3d..edd0c6c4b563e 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -435,8 +435,8 @@ def astype(self, dtype, copy=True): values = arr.astype(dtype.numpy_dtype) return FloatingArray(values, mask, copy=False) elif isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - return cls._from_sequence(self, dtype=dtype, copy=copy) + return super().astype(dtype=dtype, copy=copy) + elif np.issubdtype(dtype, np.floating): arr = self._ndarray.copy() mask = self.isna() diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 433d45d94167d..b3b06bdf53cbd 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -1103,7 +1103,7 @@ def astype_nansafe( if issubclass(dtype.type, str): return lib.ensure_string_array(arr, skipna=skipna, convert_na_value=False) - elif is_datetime64_dtype(arr): + elif is_datetime64_dtype(arr.dtype): if dtype == np.int64: warnings.warn( f"casting {arr.dtype} values to int64 with .astype(...) " @@ -1123,7 +1123,7 @@ def astype_nansafe( raise TypeError(f"cannot astype a datetimelike from [{arr.dtype}] to [{dtype}]") - elif is_timedelta64_dtype(arr): + elif is_timedelta64_dtype(arr.dtype): if dtype == np.int64: warnings.warn( f"casting {arr.dtype} values to int64 with .astype(...) "