Skip to content

Commit 2be18d6

Browse files
jbrockmendelYi Wei
authored and
Yi Wei
committed
BUG: infer_dtype with ArrowDtype (pandas-dev#53023)
* BUG: infer_dtype with ArrowDtype * xfail
1 parent b1d3a15 commit 2be18d6

File tree

3 files changed

+27
-11
lines changed

3 files changed

+27
-11
lines changed

pandas/_libs/lib.pyx

+12-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ from cpython.datetime cimport (
1313
PyDateTime_Check,
1414
PyDelta_Check,
1515
PyTime_Check,
16+
date,
17+
datetime,
1618
import_datetime,
19+
time,
20+
timedelta,
1721
)
1822
from cpython.iterator cimport PyIter_Check
1923
from cpython.number cimport PyNumber_Check
@@ -1204,6 +1208,12 @@ _TYPE_MAP = {
12041208
"m": "timedelta64",
12051209
"interval": "interval",
12061210
Period: "period",
1211+
datetime: "datetime64",
1212+
date: "date",
1213+
time: "time",
1214+
timedelta: "timedelta64",
1215+
Decimal: "decimal",
1216+
bytes: "bytes",
12071217
}
12081218

12091219
# types only exist on certain platform
@@ -1373,7 +1383,8 @@ cdef object _try_infer_map(object dtype):
13731383
cdef:
13741384
object val
13751385
str attr
1376-
for attr in ["kind", "name", "base", "type"]:
1386+
for attr in ["type", "kind", "name", "base"]:
1387+
# Checking type before kind matters for ArrowDtype cases
13771388
val = getattr(dtype, attr, None)
13781389
if val in _TYPE_MAP:
13791390
return _TYPE_MAP[val]

pandas/tests/extension/decimal/test_decimal.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pandas as pd
88
import pandas._testing as tm
9-
from pandas.api.types import infer_dtype
109
from pandas.tests.extension import base
1110
from pandas.tests.extension.decimal.array import (
1211
DecimalArray,
@@ -70,15 +69,7 @@ def data_for_grouping():
7069

7170

7271
class TestDtype(base.BaseDtypeTests):
73-
def test_hashable(self, dtype):
74-
pass
75-
76-
@pytest.mark.parametrize("skipna", [True, False])
77-
def test_infer_dtype(self, data, data_missing, skipna):
78-
# here overriding base test to ensure we fall back to return
79-
# "unknown-array" for an EA pandas doesn't know
80-
assert infer_dtype(data, skipna=skipna) == "unknown-array"
81-
assert infer_dtype(data_missing, skipna=skipna) == "unknown-array"
72+
pass
8273

8374

8475
class TestInterface(base.BaseInterfaceTests):

pandas/tests/extension/test_arrow.py

+14
Original file line numberDiff line numberDiff line change
@@ -2905,3 +2905,17 @@ def test_duration_overflow_from_ndarray_containing_nat():
29052905
result = ser_ts + ser_td
29062906
expected = pd.Series([2, None], dtype=ArrowDtype(pa.timestamp("ns")))
29072907
tm.assert_series_equal(result, expected)
2908+
2909+
2910+
def test_infer_dtype_pyarrow_dtype(data, request):
2911+
res = lib.infer_dtype(data)
2912+
assert res != "unknown-array"
2913+
2914+
if data._hasna and res in ["floating", "datetime64", "timedelta64"]:
2915+
mark = pytest.mark.xfail(
2916+
reason="in infer_dtype pd.NA is not ignored in these cases "
2917+
"even with skipna=True in the list(data) check below"
2918+
)
2919+
request.node.add_marker(mark)
2920+
2921+
assert res == lib.infer_dtype(list(data), skipna=True)

0 commit comments

Comments
 (0)