Skip to content

Commit 59b2db1

Browse files
authored
TYP: overload lib.maybe_convert_objects (#41166)
1 parent 50c0f77 commit 59b2db1

File tree

10 files changed

+98
-85
lines changed

10 files changed

+98
-85
lines changed

pandas/_libs/lib.pyi

+51-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ from typing import (
55
Any,
66
Callable,
77
Generator,
8+
Literal,
9+
overload,
810
)
911

1012
import numpy as np
@@ -51,23 +53,65 @@ def is_float_array(values: np.ndarray, skipna: bool = False): ...
5153
def is_integer_array(values: np.ndarray, skipna: bool = False): ...
5254
def is_bool_array(values: np.ndarray, skipna: bool = False): ...
5355

54-
def fast_multiget(mapping: dict, keys: np.ndarray, default=np.nan) -> ArrayLike: ...
56+
def fast_multiget(mapping: dict, keys: np.ndarray, default=np.nan) -> np.ndarray: ...
5557

5658
def fast_unique_multiple_list_gen(gen: Generator, sort: bool = True) -> list: ...
5759
def fast_unique_multiple_list(lists: list, sort: bool = True) -> list: ...
5860
def fast_unique_multiple(arrays: list, sort: bool = True) -> list: ...
5961

6062
def map_infer(
6163
arr: np.ndarray, f: Callable[[Any], Any], convert: bool = True, ignore_na: bool = False
64+
) -> np.ndarray: ...
65+
66+
67+
@overload # both convert_datetime and convert_to_nullable_integer False -> np.ndarray
68+
def maybe_convert_objects(
69+
objects: np.ndarray, # np.ndarray[object]
70+
try_float: bool = ...,
71+
safe: bool = ...,
72+
convert_datetime: Literal[False] = ...,
73+
convert_timedelta: bool = ...,
74+
convert_to_nullable_integer: Literal[False] = ...,
75+
) -> np.ndarray: ...
76+
77+
@overload
78+
def maybe_convert_objects(
79+
objects: np.ndarray, # np.ndarray[object]
80+
try_float: bool = ...,
81+
safe: bool = ...,
82+
convert_datetime: Literal[False] = False,
83+
convert_timedelta: bool = ...,
84+
convert_to_nullable_integer: Literal[True] = ...,
6285
) -> ArrayLike: ...
6386

87+
@overload
6488
def maybe_convert_objects(
6589
objects: np.ndarray, # np.ndarray[object]
66-
try_float: bool = False,
67-
safe: bool = False,
68-
convert_datetime: bool = False,
69-
convert_timedelta: bool = False,
70-
convert_to_nullable_integer: bool = False,
90+
try_float: bool = ...,
91+
safe: bool = ...,
92+
convert_datetime: Literal[True] = ...,
93+
convert_timedelta: bool = ...,
94+
convert_to_nullable_integer: Literal[False] = ...,
95+
) -> ArrayLike: ...
96+
97+
@overload
98+
def maybe_convert_objects(
99+
objects: np.ndarray, # np.ndarray[object]
100+
try_float: bool = ...,
101+
safe: bool = ...,
102+
convert_datetime: Literal[True] = ...,
103+
convert_timedelta: bool = ...,
104+
convert_to_nullable_integer: Literal[True] = ...,
105+
) -> ArrayLike: ...
106+
107+
@overload
108+
def maybe_convert_objects(
109+
objects: np.ndarray, # np.ndarray[object]
110+
try_float: bool = ...,
111+
safe: bool = ...,
112+
convert_datetime: bool = ...,
113+
convert_timedelta: bool = ...,
114+
convert_to_nullable_integer: bool = ...,
71115
) -> ArrayLike: ...
72116

73117
def maybe_convert_numeric(
@@ -140,7 +184,7 @@ def map_infer_mask(
140184
convert: bool = ...,
141185
na_value: Any = ...,
142186
dtype: np.dtype = ...,
143-
) -> ArrayLike: ...
187+
) -> np.ndarray: ...
144188

145189
def indices_fast(
146190
index: np.ndarray, # ndarray[intp_t]

pandas/_libs/lib.pyx

+5-5
Original file line numberDiff line numberDiff line change
@@ -2488,7 +2488,7 @@ no_default = NoDefault.no_default # Sentinel indicating the default value.
24882488
@cython.wraparound(False)
24892489
def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=True,
24902490
object na_value=no_default, cnp.dtype dtype=np.dtype(object)
2491-
) -> "ArrayLike":
2491+
) -> np.ndarray:
24922492
"""
24932493
Substitute for np.vectorize with pandas-friendly dtype inference.
24942494

@@ -2508,7 +2508,7 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr
25082508

25092509
Returns
25102510
-------
2511-
np.ndarray or ExtensionArray
2511+
np.ndarray
25122512
"""
25132513
cdef:
25142514
Py_ssize_t i, n
@@ -2545,7 +2545,7 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr
25452545
@cython.wraparound(False)
25462546
def map_infer(
25472547
ndarray arr, object f, bint convert=True, bint ignore_na=False
2548-
) -> "ArrayLike":
2548+
) -> np.ndarray:
25492549
"""
25502550
Substitute for np.vectorize with pandas-friendly dtype inference.
25512551

@@ -2559,7 +2559,7 @@ def map_infer(
25592559

25602560
Returns
25612561
-------
2562-
np.ndarray or ExtensionArray
2562+
np.ndarray
25632563
"""
25642564
cdef:
25652565
Py_ssize_t i, n
@@ -2697,7 +2697,7 @@ def to_object_array_tuples(rows: object) -> np.ndarray:
26972697

26982698
@cython.wraparound(False)
26992699
@cython.boundscheck(False)
2700-
def fast_multiget(dict mapping, ndarray keys, default=np.nan) -> "ArrayLike":
2700+
def fast_multiget(dict mapping, ndarray keys, default=np.nan) -> np.ndarray:
27012701
cdef:
27022702
Py_ssize_t i, n = len(keys)
27032703
object val

pandas/core/arrays/datetimelike.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def _box_values(self, values) -> np.ndarray:
262262
"""
263263
apply box func to passed values
264264
"""
265-
# error: Incompatible return value type (got
266-
# "Union[ExtensionArray, ndarray]", expected "ndarray")
267-
return lib.map_infer(values, self._box_func) # type: ignore[return-value]
265+
return lib.map_infer(values, self._box_func)
268266

269267
def __iter__(self):
270268
if self.ndim > 1:

pandas/core/arrays/string_.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,7 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
450450
if not na_value_is_na:
451451
mask[:] = False
452452

453-
# error: Argument 1 to "maybe_convert_objects" has incompatible
454-
# type "Union[ExtensionArray, ndarray]"; expected "ndarray"
455-
return constructor(result, mask) # type: ignore[arg-type]
453+
return constructor(result, mask)
456454

457455
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
458456
# i.e. StringDtype

pandas/core/arrays/string_arrow.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -420,10 +420,8 @@ def fillna(self, value=None, method=None, limit=None):
420420
if mask.any():
421421
if method is not None:
422422
func = missing.get_fill_func(method)
423-
# error: Argument 1 to "to_numpy" of "ArrowStringArray" has incompatible
424-
# type "Type[object]"; expected "Union[str, dtype[Any], None]"
425423
new_values, _ = func(
426-
self.to_numpy(object), # type: ignore[arg-type]
424+
self.to_numpy("object"),
427425
limit=limit,
428426
mask=mask,
429427
)
@@ -740,11 +738,7 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
740738
if not na_value_is_na:
741739
mask[:] = False
742740

743-
# error: Argument 1 to "IntegerArray" has incompatible type
744-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
745-
# error: Argument 1 to "BooleanArray" has incompatible type
746-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
747-
return constructor(result, mask) # type: ignore[arg-type]
741+
return constructor(result, mask)
748742

749743
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
750744
# i.e. StringDtype

pandas/core/groupby/ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -996,8 +996,8 @@ def _aggregate_series_pure_python(self, obj: Series, func: F):
996996
counts[label] = group.shape[0]
997997
result[label] = res
998998

999-
out = lib.maybe_convert_objects(result, try_float=False)
1000-
out = maybe_cast_pointwise_result(out, obj.dtype, numeric_only=True)
999+
npvalues = lib.maybe_convert_objects(result, try_float=False)
1000+
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
10011001

10021002
return out, counts
10031003

pandas/core/internals/managers.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ensure_platform_int,
3636
is_1d_only_ea_dtype,
3737
is_dtype_equal,
38-
is_extension_array_dtype,
3938
is_list_like,
4039
)
4140
from pandas.core.dtypes.dtypes import ExtensionDtype
@@ -701,16 +700,16 @@ def _interleave(
701700
# Give EAs some input on what happens here. Sparse needs this.
702701
if isinstance(dtype, SparseDtype):
703702
dtype = dtype.subtype
704-
elif is_extension_array_dtype(dtype):
703+
elif isinstance(dtype, ExtensionDtype):
705704
dtype = "object"
706705
elif is_dtype_equal(dtype, str):
707706
dtype = "object"
708707

709708
# error: Argument "dtype" to "empty" has incompatible type
710709
# "Union[ExtensionDtype, str, dtype[Any], Type[object], None]"; expected
711710
# "Union[dtype[Any], None, type, _SupportsDType, str, Union[Tuple[Any, int],
712-
# Tuple[Any, Union[int, Sequence[int]]], List[Any], _DTypeDict, Tuple[Any,
713-
# Any]]]"
711+
# Tuple[Any, Union[int, Sequence[int]]], List[Any], _DTypeDict,
712+
# Tuple[Any, Any]]]"
714713
result = np.empty(self.shape, dtype=dtype) # type: ignore[arg-type]
715714

716715
itemmask = np.zeros(self.shape[0])
@@ -1108,16 +1107,12 @@ def fast_xs(self, loc: int) -> ArrayLike:
11081107
dtype = interleaved_dtype([blk.dtype for blk in self.blocks])
11091108

11101109
n = len(self)
1111-
if is_extension_array_dtype(dtype):
1110+
if isinstance(dtype, ExtensionDtype):
11121111
# we'll eventually construct an ExtensionArray.
11131112
result = np.empty(n, dtype=object)
1113+
# TODO: let's just use dtype.empty?
11141114
else:
1115-
# error: Argument "dtype" to "empty" has incompatible type
1116-
# "Union[dtype, ExtensionDtype, None]"; expected "Union[dtype,
1117-
# None, type, _SupportsDtype, str, Tuple[Any, int], Tuple[Any,
1118-
# Union[int, Sequence[int]]], List[Any], _DtypeDict, Tuple[Any,
1119-
# Any]]"
1120-
result = np.empty(n, dtype=dtype) # type: ignore[arg-type]
1115+
result = np.empty(n, dtype=dtype)
11211116

11221117
result = ensure_wrapped_if_datetimelike(result)
11231118

pandas/core/series.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3087,10 +3087,8 @@ def combine(self, other, func, fill_value=None) -> Series:
30873087
new_name = self.name
30883088

30893089
# try_float=False is to match _aggregate_series_pure_python
3090-
res_values = lib.maybe_convert_objects(new_values, try_float=False)
3091-
res_values = maybe_cast_pointwise_result(
3092-
res_values, self.dtype, same_dtype=False
3093-
)
3090+
npvalues = lib.maybe_convert_objects(new_values, try_float=False)
3091+
res_values = maybe_cast_pointwise_result(npvalues, self.dtype, same_dtype=False)
30943092
return self._constructor(res_values, index=new_index, name=new_name)
30953093

30963094
def combine_first(self, other) -> Series:

pandas/core/tools/datetimes.py

+28-41
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,7 @@ def _convert_and_box_cache(
252252
from pandas import Series
253253

254254
result = Series(arg).map(cache_array)
255-
# error: Argument 1 to "_box_as_indexlike" has incompatible type "Series"; expected
256-
# "Union[ExtensionArray, ndarray]"
257-
return _box_as_indexlike(result, utc=None, name=name) # type: ignore[arg-type]
255+
return _box_as_indexlike(result._values, utc=None, name=name)
258256

259257

260258
def _return_parsed_timezone_results(result: np.ndarray, timezones, tz, name) -> Index:
@@ -368,13 +366,11 @@ def _convert_listlike_datetimes(
368366
arg, _ = maybe_convert_dtype(arg, copy=False)
369367
except TypeError:
370368
if errors == "coerce":
371-
result = np.array(["NaT"], dtype="datetime64[ns]").repeat(len(arg))
372-
return DatetimeIndex(result, name=name)
369+
npvalues = np.array(["NaT"], dtype="datetime64[ns]").repeat(len(arg))
370+
return DatetimeIndex(npvalues, name=name)
373371
elif errors == "ignore":
374-
# error: Incompatible types in assignment (expression has type
375-
# "Index", variable has type "ExtensionArray")
376-
result = Index(arg, name=name) # type: ignore[assignment]
377-
return result
372+
idx = Index(arg, name=name)
373+
return idx
378374
raise
379375

380376
arg = ensure_object(arg)
@@ -393,37 +389,30 @@ def _convert_listlike_datetimes(
393389
require_iso8601 = not infer_datetime_format
394390
format = None
395391

396-
# error: Incompatible types in assignment (expression has type "None", variable has
397-
# type "ExtensionArray")
398-
result = None # type: ignore[assignment]
399-
400392
if format is not None:
401-
# error: Incompatible types in assignment (expression has type
402-
# "Optional[Index]", variable has type "ndarray")
403-
result = _to_datetime_with_format( # type: ignore[assignment]
393+
res = _to_datetime_with_format(
404394
arg, orig_arg, name, tz, format, exact, errors, infer_datetime_format
405395
)
406-
if result is not None:
407-
return result
408-
409-
if result is None:
410-
assert format is None or infer_datetime_format
411-
utc = tz == "utc"
412-
result, tz_parsed = objects_to_datetime64ns(
413-
arg,
414-
dayfirst=dayfirst,
415-
yearfirst=yearfirst,
416-
utc=utc,
417-
errors=errors,
418-
require_iso8601=require_iso8601,
419-
allow_object=True,
420-
)
396+
if res is not None:
397+
return res
421398

422-
if tz_parsed is not None:
423-
# We can take a shortcut since the datetime64 numpy array
424-
# is in UTC
425-
dta = DatetimeArray(result, dtype=tz_to_dtype(tz_parsed))
426-
return DatetimeIndex._simple_new(dta, name=name)
399+
assert format is None or infer_datetime_format
400+
utc = tz == "utc"
401+
result, tz_parsed = objects_to_datetime64ns(
402+
arg,
403+
dayfirst=dayfirst,
404+
yearfirst=yearfirst,
405+
utc=utc,
406+
errors=errors,
407+
require_iso8601=require_iso8601,
408+
allow_object=True,
409+
)
410+
411+
if tz_parsed is not None:
412+
# We can take a shortcut since the datetime64 numpy array
413+
# is in UTC
414+
dta = DatetimeArray(result, dtype=tz_to_dtype(tz_parsed))
415+
return DatetimeIndex._simple_new(dta, name=name)
427416

428417
utc = tz == "utc"
429418
return _box_as_indexlike(result, utc=utc, name=name)
@@ -509,13 +498,11 @@ def _to_datetime_with_format(
509498

510499
# fallback
511500
if result is None:
512-
# error: Incompatible types in assignment (expression has type
513-
# "Optional[Index]", variable has type "Optional[ndarray]")
514-
result = _array_strptime_with_fallback( # type: ignore[assignment]
501+
res = _array_strptime_with_fallback(
515502
arg, name, tz, fmt, exact, errors, infer_datetime_format
516503
)
517-
if result is not None:
518-
return result
504+
if res is not None:
505+
return res
519506

520507
except ValueError as e:
521508
# Fallback to try to convert datetime objects if timezone-aware

pandas/io/formats/format.py

-1
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,6 @@ def _format(x):
13181318
"ExtensionArray formatting should use ExtensionArrayFormatter"
13191319
)
13201320
inferred = lib.map_infer(vals, is_float)
1321-
inferred = cast(np.ndarray, inferred)
13221321
is_float_type = (
13231322
inferred
13241323
# vals may have 2 or more dimensions

0 commit comments

Comments
 (0)