Skip to content

REF: use astype_array instead of astype_nansafe #50222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pandas/_libs/tslibs/np_datetime.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ cpdef ndarray astype_overflowsafe(
"""
if values.descr.type_num == dtype.type_num == cnp.NPY_DATETIME:
# i.e. dtype.kind == "M"
pass
dtype_name = "datetime64"
elif values.descr.type_num == dtype.type_num == cnp.NPY_TIMEDELTA:
# i.e. dtype.kind == "m"
pass
dtype_name = "timedelta64"
else:
raise TypeError(
"astype_overflowsafe values.dtype and dtype must be either "
Expand All @@ -326,14 +326,14 @@ cpdef ndarray astype_overflowsafe(
NPY_DATETIMEUNIT from_unit = get_unit_from_dtype(values.dtype)
NPY_DATETIMEUNIT to_unit = get_unit_from_dtype(dtype)

if (
from_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC
or to_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC
):
if from_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
raise TypeError(f"{dtype_name} values must have a unit specified")

if to_unit == NPY_DATETIMEUNIT.NPY_FR_GENERIC:
# without raising explicitly here, we end up with a SystemError
# built-in function [...] returned a result with an error
raise ValueError(
"datetime64/timedelta64 values and dtype must have a unit specified"
f"{dtype_name} dtype must have a unit specified"
)

if from_unit == to_unit:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/sparse/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pandas.errors import PerformanceWarning
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.astype import astype_nansafe
from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.base import (
ExtensionDtype,
register_extension_dtype,
Expand Down Expand Up @@ -363,7 +363,7 @@ def update_dtype(self, dtype) -> SparseDtype:
raise TypeError("sparse arrays of extension dtypes not supported")

fv_asarray = np.atleast_1d(np.array(self.fill_value))
fvarr = astype_nansafe(fv_asarray, dtype)
fvarr = astype_array(fv_asarray, dtype)
# NB: not fv_0d.item(), as that casts dt64->int
fill_value = fvarr[0]
dtype = cls(dtype, fill_value=fill_value)
Expand Down
58 changes: 9 additions & 49 deletions pandas/core/dtypes/astype.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import inspect
from typing import (
TYPE_CHECKING,
cast,
overload,
)

Expand All @@ -34,34 +33,29 @@
ExtensionDtype,
PandasDtype,
)
from pandas.core.dtypes.missing import isna

if TYPE_CHECKING:
from pandas.core.arrays import (
DatetimeArray,
ExtensionArray,
TimedeltaArray,
)
from pandas.core.arrays import ExtensionArray


_dtype_obj = np.dtype(object)


@overload
def astype_nansafe(
def _astype_nansafe(
arr: np.ndarray, dtype: np.dtype, copy: bool = ..., skipna: bool = ...
) -> np.ndarray:
...


@overload
def astype_nansafe(
def _astype_nansafe(
arr: np.ndarray, dtype: ExtensionDtype, copy: bool = ..., skipna: bool = ...
) -> ExtensionArray:
...


def astype_nansafe(
def _astype_nansafe(
arr: np.ndarray, dtype: DtypeObj, copy: bool = True, skipna: bool = False
) -> ArrayLike:
"""
Expand Down Expand Up @@ -90,13 +84,12 @@ def astype_nansafe(
elif not isinstance(dtype, np.dtype): # pragma: no cover
raise ValueError("dtype must be np.dtype or ExtensionDtype")

if arr.dtype.kind in ["m", "M"] and (
issubclass(dtype.type, str) or dtype == _dtype_obj
):
if arr.dtype.kind in ["m", "M"]:
from pandas.core.construction import ensure_wrapped_if_datetimelike

arr = ensure_wrapped_if_datetimelike(arr)
return arr.astype(dtype, copy=copy)
res = arr.astype(dtype, copy=copy)
return np.asarray(res)

if issubclass(dtype.type, str):
shape = arr.shape
Expand All @@ -106,39 +99,6 @@ def astype_nansafe(
arr, skipna=skipna, convert_na_value=False
).reshape(shape)

elif is_datetime64_dtype(arr.dtype):
if dtype == np.int64:
if isna(arr).any():
raise ValueError("Cannot convert NaT values to integer")
return arr.view(dtype)

# allow frequency conversions
if dtype.kind == "M":
from pandas.core.construction import ensure_wrapped_if_datetimelike

dta = ensure_wrapped_if_datetimelike(arr)
dta = cast("DatetimeArray", dta)
return dta.astype(dtype, copy=copy)._ndarray

raise TypeError(f"cannot astype a datetimelike from [{arr.dtype}] to [{dtype}]")

elif is_timedelta64_dtype(arr.dtype):
if dtype == np.int64:
if isna(arr).any():
raise ValueError("Cannot convert NaT values to integer")
return arr.view(dtype)

elif dtype.kind == "m":
# give the requested dtype for supported units (s, ms, us, ns)
# and doing the old convert-to-float behavior otherwise.
from pandas.core.construction import ensure_wrapped_if_datetimelike

tda = ensure_wrapped_if_datetimelike(arr)
tda = cast("TimedeltaArray", tda)
return tda.astype(dtype, copy=copy)._ndarray

raise TypeError(f"cannot astype a timedelta from [{arr.dtype}] to [{dtype}]")

elif np.issubdtype(arr.dtype, np.floating) and is_integer_dtype(dtype):
return _astype_float_to_int_nansafe(arr, dtype, copy)

Expand Down Expand Up @@ -231,7 +191,7 @@ def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False) -> Arra
values = values.astype(dtype, copy=copy)

else:
values = astype_nansafe(values, dtype, copy=copy)
values = _astype_nansafe(values, dtype, copy=copy)

# in pandas we don't store numpy str dtypes, so convert to object
if isinstance(dtype, np.dtype) and issubclass(values.dtype.type, str):
Expand Down Expand Up @@ -288,7 +248,7 @@ def astype_array_safe(
try:
new_values = astype_array(values, dtype, copy=copy)
except (ValueError, TypeError):
# e.g. astype_nansafe can fail on object-dtype of strings
# e.g. _astype_nansafe can fail on object-dtype of strings
# trying to convert to float
if errors == "ignore":
new_values = values
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
rewrite_exception,
)

from pandas.core.dtypes.astype import astype_nansafe
from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.cast import (
LossySetitemError,
can_hold_element,
Expand Down Expand Up @@ -993,8 +993,8 @@ def astype(self, dtype, copy: bool = True):
# GH#38607 see test_astype_str_from_bytes
new_values = values.astype(dtype, copy=copy)
else:
# GH#13149 specifically use astype_nansafe instead of astype
new_values = astype_nansafe(values, dtype=dtype, copy=copy)
# GH#13149 specifically use astype_array instead of astype
new_values = astype_array(values, dtype=dtype, copy=copy)

# pass copy=False because any copying will be done in the astype above
if self._is_backward_compat_public_numeric_index:
Expand Down
20 changes: 11 additions & 9 deletions pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@
)
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.astype import astype_nansafe
from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.common import (
ensure_object,
is_bool_dtype,
is_categorical_dtype,
is_dict_like,
is_dtype_equal,
is_extension_array_dtype,
Expand Down Expand Up @@ -799,17 +798,13 @@ def _cast_types(self, values: ArrayLike, cast_type: DtypeObj, column) -> ArrayLi
-------
converted : ndarray or ExtensionArray
"""
if is_categorical_dtype(cast_type):
known_cats = (
isinstance(cast_type, CategoricalDtype)
and cast_type.categories is not None
)
if isinstance(cast_type, CategoricalDtype):
known_cats = cast_type.categories is not None

if not is_object_dtype(values.dtype) and not known_cats:
# TODO: this is for consistency with
# c-parser which parses all categories
# as strings

values = lib.ensure_string_array(
values, skipna=False, convert_na_value=False
)
Expand Down Expand Up @@ -842,9 +837,16 @@ def _cast_types(self, values: ArrayLike, cast_type: DtypeObj, column) -> ArrayLi

elif isinstance(values, ExtensionArray):
values = values.astype(cast_type, copy=False)
elif issubclass(cast_type.type, str):
# TODO: why skipna=True here and False above? some tests depend
# on it here, but nothing fails if we change it above
# (as no tests get there as of 2022-12-06)
values = lib.ensure_string_array(
values, skipna=True, convert_na_value=False
)
else:
try:
values = astype_nansafe(values, cast_type, copy=True, skipna=True)
values = astype_array(values, cast_type, copy=True)
except ValueError as err:
raise ValueError(
f"Unable to convert column {column} to type {cast_type}"
Expand Down
47 changes: 3 additions & 44 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas.util._test_decorators as td

from pandas.core.dtypes.astype import astype_nansafe
from pandas.core.dtypes.astype import astype_array
import pandas.core.dtypes.common as com
from pandas.core.dtypes.dtypes import (
CategoricalDtype,
Expand Down Expand Up @@ -715,62 +715,21 @@ def test__is_dtype_type(input_param, result):
assert com._is_dtype_type(input_param, lambda tipo: tipo == result)


@pytest.mark.parametrize("val", [np.datetime64("NaT"), np.timedelta64("NaT")])
@pytest.mark.parametrize("typ", [np.int64])
def test_astype_nansafe(val, typ):
arr = np.array([val])

typ = np.dtype(typ)

msg = "Cannot convert NaT values to integer"
with pytest.raises(ValueError, match=msg):
astype_nansafe(arr, dtype=typ)


def test_astype_nansafe_copy_false(any_int_numpy_dtype):
# GH#34457 use astype, not view
arr = np.array([1, 2, 3], dtype=any_int_numpy_dtype)

dtype = np.dtype("float64")
result = astype_nansafe(arr, dtype, copy=False)
result = astype_array(arr, dtype, copy=False)

expected = np.array([1.0, 2.0, 3.0], dtype=dtype)
tm.assert_numpy_array_equal(result, expected)


@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
@pytest.mark.parametrize(
"to_type",
[
np.uint8,
np.uint16,
np.uint32,
np.int8,
np.int16,
np.int32,
np.float16,
np.float32,
],
)
def test_astype_datetime64_bad_dtype_raises(from_type, to_type):
arr = np.array([from_type("2018")])

to_type = np.dtype(to_type)

msg = "|".join(
[
"cannot astype a timedelta",
"cannot astype a datetimelike",
]
)
with pytest.raises(TypeError, match=msg):
astype_nansafe(arr, dtype=to_type)


@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
def test_astype_object_preserves_datetime_na(from_type):
arr = np.array([from_type("NaT", "ns")])
result = astype_nansafe(arr, dtype=np.dtype("object"))
result = astype_array(arr, dtype=np.dtype("object"))

assert isna(result)[0]

Expand Down