diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 71863c8925e89..d456f9c56e309 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -10,6 +10,7 @@ datetime, timedelta, ) +import inspect from typing import ( TYPE_CHECKING, Any, @@ -87,6 +88,7 @@ is_timedelta64_dtype, is_timedelta64_ns_dtype, is_unsigned_integer_dtype, + pandas_dtype, ) from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, @@ -1227,6 +1229,107 @@ def astype_nansafe( return arr.astype(dtype, copy=copy) +def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False) -> ArrayLike: + """ + Cast array (ndarray or ExtensionArray) to the new dtype. + + Parameters + ---------- + values : ndarray or ExtensionArray + dtype : dtype object + copy : bool, default False + copy if indicated + + Returns + ------- + ndarray or ExtensionArray + """ + if ( + values.dtype.kind in ["m", "M"] + and dtype.kind in ["i", "u"] + and isinstance(dtype, np.dtype) + and dtype.itemsize != 8 + ): + # TODO(2.0) remove special case once deprecation on DTA/TDA is enforced + msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]" + raise TypeError(msg) + + if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): + return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True) + + if is_dtype_equal(values.dtype, dtype): + if copy: + return values.copy() + return values + + if isinstance(values, ABCExtensionArray): + values = values.astype(dtype, copy=copy) + + else: + values = astype_nansafe(values, dtype, copy=copy) + + # in pandas we don't store numpy str dtypes, so convert to object + if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str): + values = np.array(values, dtype=object) + + return values + + +def astype_array_safe( + values: ArrayLike, dtype, copy: bool = False, errors: str = "raise" +) -> ArrayLike: + """ + Cast array (ndarray or ExtensionArray) to the new dtype. + + This basically is the implementation for DataFrame/Series.astype and + includes all custom logic for pandas (NaN-safety, converting str to object, + not allowing ) + + Parameters + ---------- + values : ndarray or ExtensionArray + dtype : str, dtype convertible + copy : bool, default False + copy if indicated + errors : str, {'raise', 'ignore'}, default 'raise' + - ``raise`` : allow exceptions to be raised + - ``ignore`` : suppress exceptions. On error return original object + + Returns + ------- + ndarray or ExtensionArray + """ + errors_legal_values = ("raise", "ignore") + + if errors not in errors_legal_values: + invalid_arg = ( + "Expected value of kwarg 'errors' to be one of " + f"{list(errors_legal_values)}. Supplied value is '{errors}'" + ) + raise ValueError(invalid_arg) + + if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype): + msg = ( + f"Expected an instance of {dtype.__name__}, " + "but got the class instead. Try instantiating 'dtype'." + ) + raise TypeError(msg) + + dtype = pandas_dtype(dtype) + + try: + new_values = astype_array(values, dtype, copy=copy) + except (ValueError, TypeError): + # e.g. astype_nansafe can fail on object-dtype of strings + # trying to convert to float + if errors == "ignore": + new_values = values + else: + raise + + return new_values + + def soft_convert_objects( values: np.ndarray, datetime: bool = True, diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 5001754017dda..48e27e7100d2f 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -28,6 +28,7 @@ from pandas.util._validators import validate_bool_kwarg from pandas.core.dtypes.cast import ( + astype_array_safe, find_common_type, infer_dtype_from_scalar, ) @@ -499,7 +500,7 @@ def downcast(self) -> ArrayManager: return self.apply_with_block("downcast") def astype(self, dtype, copy: bool = False, errors: str = "raise") -> ArrayManager: - return self.apply("astype", dtype=dtype, copy=copy) # , errors=errors) + return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors) def convert( self, diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index b65043be6fda6..f2b8499a316b7 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1,6 +1,5 @@ from __future__ import annotations -import inspect import re from typing import ( TYPE_CHECKING, @@ -36,8 +35,7 @@ from pandas.util._validators import validate_bool_kwarg from pandas.core.dtypes.cast import ( - astype_dt64_to_dt64tz, - astype_nansafe, + astype_array_safe, can_hold_element, find_common_type, infer_dtype_from, @@ -49,7 +47,6 @@ ) from pandas.core.dtypes.common import ( is_categorical_dtype, - is_datetime64_dtype, is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, @@ -652,33 +649,11 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"): ------- Block """ - errors_legal_values = ("raise", "ignore") - - if errors not in errors_legal_values: - invalid_arg = ( - "Expected value of kwarg 'errors' to be one of " - f"{list(errors_legal_values)}. Supplied value is '{errors}'" - ) - raise ValueError(invalid_arg) - - if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype): - msg = ( - f"Expected an instance of {dtype.__name__}, " - "but got the class instead. Try instantiating 'dtype'." - ) - raise TypeError(msg) - - dtype = pandas_dtype(dtype) + values = self.values + if values.dtype.kind in ["m", "M"]: + values = self.array_values() - try: - new_values = self._astype(dtype, copy=copy) - except (ValueError, TypeError): - # e.g. astype_nansafe can fail on object-dtype of strings - # trying to convert to float - if errors == "ignore": - new_values = self.values - else: - raise + new_values = astype_array_safe(values, dtype, copy=copy, errors=errors) newb = self.make_block(new_values) if newb.shape != self.shape: @@ -689,37 +664,6 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"): ) return newb - def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike: - values = self.values - if values.dtype.kind in ["m", "M"]: - values = self.array_values() - - if ( - values.dtype.kind in ["m", "M"] - and dtype.kind in ["i", "u"] - and isinstance(dtype, np.dtype) - and dtype.itemsize != 8 - ): - # TODO(2.0) remove special case once deprecation on DTA/TDA is enforced - msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]" - raise TypeError(msg) - - if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype): - return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True) - - if is_dtype_equal(values.dtype, dtype): - if copy: - return values.copy() - return values - - if isinstance(values, ExtensionArray): - values = values.astype(dtype, copy=copy) - - else: - values = astype_nansafe(values, dtype, copy=copy) - - return values - def convert( self, copy: bool = True, diff --git a/pandas/tests/frame/methods/test_astype.py b/pandas/tests/frame/methods/test_astype.py index 8c11f659e8454..161fe7990a327 100644 --- a/pandas/tests/frame/methods/test_astype.py +++ b/pandas/tests/frame/methods/test_astype.py @@ -3,8 +3,6 @@ import numpy as np import pytest -import pandas.util._test_decorators as td - import pandas as pd from pandas import ( Categorical, @@ -92,7 +90,6 @@ def test_astype_mixed_type(self, mixed_type_frame): casted = mn.astype("O") _check_cast(casted, "object") - @td.skip_array_manager_not_yet_implemented def test_astype_with_exclude_string(self, float_frame): df = float_frame.copy() expected = float_frame.astype(int) @@ -127,7 +124,6 @@ def test_astype_with_view_mixed_float(self, mixed_float_frame): casted = tf.astype(np.int64) casted = tf.astype(np.float32) # noqa - @td.skip_array_manager_not_yet_implemented @pytest.mark.parametrize("dtype", [np.int32, np.int64]) @pytest.mark.parametrize("val", [np.nan, np.inf]) def test_astype_cast_nan_inf_int(self, val, dtype): @@ -386,7 +382,6 @@ def test_astype_to_datetimelike_unit(self, arr_dtype, dtype, unit): tm.assert_frame_equal(result, expected) - @td.skip_array_manager_not_yet_implemented @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s", "h", "m", "D"]) def test_astype_to_datetime_unit(self, unit): # tests all units from datetime origination @@ -411,7 +406,6 @@ def test_astype_to_timedelta_unit_ns(self, unit): tm.assert_frame_equal(result, expected) - @td.skip_array_manager_not_yet_implemented @pytest.mark.parametrize("unit", ["us", "ms", "s", "h", "m", "D"]) def test_astype_to_timedelta_unit(self, unit): # coerce to float @@ -441,7 +435,6 @@ def test_astype_to_incorrect_datetimelike(self, unit): with pytest.raises(TypeError, match=msg): df.astype(dtype) - @td.skip_array_manager_not_yet_implemented def test_astype_arg_for_errors(self): # GH#14878 @@ -570,7 +563,6 @@ def test_astype_empty_dtype_dict(self): tm.assert_frame_equal(result, df) assert result is not df - @td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) ignore keyword @pytest.mark.parametrize( "df", [ diff --git a/pandas/util/_exceptions.py b/pandas/util/_exceptions.py index 5ca96a1f9989f..c31c421ee1445 100644 --- a/pandas/util/_exceptions.py +++ b/pandas/util/_exceptions.py @@ -31,7 +31,7 @@ def find_stack_level() -> int: if stack[n].function == "astype": break - while stack[n].function in ["astype", "apply", "_astype"]: + while stack[n].function in ["astype", "apply", "astype_array_safe", "astype_array"]: # e.g. # bump up Block.astype -> BlockManager.astype -> NDFrame.astype # bump up Datetime.Array.astype -> DatetimeIndex.astype