Skip to content

Commit d6c63c5

Browse files
jbrockmendelim-vinicius
authored and
im-vinicius
committed
CLN: remove pandas_dtype kwd from infer_dtype_from_x (pandas-dev#53064)
* REF: avoid object dtype in mask_missing * REF: remove pandas_dtype kwarg from infer_dtype_from * REF: remove pandas_dtype kwarg from infer_dtype_from_scalar * mypy fixup
1 parent 79ae2da commit d6c63c5

File tree

9 files changed

+73
-114
lines changed

9 files changed

+73
-114
lines changed

pandas/core/array_algos/putmask.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def setitem_datetimelike_compat(values: np.ndarray, num_set: int, other):
136136
other : Any
137137
"""
138138
if values.dtype == object:
139-
dtype, _ = infer_dtype_from(other, pandas_dtype=True)
139+
dtype, _ = infer_dtype_from(other)
140140

141141
if isinstance(dtype, np.dtype) and dtype.kind in "mM":
142142
# https://github.com/numpy/numpy/issues/12550

pandas/core/dtypes/cast.py

+18-44
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
631631

632632
# returns tuple of (dtype, fill_value)
633633
if issubclass(dtype.type, np.datetime64):
634-
inferred, fv = infer_dtype_from_scalar(fill_value, pandas_dtype=True)
634+
inferred, fv = infer_dtype_from_scalar(fill_value)
635635
if inferred == dtype:
636636
return dtype, fv
637637

@@ -645,7 +645,7 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan):
645645
return _dtype_obj, fill_value
646646

647647
elif issubclass(dtype.type, np.timedelta64):
648-
inferred, fv = infer_dtype_from_scalar(fill_value, pandas_dtype=True)
648+
inferred, fv = infer_dtype_from_scalar(fill_value)
649649
if inferred == dtype:
650650
return dtype, fv
651651

@@ -735,33 +735,26 @@ def _ensure_dtype_type(value, dtype: np.dtype):
735735
return dtype.type(value)
736736

737737

738-
def infer_dtype_from(val, pandas_dtype: bool = False) -> tuple[DtypeObj, Any]:
738+
def infer_dtype_from(val) -> tuple[DtypeObj, Any]:
739739
"""
740740
Interpret the dtype from a scalar or array.
741741
742742
Parameters
743743
----------
744744
val : object
745-
pandas_dtype : bool, default False
746-
whether to infer dtype including pandas extension types.
747-
If False, scalar/array belongs to pandas extension types is inferred as
748-
object
749745
"""
750746
if not is_list_like(val):
751-
return infer_dtype_from_scalar(val, pandas_dtype=pandas_dtype)
752-
return infer_dtype_from_array(val, pandas_dtype=pandas_dtype)
747+
return infer_dtype_from_scalar(val)
748+
return infer_dtype_from_array(val)
753749

754750

755-
def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj, Any]:
751+
def infer_dtype_from_scalar(val) -> tuple[DtypeObj, Any]:
756752
"""
757753
Interpret the dtype from a scalar.
758754
759755
Parameters
760756
----------
761-
pandas_dtype : bool, default False
762-
whether to infer dtype including pandas extension types.
763-
If False, scalar belongs to pandas extension types is inferred as
764-
object
757+
val : object
765758
"""
766759
dtype: DtypeObj = _dtype_obj
767760

@@ -796,11 +789,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj,
796789
dtype = val.dtype
797790
# TODO: test with datetime(2920, 10, 1) based on test_replace_dtypes
798791
else:
799-
if pandas_dtype:
800-
dtype = DatetimeTZDtype(unit="ns", tz=val.tz)
801-
else:
802-
# return datetimetz as object
803-
return _dtype_obj, val
792+
dtype = DatetimeTZDtype(unit="ns", tz=val.tz)
804793

805794
elif isinstance(val, (np.timedelta64, dt.timedelta)):
806795
try:
@@ -834,12 +823,11 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> tuple[DtypeObj,
834823
elif is_complex(val):
835824
dtype = np.dtype(np.complex_)
836825

837-
elif pandas_dtype:
838-
if lib.is_period(val):
839-
dtype = PeriodDtype(freq=val.freq)
840-
elif lib.is_interval(val):
841-
subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0]
842-
dtype = IntervalDtype(subtype=subtype, closed=val.closed)
826+
if lib.is_period(val):
827+
dtype = PeriodDtype(freq=val.freq)
828+
elif lib.is_interval(val):
829+
subtype = infer_dtype_from_scalar(val.left)[0]
830+
dtype = IntervalDtype(subtype=subtype, closed=val.closed)
843831

844832
return dtype, val
845833

@@ -859,32 +847,18 @@ def dict_compat(d: dict[Scalar, Scalar]) -> dict[Scalar, Scalar]:
859847
return {maybe_box_datetimelike(key): value for key, value in d.items()}
860848

861849

862-
def infer_dtype_from_array(
863-
arr, pandas_dtype: bool = False
864-
) -> tuple[DtypeObj, ArrayLike]:
850+
def infer_dtype_from_array(arr) -> tuple[DtypeObj, ArrayLike]:
865851
"""
866852
Infer the dtype from an array.
867853
868854
Parameters
869855
----------
870856
arr : array
871-
pandas_dtype : bool, default False
872-
whether to infer dtype including pandas extension types.
873-
If False, array belongs to pandas extension types
874-
is inferred as object
875857
876858
Returns
877859
-------
878-
tuple (numpy-compat/pandas-compat dtype, array)
879-
880-
Notes
881-
-----
882-
if pandas_dtype=False. these infer to numpy dtypes
883-
exactly with the exception that mixed / object dtypes
884-
are not coerced by stringifying or conversion
860+
tuple (pandas-compat dtype, array)
885861
886-
if pandas_dtype=True. datetime64tz-aware/categorical
887-
types will retain there character.
888862
889863
Examples
890864
--------
@@ -901,7 +875,7 @@ def infer_dtype_from_array(
901875
raise TypeError("'arr' must be list-like")
902876

903877
arr_dtype = getattr(arr, "dtype", None)
904-
if pandas_dtype and isinstance(arr_dtype, ExtensionDtype):
878+
if isinstance(arr_dtype, ExtensionDtype):
905879
return arr.dtype, arr
906880

907881
elif isinstance(arr, ABCSeries):
@@ -1303,7 +1277,7 @@ def find_result_type(left: ArrayLike, right: Any) -> DtypeObj:
13031277
new_dtype = ensure_dtype_can_hold_na(left.dtype)
13041278

13051279
else:
1306-
dtype, _ = infer_dtype_from(right, pandas_dtype=True)
1280+
dtype, _ = infer_dtype_from(right)
13071281

13081282
new_dtype = find_common_type([left.dtype, dtype])
13091283

@@ -1466,7 +1440,7 @@ def construct_1d_arraylike_from_scalar(
14661440

14671441
if dtype is None:
14681442
try:
1469-
dtype, value = infer_dtype_from_scalar(value, pandas_dtype=True)
1443+
dtype, value = infer_dtype_from_scalar(value)
14701444
except OutOfBoundsDatetime:
14711445
dtype = _dtype_obj
14721446

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ def __init__(
824824
columns = ensure_index(columns)
825825

826826
if not dtype:
827-
dtype, _ = infer_dtype_from_scalar(data, pandas_dtype=True)
827+
dtype, _ = infer_dtype_from_scalar(data)
828828

829829
# For data is a scalar extension dtype
830830
if isinstance(dtype, ExtensionDtype):

pandas/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6097,7 +6097,7 @@ def _find_common_type_compat(self, target) -> DtypeObj:
60976097
Implementation of find_common_type that adjusts for Index-specific
60986098
special cases.
60996099
"""
6100-
target_dtype, _ = infer_dtype_from(target, pandas_dtype=True)
6100+
target_dtype, _ = infer_dtype_from(target)
61016101

61026102
# special case: if one dtype is uint64 and the other a signed int, return object
61036103
# See https://github.com/pandas-dev/pandas/issues/26778 for discussion

pandas/core/indexes/interval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def _maybe_convert_i8(self, key):
558558

559559
if scalar:
560560
# Timestamp/Timedelta
561-
key_dtype, key_i8 = infer_dtype_from_scalar(key, pandas_dtype=True)
561+
key_dtype, key_i8 = infer_dtype_from_scalar(key)
562562
if lib.is_period(key):
563563
key_i8 = key.ordinal
564564
elif isinstance(key_i8, Timestamp):

pandas/core/internals/array_manager.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
new_block,
8585
to_native_types,
8686
)
87+
from pandas.core.internals.managers import make_na_array
8788

8889
if TYPE_CHECKING:
8990
from pandas._typing import (
@@ -665,13 +666,8 @@ def _make_na_array(self, fill_value=None, use_na_proxy: bool = False):
665666
fill_value = np.nan
666667

667668
dtype, fill_value = infer_dtype_from_scalar(fill_value)
668-
# error: Argument "dtype" to "empty" has incompatible type "Union[dtype[Any],
669-
# ExtensionDtype]"; expected "Union[dtype[Any], None, type, _SupportsDType, str,
670-
# Union[Tuple[Any, int], Tuple[Any, Union[int, Sequence[int]]], List[Any],
671-
# _DTypeDict, Tuple[Any, Any]]]"
672-
values = np.empty(self.shape_proper[0], dtype=dtype) # type: ignore[arg-type]
673-
values.fill(fill_value)
674-
return values
669+
array_values = make_na_array(dtype, self.shape_proper[:1], fill_value)
670+
return array_values
675671

676672
def _equal_values(self, other) -> bool:
677673
"""

pandas/core/internals/managers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def _make_na_block(
921921

922922
shape = (len(placement), self.shape[1])
923923

924-
dtype, fill_value = infer_dtype_from_scalar(fill_value, pandas_dtype=True)
924+
dtype, fill_value = infer_dtype_from_scalar(fill_value)
925925
block_values = make_na_array(dtype, shape, fill_value)
926926
return new_block_2d(block_values, placement=placement)
927927

pandas/core/missing.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,14 @@ def mask_missing(arr: ArrayLike, values_to_mask) -> npt.NDArray[np.bool_]:
8080
# known to be holdable by arr.
8181
# When called from Series._single_replace, values_to_mask is tuple or list
8282
dtype, values_to_mask = infer_dtype_from(values_to_mask)
83-
# error: Argument "dtype" to "array" has incompatible type "Union[dtype[Any],
84-
# ExtensionDtype]"; expected "Union[dtype[Any], None, type, _SupportsDType, str,
85-
# Union[Tuple[Any, int], Tuple[Any, Union[int, Sequence[int]]], List[Any],
86-
# _DTypeDict, Tuple[Any, Any]]]"
87-
values_to_mask = np.array(values_to_mask, dtype=dtype) # type: ignore[arg-type]
83+
84+
if isinstance(dtype, np.dtype):
85+
values_to_mask = np.array(values_to_mask, dtype=dtype)
86+
else:
87+
cls = dtype.construct_array_type()
88+
if not lib.is_list_like(values_to_mask):
89+
values_to_mask = [values_to_mask]
90+
values_to_mask = cls._from_sequence(values_to_mask, dtype=dtype, copy=False)
8891

8992
potential_na = False
9093
if is_object_dtype(arr.dtype):

pandas/tests/dtypes/cast/test_infer_dtype.py

+39-53
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@
2525
)
2626

2727

28-
@pytest.fixture(params=[True, False])
29-
def pandas_dtype(request):
30-
return request.param
31-
32-
3328
def test_infer_dtype_from_int_scalar(any_int_numpy_dtype):
3429
# Test that infer_dtype_from_scalar is
3530
# returning correct dtype for int and float.
@@ -81,36 +76,32 @@ def test_infer_dtype_from_timedelta(data):
8176

8277

8378
@pytest.mark.parametrize("freq", ["M", "D"])
84-
def test_infer_dtype_from_period(freq, pandas_dtype):
79+
def test_infer_dtype_from_period(freq):
8580
p = Period("2011-01-01", freq=freq)
86-
dtype, val = infer_dtype_from_scalar(p, pandas_dtype=pandas_dtype)
81+
dtype, val = infer_dtype_from_scalar(p)
8782

88-
if pandas_dtype:
89-
exp_dtype = f"period[{freq}]"
90-
else:
91-
exp_dtype = np.object_
83+
exp_dtype = f"period[{freq}]"
9284

9385
assert dtype == exp_dtype
9486
assert val == p
9587

9688

97-
@pytest.mark.parametrize(
98-
"data", [date(2000, 1, 1), "foo", Timestamp(1, tz="US/Eastern")]
99-
)
100-
def test_infer_dtype_misc(data):
101-
dtype, val = infer_dtype_from_scalar(data)
89+
def test_infer_dtype_misc():
90+
dt = date(2000, 1, 1)
91+
dtype, val = infer_dtype_from_scalar(dt)
10292
assert dtype == np.object_
10393

94+
ts = Timestamp(1, tz="US/Eastern")
95+
dtype, val = infer_dtype_from_scalar(ts)
96+
assert dtype == "datetime64[ns, US/Eastern]"
97+
10498

10599
@pytest.mark.parametrize("tz", ["UTC", "US/Eastern", "Asia/Tokyo"])
106-
def test_infer_from_scalar_tz(tz, pandas_dtype):
100+
def test_infer_from_scalar_tz(tz):
107101
dt = Timestamp(1, tz=tz)
108-
dtype, val = infer_dtype_from_scalar(dt, pandas_dtype=pandas_dtype)
102+
dtype, val = infer_dtype_from_scalar(dt)
109103

110-
if pandas_dtype:
111-
exp_dtype = f"datetime64[ns, {tz}]"
112-
else:
113-
exp_dtype = np.object_
104+
exp_dtype = f"datetime64[ns, {tz}]"
114105

115106
assert dtype == exp_dtype
116107
assert val == dt
@@ -126,11 +117,11 @@ def test_infer_from_scalar_tz(tz, pandas_dtype):
126117
(Timedelta(0), Timedelta(1), "timedelta64[ns]"),
127118
],
128119
)
129-
def test_infer_from_interval(left, right, subtype, closed, pandas_dtype):
120+
def test_infer_from_interval(left, right, subtype, closed):
130121
# GH 30337
131122
interval = Interval(left, right, closed)
132-
result_dtype, result_value = infer_dtype_from_scalar(interval, pandas_dtype)
133-
expected_dtype = f"interval[{subtype}, {closed}]" if pandas_dtype else np.object_
123+
result_dtype, result_value = infer_dtype_from_scalar(interval)
124+
expected_dtype = f"interval[{subtype}, {closed}]"
134125
assert result_dtype == expected_dtype
135126
assert result_value == interval
136127

@@ -143,54 +134,49 @@ def test_infer_dtype_from_scalar_errors():
143134

144135

145136
@pytest.mark.parametrize(
146-
"value, expected, pandas_dtype",
137+
"value, expected",
147138
[
148-
("foo", np.object_, False),
149-
(b"foo", np.object_, False),
150-
(1, np.int64, False),
151-
(1.5, np.float_, False),
152-
(np.datetime64("2016-01-01"), np.dtype("M8[ns]"), False),
153-
(Timestamp("20160101"), np.dtype("M8[ns]"), False),
154-
(Timestamp("20160101", tz="UTC"), np.object_, False),
155-
(Timestamp("20160101", tz="UTC"), "datetime64[ns, UTC]", True),
139+
("foo", np.object_),
140+
(b"foo", np.object_),
141+
(1, np.int64),
142+
(1.5, np.float_),
143+
(np.datetime64("2016-01-01"), np.dtype("M8[ns]")),
144+
(Timestamp("20160101"), np.dtype("M8[ns]")),
145+
(Timestamp("20160101", tz="UTC"), "datetime64[ns, UTC]"),
156146
],
157147
)
158-
def test_infer_dtype_from_scalar(value, expected, pandas_dtype):
159-
dtype, _ = infer_dtype_from_scalar(value, pandas_dtype=pandas_dtype)
148+
def test_infer_dtype_from_scalar(value, expected):
149+
dtype, _ = infer_dtype_from_scalar(value)
160150
assert is_dtype_equal(dtype, expected)
161151

162152
with pytest.raises(TypeError, match="must be list-like"):
163-
infer_dtype_from_array(value, pandas_dtype=pandas_dtype)
153+
infer_dtype_from_array(value)
164154

165155

166156
@pytest.mark.parametrize(
167-
"arr, expected, pandas_dtype",
157+
"arr, expected",
168158
[
169-
([1], np.int_, False),
170-
(np.array([1], dtype=np.int64), np.int64, False),
171-
([np.nan, 1, ""], np.object_, False),
172-
(np.array([[1.0, 2.0]]), np.float_, False),
173-
(Categorical(list("aabc")), np.object_, False),
174-
(Categorical([1, 2, 3]), np.int64, False),
175-
(Categorical(list("aabc")), "category", True),
176-
(Categorical([1, 2, 3]), "category", True),
177-
(date_range("20160101", periods=3), np.dtype("=M8[ns]"), False),
159+
([1], np.int_),
160+
(np.array([1], dtype=np.int64), np.int64),
161+
([np.nan, 1, ""], np.object_),
162+
(np.array([[1.0, 2.0]]), np.float_),
163+
(Categorical(list("aabc")), "category"),
164+
(Categorical([1, 2, 3]), "category"),
165+
(date_range("20160101", periods=3), np.dtype("=M8[ns]")),
178166
(
179167
date_range("20160101", periods=3, tz="US/Eastern"),
180168
"datetime64[ns, US/Eastern]",
181-
True,
182169
),
183-
(Series([1.0, 2, 3]), np.float64, False),
184-
(Series(list("abc")), np.object_, False),
170+
(Series([1.0, 2, 3]), np.float64),
171+
(Series(list("abc")), np.object_),
185172
(
186173
Series(date_range("20160101", periods=3, tz="US/Eastern")),
187174
"datetime64[ns, US/Eastern]",
188-
True,
189175
),
190176
],
191177
)
192-
def test_infer_dtype_from_array(arr, expected, pandas_dtype):
193-
dtype, _ = infer_dtype_from_array(arr, pandas_dtype=pandas_dtype)
178+
def test_infer_dtype_from_array(arr, expected):
179+
dtype, _ = infer_dtype_from_array(arr)
194180
assert is_dtype_equal(dtype, expected)
195181

196182

0 commit comments

Comments
 (0)