Skip to content

Commit e2598a1

Browse files
authored
REF: use astype_array instead of astype_nansafe (#50222)
1 parent b1c5b5d commit e2598a1

File tree

6 files changed

+35
-114
lines changed

6 files changed

+35
-114
lines changed

pandas/_libs/tslibs/np_datetime.pyx

+7-7
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ cpdef ndarray astype_overflowsafe(
312312
"""
313313
if values.descr.type_num == dtype.type_num == cnp.NPY_DATETIME:
314314
# i.e. dtype.kind == "M"
315-
pass
315+
dtype_name = "datetime64"
316316
elif values.descr.type_num == dtype.type_num == cnp.NPY_TIMEDELTA:
317317
# i.e. dtype.kind == "m"
318-
pass
318+
dtype_name = "timedelta64"
319319
else:
320320
raise TypeError(
321321
"astype_overflowsafe values.dtype and dtype must be either "
@@ -326,14 +326,14 @@ cpdef ndarray astype_overflowsafe(
326326
NPY_DATETIMEUNIT from_unit = get_unit_from_dtype(values.dtype)
327327
NPY_DATETIMEUNIT to_unit = get_unit_from_dtype(dtype)
328328

329-
if (
330-
from_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC
331-
or to_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC
332-
):
329+
if from_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
330+
raise TypeError(f"{dtype_name} values must have a unit specified")
331+
332+
if to_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
333333
# without raising explicitly here, we end up with a SystemError
334334
# built-in function [...] returned a result with an error
335335
raise ValueError(
336-
"datetime64/timedelta64 values and dtype must have a unit specified"
336+
f"{dtype_name} dtype must have a unit specified"
337337
)
338338

339339
if from_unit == to_unit:

pandas/core/arrays/sparse/dtype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pandas.errors import PerformanceWarning
1919
from pandas.util._exceptions import find_stack_level
2020

21-
from pandas.core.dtypes.astype import astype_nansafe
21+
from pandas.core.dtypes.astype import astype_array
2222
from pandas.core.dtypes.base import (
2323
ExtensionDtype,
2424
register_extension_dtype,
@@ -363,7 +363,7 @@ def update_dtype(self, dtype) -> SparseDtype:
363363
raise TypeError("sparse arrays of extension dtypes not supported")
364364

365365
fv_asarray = np.atleast_1d(np.array(self.fill_value))
366-
fvarr = astype_nansafe(fv_asarray, dtype)
366+
fvarr = astype_array(fv_asarray, dtype)
367367
# NB: not fv_0d.item(), as that casts dt64->int
368368
fill_value = fvarr[0]
369369
dtype = cls(dtype, fill_value=fill_value)

pandas/core/dtypes/astype.py

+9-49
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import inspect
88
from typing import (
99
TYPE_CHECKING,
10-
cast,
1110
overload,
1211
)
1312

@@ -34,34 +33,29 @@
3433
ExtensionDtype,
3534
PandasDtype,
3635
)
37-
from pandas.core.dtypes.missing import isna
3836

3937
if TYPE_CHECKING:
40-
from pandas.core.arrays import (
41-
DatetimeArray,
42-
ExtensionArray,
43-
TimedeltaArray,
44-
)
38+
from pandas.core.arrays import ExtensionArray
4539

4640

4741
_dtype_obj = np.dtype(object)
4842

4943

5044
@overload
51-
def astype_nansafe(
45+
def _astype_nansafe(
5246
arr: np.ndarray, dtype: np.dtype, copy: bool = ..., skipna: bool = ...
5347
) -> np.ndarray:
5448
...
5549

5650

5751
@overload
58-
def astype_nansafe(
52+
def _astype_nansafe(
5953
arr: np.ndarray, dtype: ExtensionDtype, copy: bool = ..., skipna: bool = ...
6054
) -> ExtensionArray:
6155
...
6256

6357

64-
def astype_nansafe(
58+
def _astype_nansafe(
6559
arr: np.ndarray, dtype: DtypeObj, copy: bool = True, skipna: bool = False
6660
) -> ArrayLike:
6761
"""
@@ -90,13 +84,12 @@ def astype_nansafe(
9084
elif not isinstance(dtype, np.dtype): # pragma: no cover
9185
raise ValueError("dtype must be np.dtype or ExtensionDtype")
9286

93-
if arr.dtype.kind in ["m", "M"] and (
94-
issubclass(dtype.type, str) or dtype == _dtype_obj
95-
):
87+
if arr.dtype.kind in ["m", "M"]:
9688
from pandas.core.construction import ensure_wrapped_if_datetimelike
9789

9890
arr = ensure_wrapped_if_datetimelike(arr)
99-
return arr.astype(dtype, copy=copy)
91+
res = arr.astype(dtype, copy=copy)
92+
return np.asarray(res)
10093

10194
if issubclass(dtype.type, str):
10295
shape = arr.shape
@@ -106,39 +99,6 @@ def astype_nansafe(
10699
arr, skipna=skipna, convert_na_value=False
107100
).reshape(shape)
108101

109-
elif is_datetime64_dtype(arr.dtype):
110-
if dtype == np.int64:
111-
if isna(arr).any():
112-
raise ValueError("Cannot convert NaT values to integer")
113-
return arr.view(dtype)
114-
115-
# allow frequency conversions
116-
if dtype.kind == "M":
117-
from pandas.core.construction import ensure_wrapped_if_datetimelike
118-
119-
dta = ensure_wrapped_if_datetimelike(arr)
120-
dta = cast("DatetimeArray", dta)
121-
return dta.astype(dtype, copy=copy)._ndarray
122-
123-
raise TypeError(f"cannot astype a datetimelike from [{arr.dtype}] to [{dtype}]")
124-
125-
elif is_timedelta64_dtype(arr.dtype):
126-
if dtype == np.int64:
127-
if isna(arr).any():
128-
raise ValueError("Cannot convert NaT values to integer")
129-
return arr.view(dtype)
130-
131-
elif dtype.kind == "m":
132-
# give the requested dtype for supported units (s, ms, us, ns)
133-
# and doing the old convert-to-float behavior otherwise.
134-
from pandas.core.construction import ensure_wrapped_if_datetimelike
135-
136-
tda = ensure_wrapped_if_datetimelike(arr)
137-
tda = cast("TimedeltaArray", tda)
138-
return tda.astype(dtype, copy=copy)._ndarray
139-
140-
raise TypeError(f"cannot astype a timedelta from [{arr.dtype}] to [{dtype}]")
141-
142102
elif np.issubdtype(arr.dtype, np.floating) and is_integer_dtype(dtype):
143103
return _astype_float_to_int_nansafe(arr, dtype, copy)
144104

@@ -231,7 +191,7 @@ def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False) -> Arra
231191
values = values.astype(dtype, copy=copy)
232192

233193
else:
234-
values = astype_nansafe(values, dtype, copy=copy)
194+
values = _astype_nansafe(values, dtype, copy=copy)
235195

236196
# in pandas we don't store numpy str dtypes, so convert to object
237197
if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str):
@@ -288,7 +248,7 @@ def astype_array_safe(
288248
try:
289249
new_values = astype_array(values, dtype, copy=copy)
290250
except (ValueError, TypeError):
291-
# e.g. astype_nansafe can fail on object-dtype of strings
251+
# e.g. _astype_nansafe can fail on object-dtype of strings
292252
# trying to convert to float
293253
if errors == "ignore":
294254
new_values = values

pandas/core/indexes/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
rewrite_exception,
7474
)
7575

76-
from pandas.core.dtypes.astype import astype_nansafe
76+
from pandas.core.dtypes.astype import astype_array
7777
from pandas.core.dtypes.cast import (
7878
LossySetitemError,
7979
can_hold_element,
@@ -992,8 +992,8 @@ def astype(self, dtype, copy: bool = True):
992992
# GH#38607 see test_astype_str_from_bytes
993993
new_values = values.astype(dtype, copy=copy)
994994
else:
995-
# GH#13149 specifically use astype_nansafe instead of astype
996-
new_values = astype_nansafe(values, dtype=dtype, copy=copy)
995+
# GH#13149 specifically use astype_array instead of astype
996+
new_values = astype_array(values, dtype=dtype, copy=copy)
997997

998998
# pass copy=False because any copying will be done in the astype above
999999
if self._is_backward_compat_public_numeric_index:

pandas/io/parsers/base_parser.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,10 @@
4848
)
4949
from pandas.util._exceptions import find_stack_level
5050

51-
from pandas.core.dtypes.astype import astype_nansafe
51+
from pandas.core.dtypes.astype import astype_array
5252
from pandas.core.dtypes.common import (
5353
ensure_object,
5454
is_bool_dtype,
55-
is_categorical_dtype,
5655
is_dict_like,
5756
is_dtype_equal,
5857
is_extension_array_dtype,
@@ -799,17 +798,13 @@ def _cast_types(self, values: ArrayLike, cast_type: DtypeObj, column) -> ArrayLi
799798
-------
800799
converted : ndarray or ExtensionArray
801800
"""
802-
if is_categorical_dtype(cast_type):
803-
known_cats = (
804-
isinstance(cast_type, CategoricalDtype)
805-
and cast_type.categories is not None
806-
)
801+
if isinstance(cast_type, CategoricalDtype):
802+
known_cats = cast_type.categories is not None
807803

808804
if not is_object_dtype(values.dtype) and not known_cats:
809805
# TODO: this is for consistency with
810806
# c-parser which parses all categories
811807
# as strings
812-
813808
values = lib.ensure_string_array(
814809
values, skipna=False, convert_na_value=False
815810
)
@@ -842,9 +837,16 @@ def _cast_types(self, values: ArrayLike, cast_type: DtypeObj, column) -> ArrayLi
842837

843838
elif isinstance(values, ExtensionArray):
844839
values = values.astype(cast_type, copy=False)
840+
elif issubclass(cast_type.type, str):
841+
# TODO: why skipna=True here and False above? some tests depend
842+
# on it here, but nothing fails if we change it above
843+
# (as no tests get there as of 2022-12-06)
844+
values = lib.ensure_string_array(
845+
values, skipna=True, convert_na_value=False
846+
)
845847
else:
846848
try:
847-
values = astype_nansafe(values, cast_type, copy=True, skipna=True)
849+
values = astype_array(values, cast_type, copy=True)
848850
except ValueError as err:
849851
raise ValueError(
850852
f"Unable to convert column {column} to type {cast_type}"

pandas/tests/dtypes/test_common.py

+3-44
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pandas.util._test_decorators as td
99

10-
from pandas.core.dtypes.astype import astype_nansafe
10+
from pandas.core.dtypes.astype import astype_array
1111
import pandas.core.dtypes.common as com
1212
from pandas.core.dtypes.dtypes import (
1313
CategoricalDtype,
@@ -715,62 +715,21 @@ def test__is_dtype_type(input_param, result):
715715
assert com._is_dtype_type(input_param, lambda tipo: tipo == result)
716716

717717

718-
@pytest.mark.parametrize("val", [np.datetime64("NaT"), np.timedelta64("NaT")])
719-
@pytest.mark.parametrize("typ", [np.int64])
720-
def test_astype_nansafe(val, typ):
721-
arr = np.array([val])
722-
723-
typ = np.dtype(typ)
724-
725-
msg = "Cannot convert NaT values to integer"
726-
with pytest.raises(ValueError, match=msg):
727-
astype_nansafe(arr, dtype=typ)
728-
729-
730718
def test_astype_nansafe_copy_false(any_int_numpy_dtype):
731719
# GH#34457 use astype, not view
732720
arr = np.array([1, 2, 3], dtype=any_int_numpy_dtype)
733721

734722
dtype = np.dtype("float64")
735-
result = astype_nansafe(arr, dtype, copy=False)
723+
result = astype_array(arr, dtype, copy=False)
736724

737725
expected = np.array([1.0, 2.0, 3.0], dtype=dtype)
738726
tm.assert_numpy_array_equal(result, expected)
739727

740728

741-
@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
742-
@pytest.mark.parametrize(
743-
"to_type",
744-
[
745-
np.uint8,
746-
np.uint16,
747-
np.uint32,
748-
np.int8,
749-
np.int16,
750-
np.int32,
751-
np.float16,
752-
np.float32,
753-
],
754-
)
755-
def test_astype_datetime64_bad_dtype_raises(from_type, to_type):
756-
arr = np.array([from_type("2018")])
757-
758-
to_type = np.dtype(to_type)
759-
760-
msg = "|".join(
761-
[
762-
"cannot astype a timedelta",
763-
"cannot astype a datetimelike",
764-
]
765-
)
766-
with pytest.raises(TypeError, match=msg):
767-
astype_nansafe(arr, dtype=to_type)
768-
769-
770729
@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
771730
def test_astype_object_preserves_datetime_na(from_type):
772731
arr = np.array([from_type("NaT", "ns")])
773-
result = astype_nansafe(arr, dtype=np.dtype("object"))
732+
result = astype_array(arr, dtype=np.dtype("object"))
774733

775734
assert isna(result)[0]
776735

0 commit comments

Comments
 (0)