Skip to content

REF: simplify maybe_convert_objects #53021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 4 additions & 55 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -70,45 +70,25 @@ def map_infer(
convert: bool = ...,
ignore_na: bool = ...,
) -> np.ndarray: ...
@overload # all convert_foo False -> only convert numeric
@overload
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: Literal[False] = ...,
convert_timedelta: Literal[False] = ...,
convert_period: Literal[False] = ...,
convert_interval: Literal[False] = ...,
convert_non_numeric: Literal[False] = ...,
convert_to_nullable_dtype: Literal[False] = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> npt.NDArray[np.object_ | np.number]: ...
@overload # both convert_datetime and convert_to_nullable_integer False -> np.ndarray
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: Literal[False] = ...,
convert_timedelta: bool = ...,
convert_period: Literal[False] = ...,
convert_interval: Literal[False] = ...,
convert_to_nullable_dtype: Literal[False] = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> np.ndarray: ...
@overload
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: bool = ...,
convert_timedelta: bool = ...,
convert_period: bool = ...,
convert_interval: bool = ...,
convert_non_numeric: bool = ...,
convert_to_nullable_dtype: Literal[True] = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> ArrayLike: ...
Expand All @@ -119,38 +99,7 @@ def maybe_convert_objects(
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: Literal[True] = ...,
convert_timedelta: bool = ...,
convert_period: bool = ...,
convert_interval: bool = ...,
convert_to_nullable_dtype: bool = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> ArrayLike: ...
@overload
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: bool = ...,
convert_timedelta: bool = ...,
convert_period: Literal[True] = ...,
convert_interval: bool = ...,
convert_to_nullable_dtype: bool = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> ArrayLike: ...
@overload
def maybe_convert_objects(
objects: npt.NDArray[np.object_],
*,
try_float: bool = ...,
safe: bool = ...,
convert_numeric: bool = ...,
convert_datetime: bool = ...,
convert_timedelta: bool = ...,
convert_period: bool = ...,
convert_interval: bool = ...,
convert_non_numeric: bool = ...,
convert_to_nullable_dtype: bool = ...,
dtype_if_all_nat: DtypeObj | None = ...,
) -> ArrayLike: ...
Expand Down
48 changes: 12 additions & 36 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2385,11 +2385,8 @@ def maybe_convert_objects(ndarray[object] objects,
bint try_float=False,
bint safe=False,
bint convert_numeric=True, # NB: different default!
bint convert_datetime=False,
bint convert_timedelta=False,
bint convert_period=False,
bint convert_interval=False,
bint convert_to_nullable_dtype=False,
bint convert_non_numeric=False,
object dtype_if_all_nat=None) -> "ArrayLike":
"""
Type inference function-- convert object array to proper dtype
Expand All @@ -2406,21 +2403,11 @@ def maybe_convert_objects(ndarray[object] objects,
True, no upcasting will be performed.
convert_numeric : bool, default True
Whether to convert numeric entries.
convert_datetime : bool, default False
If an array-like object contains only datetime values or NaT is
encountered, whether to convert and return an array of M8[ns] dtype.
convert_timedelta : bool, default False
If an array-like object contains only timedelta values or NaT is
encountered, whether to convert and return an array of m8[ns] dtype.
convert_period : bool, default False
If an array-like object contains only (homogeneous-freq) Period values
or NaT, whether to convert and return a PeriodArray.
convert_interval : bool, default False
If an array-like object contains only Interval objects (with matching
dtypes and closedness) or NaN, whether to convert to IntervalArray.
convert_to_nullable_dtype : bool, default False
If an array-like object contains only integer or boolean values (and NaN) is
encountered, whether to convert and return an Boolean/IntegerArray.
convert_non_numeric : bool, default False
Whether to convert datetime, timedelta, period, interval types.
dtype_if_all_nat : np.dtype, ExtensionDtype, or None, default None
Dtype to cast to if we have all-NaT.

Expand All @@ -2443,12 +2430,11 @@ def maybe_convert_objects(ndarray[object] objects,

if dtype_if_all_nat is not None:
# in practice we don't expect to ever pass dtype_if_all_nat
# without both convert_datetime and convert_timedelta, so disallow
# without both convert_non_numeric, so disallow
# it to avoid needing to handle it below.
if not convert_datetime or not convert_timedelta:
if not convert_non_numeric:
raise ValueError(
"Cannot specify 'dtype_if_all_nat' without convert_datetime=True "
"and convert_timedelta=True"
"Cannot specify 'dtype_if_all_nat' without convert_non_numeric=True"
)

n = len(objects)
Expand All @@ -2473,7 +2459,7 @@ def maybe_convert_objects(ndarray[object] objects,
mask[i] = True
elif val is NaT:
seen.nat_ = True
if not (convert_datetime or convert_timedelta or convert_period):
if not convert_non_numeric:
seen.object_ = True
break
elif util.is_nan(val):
Expand All @@ -2491,7 +2477,7 @@ def maybe_convert_objects(ndarray[object] objects,
if not convert_numeric:
break
elif is_timedelta(val):
if convert_timedelta:
if convert_non_numeric:
seen.timedelta_ = True
try:
convert_to_timedelta64(val, "ns")
Expand Down Expand Up @@ -2532,7 +2518,7 @@ def maybe_convert_objects(ndarray[object] objects,
elif PyDateTime_Check(val) or util.is_datetime64_object(val):

# if we have an tz's attached then return the objects
if convert_datetime:
if convert_non_numeric:
if getattr(val, "tzinfo", None) is not None:
seen.datetimetz_ = True
break
Expand All @@ -2548,7 +2534,7 @@ def maybe_convert_objects(ndarray[object] objects,
seen.object_ = True
break
elif is_period_object(val):
if convert_period:
if convert_non_numeric:
seen.period_ = True
break
else:
Expand All @@ -2564,7 +2550,7 @@ def maybe_convert_objects(ndarray[object] objects,
seen.object_ = True
break
elif is_interval(val):
if convert_interval:
if convert_non_numeric:
seen.interval_ = True
break
else:
Expand Down Expand Up @@ -2650,18 +2636,8 @@ def maybe_convert_objects(ndarray[object] objects,
elif dtype is not None:
# EA, we don't expect to get here, but _could_ implement
raise NotImplementedError(dtype)
elif convert_datetime and convert_timedelta:
# we don't guess
seen.object_ = True
elif convert_datetime:
res = np.empty((<object>objects).shape, dtype="M8[ns]")
res[:] = NPY_NAT
return res
elif convert_timedelta:
res = np.empty((<object>objects).shape, dtype="m8[ns]")
res[:] = NPY_NAT
return res
else:
# we don't guess
seen.object_ = True
else:
seen.object_ = True
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,10 +1171,7 @@ def maybe_infer_to_datetimelike(
# Here we do not convert numeric dtypes, as if we wanted that,
# numpy would have done it for us.
convert_numeric=False,
convert_period=True,
convert_interval=True,
convert_timedelta=True,
convert_datetime=True,
convert_non_numeric=True,
dtype_if_all_nat=np.dtype("M8[ns]"),
)

Expand Down
5 changes: 1 addition & 4 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6757,10 +6757,7 @@ def infer_objects(self, copy: bool = True) -> Index:
values = cast("npt.NDArray[np.object_]", values)
res_values = lib.maybe_convert_objects(
values,
convert_datetime=True,
convert_timedelta=True,
convert_period=True,
convert_interval=True,
convert_non_numeric=True,
)
if copy and res_values is values:
return self.copy()
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,7 @@ def _convert(arr):
arr = np.asarray(arr)
result = lib.maybe_convert_objects(
arr,
convert_datetime=True,
convert_timedelta=True,
convert_period=True,
convert_interval=True,
convert_non_numeric=True,
)
if result is arr and copy:
return arr.copy()
Expand Down
5 changes: 1 addition & 4 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,10 +2285,7 @@ def convert(

res_values = lib.maybe_convert_objects(
values,
convert_datetime=True,
convert_timedelta=True,
convert_period=True,
convert_interval=True,
convert_non_numeric=True,
)
refs = None
if copy and res_values is values:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,8 @@ def convert(arr):
# 1) we DO get here when arr is all Timestamps and dtype=None
# 2) disabling this doesn't break the world, so this must be
# getting caught at a higher level
# 3) passing convert_datetime to maybe_convert_objects get this right
# 4) convert_timedelta?
# 3) passing convert_non_numeric to maybe_convert_objects get this right
# 4) convert_non_numeric?

if dtype is None:
if arr.dtype == np.dtype("O"):
Expand Down
46 changes: 15 additions & 31 deletions pandas/tests/dtypes/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,17 +727,15 @@ def test_maybe_convert_objects_nat_inference(self, val, dtype):
vals = np.array([pd.NaT, val], dtype=object)
result = lib.maybe_convert_objects(
vals,
convert_datetime=True,
convert_timedelta=True,
convert_non_numeric=True,
dtype_if_all_nat=dtype,
)
assert result.dtype == dtype
assert np.isnat(result).all()

result = lib.maybe_convert_objects(
vals[::-1],
convert_datetime=True,
convert_timedelta=True,
convert_non_numeric=True,
dtype_if_all_nat=dtype,
)
assert result.dtype == dtype
Expand Down Expand Up @@ -777,47 +775,37 @@ def test_maybe_convert_objects_datetime(self):
[np.datetime64("2000-01-01"), np.timedelta64(1, "s")], dtype=object
)
exp = arr.copy()
out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
out = lib.maybe_convert_objects(arr, convert_non_numeric=True)
tm.assert_numpy_array_equal(out, exp)

arr = np.array([pd.NaT, np.timedelta64(1, "s")], dtype=object)
exp = np.array([np.timedelta64("NaT"), np.timedelta64(1, "s")], dtype="m8[ns]")
out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
out = lib.maybe_convert_objects(arr, convert_non_numeric=True)
tm.assert_numpy_array_equal(out, exp)

# with convert_timedelta=True, the nan is a valid NA value for td64
# with convert_non_numeric=True, the nan is a valid NA value for td64
arr = np.array([np.timedelta64(1, "s"), np.nan], dtype=object)
exp = exp[::-1]
out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
out = lib.maybe_convert_objects(arr, convert_non_numeric=True)
tm.assert_numpy_array_equal(out, exp)

def test_maybe_convert_objects_dtype_if_all_nat(self):
arr = np.array([pd.NaT, pd.NaT], dtype=object)
out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
out = lib.maybe_convert_objects(arr, convert_non_numeric=True)
# no dtype_if_all_nat passed -> we dont guess
tm.assert_numpy_array_equal(out, arr)

out = lib.maybe_convert_objects(
arr,
convert_datetime=True,
convert_timedelta=True,
convert_non_numeric=True,
dtype_if_all_nat=np.dtype("timedelta64[ns]"),
)
exp = np.array(["NaT", "NaT"], dtype="timedelta64[ns]")
tm.assert_numpy_array_equal(out, exp)

out = lib.maybe_convert_objects(
arr,
convert_datetime=True,
convert_timedelta=True,
convert_non_numeric=True,
dtype_if_all_nat=np.dtype("datetime64[ns]"),
)
exp = np.array(["NaT", "NaT"], dtype="datetime64[ns]")
Expand All @@ -830,8 +818,7 @@ def test_maybe_convert_objects_dtype_if_all_nat_invalid(self):
with pytest.raises(ValueError, match="int64"):
lib.maybe_convert_objects(
arr,
convert_datetime=True,
convert_timedelta=True,
convert_non_numeric=True,
dtype_if_all_nat=np.dtype("int64"),
)

Expand All @@ -842,9 +829,7 @@ def test_maybe_convert_objects_datetime_overflow_safe(self, dtype):
stamp = stamp - datetime(1970, 1, 1)
arr = np.array([stamp], dtype=object)

out = lib.maybe_convert_objects(
arr, convert_datetime=True, convert_timedelta=True
)
out = lib.maybe_convert_objects(arr, convert_non_numeric=True)
# no OutOfBoundsDatetime/OutOfBoundsTimedeltas
tm.assert_numpy_array_equal(out, arr)

Expand All @@ -855,15 +840,15 @@ def test_maybe_convert_objects_mixed_datetimes(self):
for data in itertools.permutations(vals):
data = np.array(list(data), dtype=object)
expected = DatetimeIndex(data)._data._ndarray
result = lib.maybe_convert_objects(data, convert_datetime=True)
result = lib.maybe_convert_objects(data, convert_non_numeric=True)
tm.assert_numpy_array_equal(result, expected)

def test_maybe_convert_objects_timedelta64_nat(self):
obj = np.timedelta64("NaT", "ns")
arr = np.array([obj], dtype=object)
assert arr[0] is obj

result = lib.maybe_convert_objects(arr, convert_timedelta=True)
result = lib.maybe_convert_objects(arr, convert_non_numeric=True)

expected = np.array([obj], dtype="m8[ns]")
tm.assert_numpy_array_equal(result, expected)
Expand Down Expand Up @@ -1037,7 +1022,7 @@ def test_maybe_convert_objects_itemsize(self, data0, data1):
def test_mixed_dtypes_remain_object_array(self):
# GH14956
arr = np.array([datetime(2015, 1, 1, tzinfo=pytz.utc), 1], dtype=object)
result = lib.maybe_convert_objects(arr, convert_datetime=True)
result = lib.maybe_convert_objects(arr, convert_non_numeric=True)
tm.assert_numpy_array_equal(result, arr)

@pytest.mark.parametrize(
Expand All @@ -1050,8 +1035,7 @@ def test_mixed_dtypes_remain_object_array(self):
def test_maybe_convert_objects_ea(self, idx):
result = lib.maybe_convert_objects(
np.array(idx, dtype=object),
convert_period=True,
convert_interval=True,
convert_non_numeric=True,
)
tm.assert_extension_array_equal(result, idx._data)

Expand Down