Skip to content

Commit 9a2dc47

Browse files
Backport PR #51854 on branch 2.0.x (ENH: Add misc pyarrow types to ArrowDtype.type) (#51887)
Backport PR #51854: ENH: Add misc pyarrow types to ArrowDtype.type Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 5f2b051 commit 9a2dc47

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

pandas/core/arrays/arrow/dtype.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
StorageExtensionDtype,
2828
register_extension_dtype,
2929
)
30+
from pandas.core.dtypes.dtypes import CategoricalDtypeType
3031

3132
if not pa_version_under7p0:
3233
import pyarrow as pa
@@ -106,7 +107,7 @@ def type(self):
106107
return int
107108
elif pa.types.is_floating(pa_type):
108109
return float
109-
elif pa.types.is_string(pa_type):
110+
elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
110111
return str
111112
elif (
112113
pa.types.is_binary(pa_type)
@@ -132,6 +133,14 @@ def type(self):
132133
return time
133134
elif pa.types.is_decimal(pa_type):
134135
return Decimal
136+
elif pa.types.is_dictionary(pa_type):
137+
# TODO: Potentially change this & CategoricalDtype.type to
138+
# something more representative of the scalar
139+
return CategoricalDtypeType
140+
elif pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type):
141+
return list
142+
elif pa.types.is_map(pa_type):
143+
return dict
135144
elif pa.types.is_null(pa_type):
136145
# TODO: None? pd.NA? pa.null?
137146
return type(pa_type)

pandas/tests/extension/test_arrow.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pandas.errors import PerformanceWarning
4040

4141
from pandas.core.dtypes.common import is_any_int_dtype
42+
from pandas.core.dtypes.dtypes import CategoricalDtypeType
4243

4344
import pandas as pd
4445
import pandas._testing as tm
@@ -1543,9 +1544,23 @@ def test_mode_dropna_false_mode_na(data):
15431544
tm.assert_series_equal(result, expected)
15441545

15451546

1546-
@pytest.mark.parametrize("arrow_dtype", [pa.binary(), pa.binary(16), pa.large_binary()])
1547-
def test_arrow_dtype_type(arrow_dtype):
1548-
assert ArrowDtype(arrow_dtype).type == bytes
1547+
@pytest.mark.parametrize(
1548+
"arrow_dtype, expected_type",
1549+
[
1550+
[pa.binary(), bytes],
1551+
[pa.binary(16), bytes],
1552+
[pa.large_binary(), bytes],
1553+
[pa.large_string(), str],
1554+
[pa.list_(pa.int64()), list],
1555+
[pa.large_list(pa.int64()), list],
1556+
[pa.map_(pa.string(), pa.int64()), dict],
1557+
[pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType],
1558+
],
1559+
)
1560+
def test_arrow_dtype_type(arrow_dtype, expected_type):
1561+
# GH 51845
1562+
# TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture
1563+
assert ArrowDtype(arrow_dtype).type == expected_type
15491564

15501565

15511566
def test_is_bool_dtype():
@@ -1938,7 +1953,7 @@ def test_str_get(i, exp):
19381953

19391954
@pytest.mark.xfail(
19401955
reason="TODO: StringMethods._validate should support Arrow list types",
1941-
raises=NotImplementedError,
1956+
raises=AttributeError,
19421957
)
19431958
def test_str_join():
19441959
ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))

0 commit comments

Comments
 (0)