Skip to content

Commit db495eb

Browse files
REF: move Block.astype implementation to dtypes/cast.py (#40141)
1 parent 7648a8d commit db495eb

File tree

5 files changed

+111
-71
lines changed

5 files changed

+111
-71
lines changed

pandas/core/dtypes/cast.py

+103
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
datetime,
1111
timedelta,
1212
)
13+
import inspect
1314
from typing import (
1415
TYPE_CHECKING,
1516
Any,
@@ -87,6 +88,7 @@
8788
is_timedelta64_dtype,
8889
is_timedelta64_ns_dtype,
8990
is_unsigned_integer_dtype,
91+
pandas_dtype,
9092
)
9193
from pandas.core.dtypes.dtypes import (
9294
DatetimeTZDtype,
@@ -1227,6 +1229,107 @@ def astype_nansafe(
12271229
return arr.astype(dtype, copy=copy)
12281230

12291231

1232+
def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False) -> ArrayLike:
1233+
"""
1234+
Cast array (ndarray or ExtensionArray) to the new dtype.
1235+
1236+
Parameters
1237+
----------
1238+
values : ndarray or ExtensionArray
1239+
dtype : dtype object
1240+
copy : bool, default False
1241+
copy if indicated
1242+
1243+
Returns
1244+
-------
1245+
ndarray or ExtensionArray
1246+
"""
1247+
if (
1248+
values.dtype.kind in ["m", "M"]
1249+
and dtype.kind in ["i", "u"]
1250+
and isinstance(dtype, np.dtype)
1251+
and dtype.itemsize != 8
1252+
):
1253+
# TODO(2.0) remove special case once deprecation on DTA/TDA is enforced
1254+
msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]"
1255+
raise TypeError(msg)
1256+
1257+
if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
1258+
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)
1259+
1260+
if is_dtype_equal(values.dtype, dtype):
1261+
if copy:
1262+
return values.copy()
1263+
return values
1264+
1265+
if isinstance(values, ABCExtensionArray):
1266+
values = values.astype(dtype, copy=copy)
1267+
1268+
else:
1269+
values = astype_nansafe(values, dtype, copy=copy)
1270+
1271+
# in pandas we don't store numpy str dtypes, so convert to object
1272+
if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str):
1273+
values = np.array(values, dtype=object)
1274+
1275+
return values
1276+
1277+
1278+
def astype_array_safe(
1279+
values: ArrayLike, dtype, copy: bool = False, errors: str = "raise"
1280+
) -> ArrayLike:
1281+
"""
1282+
Cast array (ndarray or ExtensionArray) to the new dtype.
1283+
1284+
This basically is the implementation for DataFrame/Series.astype and
1285+
includes all custom logic for pandas (NaN-safety, converting str to object,
1286+
not allowing )
1287+
1288+
Parameters
1289+
----------
1290+
values : ndarray or ExtensionArray
1291+
dtype : str, dtype convertible
1292+
copy : bool, default False
1293+
copy if indicated
1294+
errors : str, {'raise', 'ignore'}, default 'raise'
1295+
- ``raise`` : allow exceptions to be raised
1296+
- ``ignore`` : suppress exceptions. On error return original object
1297+
1298+
Returns
1299+
-------
1300+
ndarray or ExtensionArray
1301+
"""
1302+
errors_legal_values = ("raise", "ignore")
1303+
1304+
if errors not in errors_legal_values:
1305+
invalid_arg = (
1306+
"Expected value of kwarg 'errors' to be one of "
1307+
f"{list(errors_legal_values)}. Supplied value is '{errors}'"
1308+
)
1309+
raise ValueError(invalid_arg)
1310+
1311+
if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
1312+
msg = (
1313+
f"Expected an instance of {dtype.__name__}, "
1314+
"but got the class instead. Try instantiating 'dtype'."
1315+
)
1316+
raise TypeError(msg)
1317+
1318+
dtype = pandas_dtype(dtype)
1319+
1320+
try:
1321+
new_values = astype_array(values, dtype, copy=copy)
1322+
except (ValueError, TypeError):
1323+
# e.g. astype_nansafe can fail on object-dtype of strings
1324+
# trying to convert to float
1325+
if errors == "ignore":
1326+
new_values = values
1327+
else:
1328+
raise
1329+
1330+
return new_values
1331+
1332+
12301333
def soft_convert_objects(
12311334
values: np.ndarray,
12321335
datetime: bool = True,

pandas/core/internals/array_manager.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pandas.util._validators import validate_bool_kwarg
2929

3030
from pandas.core.dtypes.cast import (
31+
astype_array_safe,
3132
find_common_type,
3233
infer_dtype_from_scalar,
3334
)
@@ -499,7 +500,7 @@ def downcast(self) -> ArrayManager:
499500
return self.apply_with_block("downcast")
500501

501502
def astype(self, dtype, copy: bool = False, errors: str = "raise") -> ArrayManager:
502-
return self.apply("astype", dtype=dtype, copy=copy) # , errors=errors)
503+
return self.apply(astype_array_safe, dtype=dtype, copy=copy, errors=errors)
503504

504505
def convert(
505506
self,

pandas/core/internals/blocks.py

+5-61
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import inspect
43
import re
54
from typing import (
65
TYPE_CHECKING,
@@ -36,8 +35,7 @@
3635
from pandas.util._validators import validate_bool_kwarg
3736

3837
from pandas.core.dtypes.cast import (
39-
astype_dt64_to_dt64tz,
40-
astype_nansafe,
38+
astype_array_safe,
4139
can_hold_element,
4240
find_common_type,
4341
infer_dtype_from,
@@ -49,7 +47,6 @@
4947
)
5048
from pandas.core.dtypes.common import (
5149
is_categorical_dtype,
52-
is_datetime64_dtype,
5350
is_datetime64tz_dtype,
5451
is_dtype_equal,
5552
is_extension_array_dtype,
@@ -652,33 +649,11 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
652649
-------
653650
Block
654651
"""
655-
errors_legal_values = ("raise", "ignore")
656-
657-
if errors not in errors_legal_values:
658-
invalid_arg = (
659-
"Expected value of kwarg 'errors' to be one of "
660-
f"{list(errors_legal_values)}. Supplied value is '{errors}'"
661-
)
662-
raise ValueError(invalid_arg)
663-
664-
if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
665-
msg = (
666-
f"Expected an instance of {dtype.__name__}, "
667-
"but got the class instead. Try instantiating 'dtype'."
668-
)
669-
raise TypeError(msg)
670-
671-
dtype = pandas_dtype(dtype)
652+
values = self.values
653+
if values.dtype.kind in ["m", "M"]:
654+
values = self.array_values()
672655

673-
try:
674-
new_values = self._astype(dtype, copy=copy)
675-
except (ValueError, TypeError):
676-
# e.g. astype_nansafe can fail on object-dtype of strings
677-
# trying to convert to float
678-
if errors == "ignore":
679-
new_values = self.values
680-
else:
681-
raise
656+
new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
682657

683658
newb = self.make_block(new_values)
684659
if newb.shape != self.shape:
@@ -689,37 +664,6 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
689664
)
690665
return newb
691666

692-
def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
693-
values = self.values
694-
if values.dtype.kind in ["m", "M"]:
695-
values = self.array_values()
696-
697-
if (
698-
values.dtype.kind in ["m", "M"]
699-
and dtype.kind in ["i", "u"]
700-
and isinstance(dtype, np.dtype)
701-
and dtype.itemsize != 8
702-
):
703-
# TODO(2.0) remove special case once deprecation on DTA/TDA is enforced
704-
msg = rf"cannot astype a datetimelike from [{values.dtype}] to [{dtype}]"
705-
raise TypeError(msg)
706-
707-
if is_datetime64tz_dtype(dtype) and is_datetime64_dtype(values.dtype):
708-
return astype_dt64_to_dt64tz(values, dtype, copy, via_utc=True)
709-
710-
if is_dtype_equal(values.dtype, dtype):
711-
if copy:
712-
return values.copy()
713-
return values
714-
715-
if isinstance(values, ExtensionArray):
716-
values = values.astype(dtype, copy=copy)
717-
718-
else:
719-
values = astype_nansafe(values, dtype, copy=copy)
720-
721-
return values
722-
723667
def convert(
724668
self,
725669
copy: bool = True,

pandas/tests/frame/methods/test_astype.py

-8
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import numpy as np
44
import pytest
55

6-
import pandas.util._test_decorators as td
7-
86
import pandas as pd
97
from pandas import (
108
Categorical,
@@ -92,7 +90,6 @@ def test_astype_mixed_type(self, mixed_type_frame):
9290
casted = mn.astype("O")
9391
_check_cast(casted, "object")
9492

95-
@td.skip_array_manager_not_yet_implemented
9693
def test_astype_with_exclude_string(self, float_frame):
9794
df = float_frame.copy()
9895
expected = float_frame.astype(int)
@@ -127,7 +124,6 @@ def test_astype_with_view_mixed_float(self, mixed_float_frame):
127124
casted = tf.astype(np.int64)
128125
casted = tf.astype(np.float32) # noqa
129126

130-
@td.skip_array_manager_not_yet_implemented
131127
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
132128
@pytest.mark.parametrize("val", [np.nan, np.inf])
133129
def test_astype_cast_nan_inf_int(self, val, dtype):
@@ -386,7 +382,6 @@ def test_astype_to_datetimelike_unit(self, arr_dtype, dtype, unit):
386382

387383
tm.assert_frame_equal(result, expected)
388384

389-
@td.skip_array_manager_not_yet_implemented
390385
@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s", "h", "m", "D"])
391386
def test_astype_to_datetime_unit(self, unit):
392387
# tests all units from datetime origination
@@ -411,7 +406,6 @@ def test_astype_to_timedelta_unit_ns(self, unit):
411406

412407
tm.assert_frame_equal(result, expected)
413408

414-
@td.skip_array_manager_not_yet_implemented
415409
@pytest.mark.parametrize("unit", ["us", "ms", "s", "h", "m", "D"])
416410
def test_astype_to_timedelta_unit(self, unit):
417411
# coerce to float
@@ -441,7 +435,6 @@ def test_astype_to_incorrect_datetimelike(self, unit):
441435
with pytest.raises(TypeError, match=msg):
442436
df.astype(dtype)
443437

444-
@td.skip_array_manager_not_yet_implemented
445438
def test_astype_arg_for_errors(self):
446439
# GH#14878
447440

@@ -570,7 +563,6 @@ def test_astype_empty_dtype_dict(self):
570563
tm.assert_frame_equal(result, df)
571564
assert result is not df
572565

573-
@td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) ignore keyword
574566
@pytest.mark.parametrize(
575567
"df",
576568
[

pandas/util/_exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def find_stack_level() -> int:
3131
if stack[n].function == "astype":
3232
break
3333

34-
while stack[n].function in ["astype", "apply", "_astype"]:
34+
while stack[n].function in ["astype", "apply", "astype_array_safe", "astype_array"]:
3535
# e.g.
3636
# bump up Block.astype -> BlockManager.astype -> NDFrame.astype
3737
# bump up Datetime.Array.astype -> DatetimeIndex.astype

0 commit comments

Comments
 (0)