From fe2986adead52838a89c95120e3f951e6e410905 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 21 Oct 2021 09:22:49 -0700 Subject: [PATCH] REF: simplify EA.astype --- pandas/core/arrays/base.py | 17 +++++------------ pandas/core/arrays/categorical.py | 3 +-- pandas/core/arrays/datetimelike.py | 16 ++++++---------- pandas/core/arrays/interval.py | 19 ++++++------------- pandas/core/arrays/string_.py | 3 +-- pandas/core/arrays/string_arrow.py | 7 +------ 6 files changed, 20 insertions(+), 45 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bf54f7166e14d..8365e58632c70 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -561,11 +561,9 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: Returns ------- array : np.ndarray or ExtensionArray - An ExtensionArray if dtype is StringDtype, - or same as that of underlying array. + An ExtensionArray if dtype is ExtensionDtype, Otherwise a NumPy ndarray with 'dtype' for its dtype. """ - from pandas.core.arrays.string_ import StringDtype dtype = pandas_dtype(dtype) if is_dtype_equal(dtype, self.dtype): @@ -574,16 +572,11 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: else: return self.copy() - # FIXME: Really hard-code here? - if isinstance(dtype, StringDtype): - # allow conversion to StringArrays - return dtype.construct_array_type()._from_sequence(self, copy=False) + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + return cls._from_sequence(self, dtype=dtype, copy=copy) - # error: Argument "dtype" to "array" has incompatible type - # "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type, - # _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int, - # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]" - return np.array(self, dtype=dtype, copy=copy) # type: ignore[arg-type] + return np.array(self, dtype=dtype, copy=copy) def isna(self) -> np.ndarray | ExtensionArraySupportsAnyAll: """ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index a4d6c0f3cd832..35e2dd25678e5 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -110,7 +110,6 @@ ) import pandas.core.common as com from pandas.core.construction import ( - array as pd_array, extract_array, sanitize_array, ) @@ -527,7 +526,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike: # TODO: consolidate with ndarray case? elif isinstance(dtype, ExtensionDtype): - result = pd_array(self, dtype=dtype, copy=copy) + return super().astype(dtype, copy=copy) elif is_integer_dtype(dtype) and self.isna().any(): raise ValueError("Cannot convert float NaN to integer") diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index d1b926bd25055..72c00dfe7c65a 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -74,7 +74,6 @@ is_datetime64tz_dtype, is_datetime_or_timedelta_dtype, is_dtype_equal, - is_extension_array_dtype, is_float_dtype, is_integer_dtype, is_list_like, @@ -87,6 +86,7 @@ ) from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, + ExtensionDtype, PeriodDtype, ) from pandas.core.dtypes.missing import ( @@ -407,12 +407,11 @@ def astype(self, dtype, copy: bool = True): if is_object_dtype(dtype): return self._box_values(self.asi8.ravel()).reshape(self.shape) - elif is_string_dtype(dtype) and not is_categorical_dtype(dtype): - if is_extension_array_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls._from_sequence(self, dtype=dtype, copy=copy) - else: - return self._format_native_types() + + elif isinstance(dtype, ExtensionDtype): + return super().astype(dtype, copy=copy) + elif is_string_dtype(dtype): + return self._format_native_types() elif is_integer_dtype(dtype): # we deliberately ignore int32 vs. int64 here. # See https://github.com/pandas-dev/pandas/issues/24381 for more. @@ -442,9 +441,6 @@ def astype(self, dtype, copy: bool = True): # and conversions for any datetimelike to float msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" raise TypeError(msg) - elif is_categorical_dtype(dtype): - arr_cls = dtype.construct_array_type() - return arr_cls(self, dtype=dtype) else: return np.asarray(self, dtype=dtype) diff --git a/pandas/core/arrays/interval.py b/pandas/core/arrays/interval.py index 68365613c8c77..d5718d59bf8b0 100644 --- a/pandas/core/arrays/interval.py +++ b/pandas/core/arrays/interval.py @@ -79,7 +79,6 @@ ExtensionArray, _extension_array_shared_docs, ) -from pandas.core.arrays.categorical import Categorical import pandas.core.common as com from pandas.core.construction import ( array as pd_array, @@ -850,7 +849,6 @@ def astype(self, dtype, copy: bool = True): ExtensionArray or NumPy ndarray with 'dtype' for its dtype. """ from pandas import Index - from pandas.core.arrays.string_ import StringDtype if dtype is not None: dtype = pandas_dtype(dtype) @@ -871,17 +869,12 @@ def astype(self, dtype, copy: bool = True): ) raise TypeError(msg) from err return self._shallow_copy(new_left, new_right) - elif is_categorical_dtype(dtype): - return Categorical(np.asarray(self), dtype=dtype) - elif isinstance(dtype, StringDtype): - return dtype.construct_array_type()._from_sequence(self, copy=False) - - # TODO: This try/except will be repeated. - try: - return np.asarray(self).astype(dtype, copy=copy) - except (TypeError, ValueError) as err: - msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" - raise TypeError(msg) from err + else: + try: + return super().astype(dtype, copy=copy) + except (TypeError, ValueError) as err: + msg = f"Cannot cast {type(self).__name__} to dtype {dtype}" + raise TypeError(msg) from err def equals(self, other) -> bool: if type(self) != type(other): diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index d93fa4bbdd7fc..e9fb5bdf80045 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -437,8 +437,7 @@ def astype(self, dtype, copy=True): values = arr.astype(dtype.numpy_dtype) return FloatingArray(values, mask, copy=False) elif isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - return cls._from_sequence(self, dtype=dtype, copy=copy) + return super().astype(dtype, copy=copy) elif np.issubdtype(dtype, np.floating): arr = self._ndarray.copy() mask = self.isna() diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index c7d08f7873c09..a83cfa89c4728 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -35,7 +35,6 @@ from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs -from pandas.core.dtypes.base import ExtensionDtype from pandas.core.dtypes.common import ( is_array_like, is_bool_dtype, @@ -685,11 +684,7 @@ def astype(self, dtype, copy=True): data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype)) return dtype.__from_arrow__(data) - elif isinstance(dtype, ExtensionDtype): - cls = dtype.construct_array_type() - return cls._from_sequence(self, dtype=dtype, copy=copy) - - return super().astype(dtype, copy) + return super().astype(dtype, copy=copy) # ------------------------------------------------------------------------ # String methods interface