Skip to content

Commit 3882fd2

Browse files
authored
BUG: require arraylike in infer_dtype_from_array (#38473)
1 parent cd0fb05 commit 3882fd2

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

pandas/core/dtypes/cast.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def infer_dtype_from(val, pandas_dtype: bool = False) -> Tuple[DtypeObj, Any]:
673673
If False, scalar/array belongs to pandas extension types is inferred as
674674
object
675675
"""
676-
if is_scalar(val):
676+
if not is_list_like(val):
677677
return infer_dtype_from_scalar(val, pandas_dtype=pandas_dtype)
678678
return infer_dtype_from_array(val, pandas_dtype=pandas_dtype)
679679

@@ -814,7 +814,7 @@ def infer_dtype_from_array(
814814
return arr.dtype, arr
815815

816816
if not is_list_like(arr):
817-
arr = [arr]
817+
raise TypeError("'arr' must be list-like")
818818

819819
if pandas_dtype and is_extension_array_dtype(arr):
820820
return arr.dtype, arr

pandas/core/missing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas._typing import ArrayLike, Axis, DtypeObj
1111
from pandas.compat._optional import import_optional_dependency
1212

13-
from pandas.core.dtypes.cast import infer_dtype_from_array
13+
from pandas.core.dtypes.cast import infer_dtype_from
1414
from pandas.core.dtypes.common import (
1515
ensure_float64,
1616
is_integer_dtype,
@@ -40,7 +40,7 @@ def mask_missing(arr: ArrayLike, values_to_mask) -> np.ndarray:
4040
# When called from Block.replace/replace_list, values_to_mask is a scalar
4141
# known to be holdable by arr.
4242
# When called from Series._single_replace, values_to_mask is tuple or list
43-
dtype, values_to_mask = infer_dtype_from_array(values_to_mask)
43+
dtype, values_to_mask = infer_dtype_from(values_to_mask)
4444
values_to_mask = np.array(values_to_mask, dtype=dtype)
4545

4646
na_mask = isna(values_to_mask)

pandas/tests/dtypes/cast/test_infer_dtype.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,29 @@ def test_infer_dtype_from_scalar_errors():
141141

142142

143143
@pytest.mark.parametrize(
144-
"arr, expected, pandas_dtype",
144+
"value, expected, pandas_dtype",
145145
[
146146
("foo", np.object_, False),
147147
(b"foo", np.object_, False),
148-
(1, np.int_, False),
148+
(1, np.int64, False),
149149
(1.5, np.float_, False),
150+
(np.datetime64("2016-01-01"), np.dtype("M8[ns]"), False),
151+
(Timestamp("20160101"), np.dtype("M8[ns]"), False),
152+
(Timestamp("20160101", tz="UTC"), np.object_, False),
153+
(Timestamp("20160101", tz="UTC"), "datetime64[ns, UTC]", True),
154+
],
155+
)
156+
def test_infer_dtype_from_scalar(value, expected, pandas_dtype):
157+
dtype, _ = infer_dtype_from_scalar(value, pandas_dtype=pandas_dtype)
158+
assert is_dtype_equal(dtype, expected)
159+
160+
with pytest.raises(TypeError, match="must be list-like"):
161+
infer_dtype_from_array(value, pandas_dtype=pandas_dtype)
162+
163+
164+
@pytest.mark.parametrize(
165+
"arr, expected, pandas_dtype",
166+
[
150167
([1], np.int_, False),
151168
(np.array([1], dtype=np.int64), np.int64, False),
152169
([np.nan, 1, ""], np.object_, False),
@@ -155,8 +172,6 @@ def test_infer_dtype_from_scalar_errors():
155172
(Categorical([1, 2, 3]), np.int64, False),
156173
(Categorical(list("aabc")), "category", True),
157174
(Categorical([1, 2, 3]), "category", True),
158-
(Timestamp("20160101"), np.object_, False),
159-
(np.datetime64("2016-01-01"), np.dtype("=M8[D]"), False),
160175
(date_range("20160101", periods=3), np.dtype("=M8[ns]"), False),
161176
(
162177
date_range("20160101", periods=3, tz="US/Eastern"),

0 commit comments

Comments
 (0)