Skip to content

Commit af43bfc

Browse files
authored
ENH: Add misc pyarrow types to ArrowDtype.type (#51854)
* ENH: Add misc pyarrow types to ArrowDtype.type * change exception * Change to CategoricalDtypeType
1 parent 78947dd commit af43bfc

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

pandas/core/arrays/arrow/dtype.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
StorageExtensionDtype,
2828
register_extension_dtype,
2929
)
30+
from pandas.core.dtypes.dtypes import CategoricalDtypeType
3031

3132
if not pa_version_under7p0:
3233
import pyarrow as pa
@@ -106,7 +107,7 @@ def type(self):
106107
return int
107108
elif pa.types.is_floating(pa_type):
108109
return float
109-
elif pa.types.is_string(pa_type):
110+
elif pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
110111
return str
111112
elif (
112113
pa.types.is_binary(pa_type)
@@ -132,6 +133,14 @@ def type(self):
132133
return time
133134
elif pa.types.is_decimal(pa_type):
134135
return Decimal
136+
elif pa.types.is_dictionary(pa_type):
137+
# TODO: Potentially change this & CategoricalDtype.type to
138+
# something more representative of the scalar
139+
return CategoricalDtypeType
140+
elif pa.types.is_list(pa_type) or pa.types.is_large_list(pa_type):
141+
return list
142+
elif pa.types.is_map(pa_type):
143+
return dict
135144
elif pa.types.is_null(pa_type):
136145
# TODO: None? pd.NA? pa.null?
137146
return type(pa_type)

pandas/tests/extension/test_arrow.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from pandas.errors import PerformanceWarning
4040

4141
from pandas.core.dtypes.common import is_any_int_dtype
42+
from pandas.core.dtypes.dtypes import CategoricalDtypeType
4243

4344
import pandas as pd
4445
import pandas._testing as tm
@@ -1530,9 +1531,23 @@ def test_mode_dropna_false_mode_na(data):
15301531
tm.assert_series_equal(result, expected)
15311532

15321533

1533-
@pytest.mark.parametrize("arrow_dtype", [pa.binary(), pa.binary(16), pa.large_binary()])
1534-
def test_arrow_dtype_type(arrow_dtype):
1535-
assert ArrowDtype(arrow_dtype).type == bytes
1534+
@pytest.mark.parametrize(
1535+
"arrow_dtype, expected_type",
1536+
[
1537+
[pa.binary(), bytes],
1538+
[pa.binary(16), bytes],
1539+
[pa.large_binary(), bytes],
1540+
[pa.large_string(), str],
1541+
[pa.list_(pa.int64()), list],
1542+
[pa.large_list(pa.int64()), list],
1543+
[pa.map_(pa.string(), pa.int64()), dict],
1544+
[pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType],
1545+
],
1546+
)
1547+
def test_arrow_dtype_type(arrow_dtype, expected_type):
1548+
# GH 51845
1549+
# TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture
1550+
assert ArrowDtype(arrow_dtype).type == expected_type
15361551

15371552

15381553
def test_is_bool_dtype():
@@ -1925,7 +1940,7 @@ def test_str_get(i, exp):
19251940

19261941
@pytest.mark.xfail(
19271942
reason="TODO: StringMethods._validate should support Arrow list types",
1928-
raises=NotImplementedError,
1943+
raises=AttributeError,
19291944
)
19301945
def test_str_join():
19311946
ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))

0 commit comments

Comments
 (0)