Skip to content

Commit 5a35050

Browse files
BUG/TST: Fix infer_dtype for Period array-likes and general ExtensionArrays (#37367)
1 parent f7a44db commit 5a35050

File tree

7 files changed

+47
-15
lines changed

7 files changed

+47
-15
lines changed

doc/source/whatsnew/v1.3.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,8 @@ ExtensionArray
448448
Other
449449
^^^^^
450450
- Bug in :class:`Index` constructor sometimes silently ignorning a specified ``dtype`` (:issue:`38879`)
451+
- Bug in :func:`pandas.api.types.infer_dtype` not recognizing Series, Index or array with a period dtype (:issue:`23553`)
452+
- Bug in :func:`pandas.api.types.infer_dtype` raising an error for general :class:`.ExtensionArray` objects. It will now return ``"unknown-array"`` instead of raising (:issue:`37367`)
451453
- Bug in constructing a :class:`Series` from a list and a :class:`PandasDtype` (:issue:`39357`)
452454
- Bug in :class:`Styler` which caused CSS to duplicate on multiple renders. (:issue:`39395`)
453455
- ``inspect.getmembers(Series)`` no longer raises an ``AbstractMethodError`` (:issue:`38782`)

pandas/_libs/lib.pyx

+12-8
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ from pandas._libs cimport util
6969
from pandas._libs.util cimport INT64_MAX, INT64_MIN, UINT64_MAX, is_nan
7070

7171
from pandas._libs.tslib import array_to_datetime
72+
from pandas._libs.tslibs.period import Period
7273

7374
from pandas._libs.missing cimport (
7475
C_NA,
@@ -1082,6 +1083,7 @@ _TYPE_MAP = {
10821083
"timedelta64[ns]": "timedelta64",
10831084
"m": "timedelta64",
10841085
"interval": "interval",
1086+
Period: "period",
10851087
}
10861088

10871089
# types only exist on certain platform
@@ -1233,8 +1235,8 @@ cdef object _try_infer_map(object dtype):
12331235
cdef:
12341236
object val
12351237
str attr
1236-
for attr in ["name", "kind", "base"]:
1237-
val = getattr(dtype, attr)
1238+
for attr in ["name", "kind", "base", "type"]:
1239+
val = getattr(dtype, attr, None)
12381240
if val in _TYPE_MAP:
12391241
return _TYPE_MAP[val]
12401242
return None
@@ -1275,6 +1277,7 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
12751277
- time
12761278
- period
12771279
- mixed
1280+
- unknown-array
12781281

12791282
Raises
12801283
------
@@ -1287,6 +1290,9 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
12871290
specialized
12881291
- 'mixed-integer-float' are floats and integers
12891292
- 'mixed-integer' are integers mixed with non-integers
1293+
- 'unknown-array' is the catchall for something that *is* an array (has
1294+
a dtype attribute), but has a dtype unknown to pandas (e.g. external
1295+
extension array)
12901296

12911297
Examples
12921298
--------
@@ -1355,12 +1361,10 @@ def infer_dtype(value: object, skipna: bool = True) -> str:
13551361
# e.g. categoricals
13561362
dtype = value.dtype
13571363
if not isinstance(dtype, np.dtype):
1358-
value = _try_infer_map(value.dtype)
1359-
if value is not None:
1360-
return value
1361-
1362-
# its ndarray-like but we can't handle
1363-
raise ValueError(f"cannot infer type for {type(value)}")
1364+
inferred = _try_infer_map(value.dtype)
1365+
if inferred is not None:
1366+
return inferred
1367+
return "unknown-array"
13641368

13651369
# Unwrap Series/Index
13661370
values = np.asarray(value)

pandas/core/strings/accessor.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,7 @@ def _validate(data):
202202
if isinstance(values.dtype, StringDtype):
203203
return "string"
204204

205-
try:
206-
inferred_dtype = lib.infer_dtype(values, skipna=True)
207-
except ValueError:
208-
# GH#27571 mostly occurs with ExtensionArray
209-
inferred_dtype = None
205+
inferred_dtype = lib.infer_dtype(values, skipna=True)
210206

211207
if inferred_dtype not in allowed_types:
212208
raise AttributeError("Can only use .str accessor with string values!")

pandas/tests/dtypes/test_inference.py

+13
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,19 @@ def test_infer_dtype_period(self):
891891
arr = np.array([Period("2011-01", freq="D"), Period("2011-02", freq="M")])
892892
assert lib.infer_dtype(arr, skipna=True) == "period"
893893

894+
@pytest.mark.parametrize("klass", [pd.array, pd.Series, pd.Index])
895+
@pytest.mark.parametrize("skipna", [True, False])
896+
def test_infer_dtype_period_array(self, klass, skipna):
897+
# https://github.com/pandas-dev/pandas/issues/23553
898+
values = klass(
899+
[
900+
Period("2011-01-01", freq="D"),
901+
Period("2011-01-02", freq="D"),
902+
pd.NaT,
903+
]
904+
)
905+
assert lib.infer_dtype(values, skipna=skipna) == "period"
906+
894907
def test_infer_dtype_period_mixed(self):
895908
arr = np.array(
896909
[Period("2011-01", freq="M"), np.datetime64("nat")], dtype=object

pandas/tests/extension/base/dtype.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
import pandas as pd
7-
from pandas.api.types import is_object_dtype, is_string_dtype
7+
from pandas.api.types import infer_dtype, is_object_dtype, is_string_dtype
88
from pandas.tests.extension.base.base import BaseExtensionTests
99

1010

@@ -123,3 +123,11 @@ def test_get_common_dtype(self, dtype):
123123
# still testing as good practice to have this working (and it is the
124124
# only case we can test in general)
125125
assert dtype._get_common_dtype([dtype]) == dtype
126+
127+
@pytest.mark.parametrize("skipna", [True, False])
128+
def test_infer_dtype(self, data, data_missing, skipna):
129+
# only testing that this works without raising an error
130+
res = infer_dtype(data, skipna=skipna)
131+
assert isinstance(res, str)
132+
res = infer_dtype(data_missing, skipna=skipna)
133+
assert isinstance(res, str)

pandas/tests/extension/decimal/test_decimal.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pandas as pd
99
import pandas._testing as tm
10+
from pandas.api.types import infer_dtype
1011
from pandas.tests.extension import base
1112
from pandas.tests.extension.decimal.array import (
1213
DecimalArray,
@@ -120,6 +121,13 @@ class TestDtype(BaseDecimal, base.BaseDtypeTests):
120121
def test_hashable(self, dtype):
121122
pass
122123

124+
@pytest.mark.parametrize("skipna", [True, False])
125+
def test_infer_dtype(self, data, data_missing, skipna):
126+
# here overriding base test to ensure we fall back to return
127+
# "unknown-array" for an EA pandas doesn't know
128+
assert infer_dtype(data, skipna=skipna) == "unknown-array"
129+
assert infer_dtype(data_missing, skipna=skipna) == "unknown-array"
130+
123131

124132
class TestInterface(BaseDecimal, base.BaseInterfaceTests):
125133
pass

pandas/tests/io/test_parquet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,8 @@ def test_unsupported(self, fp):
938938

939939
# period
940940
df = pd.DataFrame({"a": pd.period_range("2013", freq="M", periods=3)})
941-
self.check_error_on_write(df, fp, ValueError, "cannot infer type for")
941+
# error from fastparquet -> don't check exact error message
942+
self.check_error_on_write(df, fp, ValueError, None)
942943

943944
# mixed
944945
df = pd.DataFrame({"a": ["a", 1, 2.0]})

0 commit comments

Comments
 (0)