Skip to content

Commit 629b9b7

Browse files
simonjayhawkinsJulianWgs
authored andcommitted
[ArrowStringArray] ENH: raise an ImportError when trying to create an arrow string dtype if pyarrow is not installed (pandas-dev#41732)
1 parent 7b531a5 commit 629b9b7

File tree

2 files changed

+57
-34
lines changed

2 files changed

+57
-34
lines changed

pandas/core/arrays/string_arrow.py

+26-31
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
type_t,
2424
)
2525
from pandas.compat import (
26+
pa_version_under1p0,
2627
pa_version_under2p0,
2728
pa_version_under3p0,
2829
pa_version_under4p0,
2930
)
30-
from pandas.compat.pyarrow import pa_version_under1p0
3131
from pandas.util._decorators import doc
3232
from pandas.util._validators import validate_fillna_kwargs
3333

@@ -55,31 +55,33 @@
5555
)
5656
from pandas.core.strings.object_array import ObjectStringArrayMixin
5757

58-
try:
58+
# PyArrow backed StringArrays are available starting at 1.0.0, but this
59+
# file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute
60+
# and its compute functions existed. GH38801
61+
if not pa_version_under1p0:
5962
import pyarrow as pa
60-
except ImportError:
61-
pa = None
62-
else:
63-
# PyArrow backed StringArrays are available starting at 1.0.0, but this
64-
# file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute
65-
# and its compute functions existed. GH38801
66-
if not pa_version_under1p0:
67-
import pyarrow.compute as pc
68-
69-
ARROW_CMP_FUNCS = {
70-
"eq": pc.equal,
71-
"ne": pc.not_equal,
72-
"lt": pc.less,
73-
"gt": pc.greater,
74-
"le": pc.less_equal,
75-
"ge": pc.greater_equal,
76-
}
63+
import pyarrow.compute as pc
64+
65+
ARROW_CMP_FUNCS = {
66+
"eq": pc.equal,
67+
"ne": pc.not_equal,
68+
"lt": pc.less,
69+
"gt": pc.greater,
70+
"le": pc.less_equal,
71+
"ge": pc.greater_equal,
72+
}
7773

7874

7975
if TYPE_CHECKING:
8076
from pandas import Series
8177

8278

79+
def _chk_pyarrow_available() -> None:
80+
if pa_version_under1p0:
81+
msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray."
82+
raise ImportError(msg)
83+
84+
8385
@register_extension_dtype
8486
class ArrowStringDtype(StringDtype):
8587
"""
@@ -112,6 +114,9 @@ class ArrowStringDtype(StringDtype):
112114
#: StringDtype.na_value uses pandas.NA
113115
na_value = libmissing.NA
114116

117+
def __init__(self):
118+
_chk_pyarrow_available()
119+
115120
@property
116121
def type(self) -> type[str]:
117122
return str
@@ -213,10 +218,8 @@ class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin):
213218
Length: 4, dtype: arrow_string
214219
"""
215220

216-
_dtype = ArrowStringDtype()
217-
218221
def __init__(self, values):
219-
self._chk_pyarrow_available()
222+
self._dtype = ArrowStringDtype()
220223
if isinstance(values, pa.Array):
221224
self._data = pa.chunked_array([values])
222225
elif isinstance(values, pa.ChunkedArray):
@@ -229,19 +232,11 @@ def __init__(self, values):
229232
"ArrowStringArray requires a PyArrow (chunked) array of string type"
230233
)
231234

232-
@classmethod
233-
def _chk_pyarrow_available(cls) -> None:
234-
# TODO: maybe update import_optional_dependency to allow a minimum
235-
# version to be specified rather than use the global minimum
236-
if pa is None or pa_version_under1p0:
237-
msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray."
238-
raise ImportError(msg)
239-
240235
@classmethod
241236
def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
242237
from pandas.core.arrays.masked import BaseMaskedArray
243238

244-
cls._chk_pyarrow_available()
239+
_chk_pyarrow_available()
245240

246241
if isinstance(scalars, BaseMaskedArray):
247242
# avoid costly conversion to object dtype in ensure_string_array and

pandas/tests/arrays/string_/test_string_arrow.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,25 @@
33
import numpy as np
44
import pytest
55

6-
from pandas.core.arrays.string_arrow import ArrowStringArray
6+
from pandas.compat import pa_version_under1p0
77

8-
pa = pytest.importorskip("pyarrow", minversion="1.0.0")
8+
from pandas.core.arrays.string_arrow import (
9+
ArrowStringArray,
10+
ArrowStringDtype,
11+
)
912

1013

14+
@pytest.mark.skipif(
15+
pa_version_under1p0,
16+
reason="pyarrow>=1.0.0 is required for PyArrow backed StringArray",
17+
)
1118
@pytest.mark.parametrize("chunked", [True, False])
12-
@pytest.mark.parametrize("array", [np, pa])
19+
@pytest.mark.parametrize("array", ["numpy", "pyarrow"])
1320
def test_constructor_not_string_type_raises(array, chunked):
21+
import pyarrow as pa
22+
23+
array = pa if array == "pyarrow" else np
24+
1425
arr = array.array([1, 2, 3])
1526
if chunked:
1627
if array is np:
@@ -24,3 +35,20 @@ def test_constructor_not_string_type_raises(array, chunked):
2435
)
2536
with pytest.raises(ValueError, match=msg):
2637
ArrowStringArray(arr)
38+
39+
40+
@pytest.mark.skipif(
41+
not pa_version_under1p0,
42+
reason="pyarrow is installed",
43+
)
44+
def test_pyarrow_not_installed_raises():
45+
msg = re.escape("pyarrow>=1.0.0 is required for PyArrow backed StringArray")
46+
47+
with pytest.raises(ImportError, match=msg):
48+
ArrowStringDtype()
49+
50+
with pytest.raises(ImportError, match=msg):
51+
ArrowStringArray([])
52+
53+
with pytest.raises(ImportError, match=msg):
54+
ArrowStringArray._from_sequence(["a", None, "b"])

0 commit comments

Comments
 (0)