Skip to content

Commit 29c51af

Browse files
jbrockmendelluckyvs1
authored andcommitted
REF: Separate values-casting from Block.astype (pandas-dev#38455)
* REF: Separate values-casting from Block.astype * fix xfail
1 parent 0a5c089 commit 29c51af

File tree

2 files changed

+68
-65
lines changed

2 files changed

+68
-65
lines changed

pandas/core/internals/blocks.py

+59-61
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas._libs.internals import BlockPlacement
1111
from pandas._libs.tslibs import conversion
1212
from pandas._libs.tslibs.timezones import tz_compare
13-
from pandas._typing import ArrayLike, Scalar, Shape
13+
from pandas._typing import ArrayLike, DtypeObj, Scalar, Shape
1414
from pandas.util._validators import validate_bool_kwarg
1515

1616
from pandas.core.dtypes.cast import (
@@ -68,7 +68,7 @@
6868
)
6969
from pandas.core.base import PandasObject
7070
import pandas.core.common as com
71-
from pandas.core.construction import extract_array
71+
from pandas.core.construction import array as pd_array, extract_array
7272
from pandas.core.indexers import (
7373
check_setitem_lengths,
7474
is_empty_indexer,
@@ -593,7 +593,7 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
593593
dtype : str, dtype convertible
594594
copy : bool, default False
595595
copy if indicated
596-
errors : str, {'raise', 'ignore'}, default 'ignore'
596+
errors : str, {'raise', 'ignore'}, default 'raise'
597597
- ``raise`` : allow exceptions to be raised
598598
- ``ignore`` : suppress exceptions. On error return original object
599599
@@ -617,69 +617,23 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
617617
)
618618
raise TypeError(msg)
619619

620-
if dtype is not None:
621-
dtype = pandas_dtype(dtype)
622-
623-
# may need to convert to categorical
624-
if is_categorical_dtype(dtype):
625-
626-
if is_categorical_dtype(self.values.dtype):
627-
# GH 10696/18593: update an existing categorical efficiently
628-
return self.make_block(self.values.astype(dtype, copy=copy))
629-
630-
return self.make_block(Categorical(self.values, dtype=dtype))
631-
632620
dtype = pandas_dtype(dtype)
633621

634-
# astype processing
635-
if is_dtype_equal(self.dtype, dtype):
636-
if copy:
637-
return self.copy()
638-
return self
639-
640-
# force the copy here
641-
if self.is_extension:
642-
try:
643-
values = self.values.astype(dtype)
644-
except (ValueError, TypeError):
645-
if errors == "ignore":
646-
values = self.values
647-
else:
648-
raise
649-
else:
650-
if issubclass(dtype.type, str):
651-
652-
# use native type formatting for datetime/tz/timedelta
653-
if self.is_datelike:
654-
values = self.to_native_types().values
655-
656-
# astype formatting
657-
else:
658-
# Because we have neither is_extension nor is_datelike,
659-
# self.values already has the correct shape
660-
values = self.values
661-
622+
try:
623+
new_values = self._astype(dtype, copy=copy)
624+
except (ValueError, TypeError):
625+
# e.g. astype_nansafe can fail on object-dtype of strings
626+
# trying to convert to float
627+
if errors == "ignore":
628+
new_values = self.values
662629
else:
663-
values = self.get_values(dtype=dtype)
664-
665-
# _astype_nansafe works fine with 1-d only
666-
vals1d = values.ravel()
667-
try:
668-
values = astype_nansafe(vals1d, dtype, copy=True)
669-
except (ValueError, TypeError):
670-
# e.g. astype_nansafe can fail on object-dtype of strings
671-
# trying to convert to float
672-
if errors == "raise":
673-
raise
674-
newb = self.copy() if copy else self
675-
return newb
676-
677-
# TODO(EA2D): special case not needed with 2D EAs
678-
if isinstance(values, np.ndarray):
679-
values = values.reshape(self.shape)
630+
raise
680631

681-
newb = self.make_block(values)
632+
if isinstance(new_values, np.ndarray):
633+
# TODO(EA2D): special case not needed with 2D EAs
634+
new_values = new_values.reshape(self.shape)
682635

636+
newb = self.make_block(new_values)
683637
if newb.is_numeric and self.is_numeric:
684638
if newb.shape != self.shape:
685639
raise TypeError(
@@ -689,6 +643,50 @@ def astype(self, dtype, copy: bool = False, errors: str = "raise"):
689643
)
690644
return newb
691645

646+
def _astype(self, dtype: DtypeObj, copy: bool) -> ArrayLike:
647+
values = self.values
648+
649+
if is_categorical_dtype(dtype):
650+
651+
if is_categorical_dtype(values.dtype):
652+
# GH#10696/GH#18593: update an existing categorical efficiently
653+
return values.astype(dtype, copy=copy)
654+
655+
return Categorical(values, dtype=dtype)
656+
657+
if is_dtype_equal(values.dtype, dtype):
658+
if copy:
659+
return values.copy()
660+
return values
661+
662+
if isinstance(values, ExtensionArray):
663+
values = values.astype(dtype, copy=copy)
664+
665+
else:
666+
if issubclass(dtype.type, str):
667+
if values.dtype.kind in ["m", "M"]:
668+
# use native type formatting for datetime/tz/timedelta
669+
arr = pd_array(values)
670+
# Note: in the case where dtype is an np.dtype, i.e. not
671+
# StringDtype, this matches arr.astype(dtype), xref GH#36153
672+
values = arr._format_native_types(na_rep="NaT")
673+
674+
elif is_object_dtype(dtype):
675+
if values.dtype.kind in ["m", "M"]:
676+
# Wrap in Timedelta/Timestamp
677+
arr = pd_array(values)
678+
values = arr.astype(object)
679+
else:
680+
values = values.astype(object)
681+
# We still need to go through astype_nansafe for
682+
# e.g. dtype = Sparse[object, 0]
683+
684+
# astype_nansafe works with 1-d only
685+
vals1d = values.ravel()
686+
values = astype_nansafe(vals1d, dtype, copy=True)
687+
688+
return values
689+
692690
def convert(
693691
self,
694692
copy: bool = True,

pandas/tests/arrays/string_/test_string.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import pandas.util._test_decorators as td
77

8+
from pandas.core.dtypes.common import is_dtype_equal
9+
810
import pandas as pd
911
import pandas._testing as tm
1012
from pandas.core.arrays.string_arrow import ArrowStringArray, ArrowStringDtype
@@ -127,11 +129,14 @@ def test_astype_roundtrip(dtype, request):
127129
mark = pytest.mark.xfail(reason=reason)
128130
request.node.add_marker(mark)
129131

130-
s = pd.Series(pd.date_range("2000", periods=12))
131-
s[0] = None
132+
ser = pd.Series(pd.date_range("2000", periods=12))
133+
ser[0] = None
134+
135+
casted = ser.astype(dtype)
136+
assert is_dtype_equal(casted.dtype, dtype)
132137

133-
result = s.astype(dtype).astype("datetime64[ns]")
134-
tm.assert_series_equal(result, s)
138+
result = casted.astype("datetime64[ns]")
139+
tm.assert_series_equal(result, ser)
135140

136141

137142
def test_add(dtype, request):

0 commit comments

Comments
 (0)