Skip to content

REF: simplify EA.astype #44133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,9 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
Returns
-------
array : np.ndarray or ExtensionArray
An ExtensionArray if dtype is StringDtype,
or same as that of underlying array.
An ExtensionArray if dtype is ExtensionDtype,
Otherwise a NumPy ndarray with 'dtype' for its dtype.
"""
from pandas.core.arrays.string_ import StringDtype

dtype = pandas_dtype(dtype)
if is_dtype_equal(dtype, self.dtype):
Expand All @@ -574,16 +572,11 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
else:
return self.copy()

# FIXME: Really hard-code here?
if isinstance(dtype, StringDtype):
# allow conversion to StringArrays
return dtype.construct_array_type()._from_sequence(self, copy=False)
if isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

# error: Argument "dtype" to "array" has incompatible type
# "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None, type,
# _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
# Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
return np.array(self, dtype=dtype, copy=copy) # type: ignore[arg-type]
return np.array(self, dtype=dtype, copy=copy)

def isna(self) -> np.ndarray | ExtensionArraySupportsAnyAll:
"""
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
)
import pandas.core.common as com
from pandas.core.construction import (
array as pd_array,
extract_array,
sanitize_array,
)
Expand Down Expand Up @@ -527,7 +526,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:

# TODO: consolidate with ndarray case?
elif isinstance(dtype, ExtensionDtype):
result = pd_array(self, dtype=dtype, copy=copy)
return super().astype(dtype, copy=copy)

elif is_integer_dtype(dtype) and self.isna().any():
raise ValueError("Cannot convert float NaN to integer")
Expand Down
16 changes: 6 additions & 10 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
is_datetime64tz_dtype,
is_datetime_or_timedelta_dtype,
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_integer_dtype,
is_list_like,
Expand All @@ -87,6 +86,7 @@
)
from pandas.core.dtypes.dtypes import (
DatetimeTZDtype,
ExtensionDtype,
PeriodDtype,
)
from pandas.core.dtypes.missing import (
Expand Down Expand Up @@ -407,12 +407,11 @@ def astype(self, dtype, copy: bool = True):

if is_object_dtype(dtype):
return self._box_values(self.asi8.ravel()).reshape(self.shape)
elif is_string_dtype(dtype) and not is_categorical_dtype(dtype):
if is_extension_array_dtype(dtype):
arr_cls = dtype.construct_array_type()
return arr_cls._from_sequence(self, dtype=dtype, copy=copy)
else:
return self._format_native_types()

elif isinstance(dtype, ExtensionDtype):
return super().astype(dtype, copy=copy)
elif is_string_dtype(dtype):
return self._format_native_types()
elif is_integer_dtype(dtype):
# we deliberately ignore int32 vs. int64 here.
# See https://github.com/pandas-dev/pandas/issues/24381 for more.
Expand Down Expand Up @@ -442,9 +441,6 @@ def astype(self, dtype, copy: bool = True):
# and conversions for any datetimelike to float
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
raise TypeError(msg)
elif is_categorical_dtype(dtype):
arr_cls = dtype.construct_array_type()
return arr_cls(self, dtype=dtype)
else:
return np.asarray(self, dtype=dtype)

Expand Down
19 changes: 6 additions & 13 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
ExtensionArray,
_extension_array_shared_docs,
)
from pandas.core.arrays.categorical import Categorical
import pandas.core.common as com
from pandas.core.construction import (
array as pd_array,
Expand Down Expand Up @@ -850,7 +849,6 @@ def astype(self, dtype, copy: bool = True):
ExtensionArray or NumPy ndarray with 'dtype' for its dtype.
"""
from pandas import Index
from pandas.core.arrays.string_ import StringDtype

if dtype is not None:
dtype = pandas_dtype(dtype)
Expand All @@ -871,17 +869,12 @@ def astype(self, dtype, copy: bool = True):
)
raise TypeError(msg) from err
return self._shallow_copy(new_left, new_right)
elif is_categorical_dtype(dtype):
return Categorical(np.asarray(self), dtype=dtype)
elif isinstance(dtype, StringDtype):
return dtype.construct_array_type()._from_sequence(self, copy=False)

# TODO: This try/except will be repeated.
try:
return np.asarray(self).astype(dtype, copy=copy)
except (TypeError, ValueError) as err:
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
raise TypeError(msg) from err
else:
try:
return super().astype(dtype, copy=copy)
except (TypeError, ValueError) as err:
msg = f"Cannot cast {type(self).__name__} to dtype {dtype}"
raise TypeError(msg) from err

def equals(self, other) -> bool:
if type(self) != type(other):
Expand Down
3 changes: 1 addition & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,7 @@ def astype(self, dtype, copy=True):
values = arr.astype(dtype.numpy_dtype)
return FloatingArray(values, mask, copy=False)
elif isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)
return super().astype(dtype, copy=copy)
elif np.issubdtype(dtype, np.floating):
arr = self._ndarray.copy()
mask = self.isna()
Expand Down
7 changes: 1 addition & 6 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
is_array_like,
is_bool_dtype,
Expand Down Expand Up @@ -685,11 +684,7 @@ def astype(self, dtype, copy=True):
data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
return dtype.__from_arrow__(data)

elif isinstance(dtype, ExtensionDtype):
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

return super().astype(dtype, copy)
return super().astype(dtype, copy=copy)

# ------------------------------------------------------------------------
# String methods interface
Expand Down