Skip to content

REF: standardize astype in EA subclasses #41652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
12 changes: 5 additions & 7 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,8 @@ def astype(self, dtype, copy=True):

Returns
-------
array : ndarray
NumPy ndarray with 'dtype' for its dtype.
np.ndarray or ExtensionArray
"""
from pandas.core.arrays.string_ import StringDtype

dtype = pandas_dtype(dtype)
if is_dtype_equal(dtype, self.dtype):
Expand All @@ -541,10 +539,10 @@ def astype(self, dtype, copy=True):
else:
return self.copy()

# FIXME: Really hard-code here?
if isinstance(dtype, StringDtype):
# allow conversion to StringArrays
return dtype.construct_array_type()._from_sequence(self, copy=False)
if isinstance(dtype, ExtensionDtype):
# allow conversion to e.g. StringArrays
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

return np.array(self, dtype=dtype, copy=copy)

Expand Down
10 changes: 4 additions & 6 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@
)
import pandas.core.common as com
from pandas.core.construction import (
array as pd_array,
extract_array,
sanitize_array,
)
Expand Down Expand Up @@ -494,19 +493,18 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
"""
dtype = pandas_dtype(dtype)
if self.dtype is dtype:
result = self.copy() if copy else self
return self.copy() if copy else self

elif is_categorical_dtype(dtype):
dtype = cast(Union[str, CategoricalDtype], dtype)

# GH 10696/18593/18630
dtype = self.dtype.update_dtype(dtype)
self = self.copy() if copy else self
result = self._set_dtype(dtype)
obj = self.copy() if copy else self
return obj._set_dtype(dtype)

# TODO: consolidate with ndarray case?
elif isinstance(dtype, ExtensionDtype):
result = pd_array(self, dtype=dtype, copy=copy)
return super().astype(dtype, copy=copy)

elif is_integer_dtype(dtype) and self.isna().any():
raise ValueError("Cannot convert float NaN to integer")
Expand Down
16 changes: 6 additions & 10 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
is_datetime64tz_dtype,
is_datetime_or_timedelta_dtype,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_list_like,
Expand All @@ -82,6 +81,7 @@
)
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
PeriodDtype,
)
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -385,14 +385,13 @@ def astype(self, dtype, copy: bool = True):
# 3. DatetimeArray.astype handles datetime -> period
dtype = pandas_dtype(dtype)

if isinstance(dtype, ExtensionDtype):
return super().astype(dtype=dtype, copy=copy)

if is_object_dtype(dtype):
return self._box_values(self.asi8.ravel()).reshape(self.shape)
elif is_string_dtype(dtype) and not is_categorical_dtype(dtype):
if is_extension_array_dtype(dtype):
arr_cls = dtype.construct_array_type()
return arr_cls._from_sequence(self, dtype=dtype, copy=copy)
else:
return self._format_native_types()
elif is_string_dtype(dtype):
return self._format_native_types()
elif is_integer_dtype(dtype):
# we deliberately ignore int32 vs. int64 here.
# See https://github.com/pandas-dev/pandas/issues/24381 for more.
Expand Down Expand Up @@ -422,9 +421,6 @@ def astype(self, dtype, copy: bool = True):
# and conversions for any datetimelike to float
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
raise TypeError(msg)
elif is_categorical_dtype(dtype):
arr_cls = dtype.construct_array_type()
return arr_cls(self, dtype=dtype)
else:
return np.asarray(self, dtype=dtype)

Expand Down
15 changes: 7 additions & 8 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
needs_i8_conversion,
pandas_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.dtypes import (
ExtensionDtype,
IntervalDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCDatetimeIndex,
Expand All @@ -70,7 +73,6 @@
ExtensionArray,
_extension_array_shared_docs,
)
from pandas.core.arrays.categorical import Categorical
import pandas.core.common as com
from pandas.core.construction import (
array as pd_array,
Expand Down Expand Up @@ -827,7 +829,6 @@ def astype(self, dtype, copy: bool = True):
ExtensionArray or NumPy ndarray with 'dtype' for its dtype.
"""
from pandas import Index
from pandas.core.arrays.string_ import StringDtype

if dtype is not None:
dtype = pandas_dtype(dtype)
Expand All @@ -848,13 +849,11 @@ def astype(self, dtype, copy: bool = True):
)
raise TypeError(msg) from err
return self._shallow_copy(new_left, new_right)
elif is_categorical_dtype(dtype):
return Categorical(np.asarray(self), dtype=dtype)
elif isinstance(dtype, StringDtype):
return dtype.construct_array_type()._from_sequence(self, copy=False)

# TODO: This try/except will be repeated.
try:
if isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)
return np.asarray(self).astype(dtype, copy=copy)
except (TypeError, ValueError) as err:
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
return cls(data, mask, copy=False)

if isinstance(dtype, ExtensionDtype):
eacls = dtype.construct_array_type()
return eacls._from_sequence(self, dtype=dtype, copy=copy)
return super().astype(dtype=dtype, copy=copy)

raise NotImplementedError("subclass must implement astype to np.dtype")

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def astype(self, dtype, copy=True):
values = arr.astype(dtype.numpy_dtype)
return FloatingArray(values, mask, copy=False)
elif isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)
return super().astype(dtype=dtype, copy=copy)

elif np.issubdtype(dtype, np.floating):
arr = self._ndarray.copy()
mask = self.isna()
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ def astype_nansafe(
if issubclass(dtype.type, str):
return lib.ensure_string_array(arr, skipna=skipna, convert_na_value=False)

elif is_datetime64_dtype(arr):
elif is_datetime64_dtype(arr.dtype):
if dtype == np.int64:
warnings.warn(
f"casting {arr.dtype} values to int64 with .astype(...) "
Expand All @@ -1123,7 +1123,7 @@ def astype_nansafe(

raise TypeError(f"cannot astype a datetimelike from [{arr.dtype}] to [{dtype}]")

elif is_timedelta64_dtype(arr):
elif is_timedelta64_dtype(arr.dtype):
if dtype == np.int64:
warnings.warn(
f"casting {arr.dtype} values to int64 with .astype(...) "
Expand Down