diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 219c52c4a65b9..d6d7743f3f5f3 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -23,11 +23,11 @@ type_t, ) from pandas.compat import ( + pa_version_under1p0, pa_version_under2p0, pa_version_under3p0, pa_version_under4p0, ) -from pandas.compat.pyarrow import pa_version_under1p0 from pandas.util._decorators import doc from pandas.util._validators import validate_fillna_kwargs @@ -55,31 +55,33 @@ ) from pandas.core.strings.object_array import ObjectStringArrayMixin -try: +# PyArrow backed StringArrays are available starting at 1.0.0, but this +# file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute +# and its compute functions existed. GH38801 +if not pa_version_under1p0: import pyarrow as pa -except ImportError: - pa = None -else: - # PyArrow backed StringArrays are available starting at 1.0.0, but this - # file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute - # and its compute functions existed. GH38801 - if not pa_version_under1p0: - import pyarrow.compute as pc - - ARROW_CMP_FUNCS = { - "eq": pc.equal, - "ne": pc.not_equal, - "lt": pc.less, - "gt": pc.greater, - "le": pc.less_equal, - "ge": pc.greater_equal, - } + import pyarrow.compute as pc + + ARROW_CMP_FUNCS = { + "eq": pc.equal, + "ne": pc.not_equal, + "lt": pc.less, + "gt": pc.greater, + "le": pc.less_equal, + "ge": pc.greater_equal, + } if TYPE_CHECKING: from pandas import Series +def _chk_pyarrow_available() -> None: + if pa_version_under1p0: + msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray." + raise ImportError(msg) + + @register_extension_dtype class ArrowStringDtype(StringDtype): """ @@ -112,6 +114,9 @@ class ArrowStringDtype(StringDtype): #: StringDtype.na_value uses pandas.NA na_value = libmissing.NA + def __init__(self): + _chk_pyarrow_available() + @property def type(self) -> type[str]: return str @@ -213,10 +218,8 @@ class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin): Length: 4, dtype: arrow_string """ - _dtype = ArrowStringDtype() - def __init__(self, values): - self._chk_pyarrow_available() + self._dtype = ArrowStringDtype() if isinstance(values, pa.Array): self._data = pa.chunked_array([values]) elif isinstance(values, pa.ChunkedArray): @@ -229,19 +232,11 @@ def __init__(self, values): "ArrowStringArray requires a PyArrow (chunked) array of string type" ) - @classmethod - def _chk_pyarrow_available(cls) -> None: - # TODO: maybe update import_optional_dependency to allow a minimum - # version to be specified rather than use the global minimum - if pa is None or pa_version_under1p0: - msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray." - raise ImportError(msg) - @classmethod def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False): from pandas.core.arrays.masked import BaseMaskedArray - cls._chk_pyarrow_available() + _chk_pyarrow_available() if isinstance(scalars, BaseMaskedArray): # avoid costly conversion to object dtype in ensure_string_array and diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index ec7f57940a67f..3db8333798e36 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -3,14 +3,25 @@ import numpy as np import pytest -from pandas.core.arrays.string_arrow import ArrowStringArray +from pandas.compat import pa_version_under1p0 -pa = pytest.importorskip("pyarrow", minversion="1.0.0") +from pandas.core.arrays.string_arrow import ( + ArrowStringArray, + ArrowStringDtype, +) +@pytest.mark.skipif( + pa_version_under1p0, + reason="pyarrow>=1.0.0 is required for PyArrow backed StringArray", +) @pytest.mark.parametrize("chunked", [True, False]) -@pytest.mark.parametrize("array", [np, pa]) +@pytest.mark.parametrize("array", ["numpy", "pyarrow"]) def test_constructor_not_string_type_raises(array, chunked): + import pyarrow as pa + + array = pa if array == "pyarrow" else np + arr = array.array([1, 2, 3]) if chunked: if array is np: @@ -24,3 +35,20 @@ def test_constructor_not_string_type_raises(array, chunked): ) with pytest.raises(ValueError, match=msg): ArrowStringArray(arr) + + +@pytest.mark.skipif( + not pa_version_under1p0, + reason="pyarrow is installed", +) +def test_pyarrow_not_installed_raises(): + msg = re.escape("pyarrow>=1.0.0 is required for PyArrow backed StringArray") + + with pytest.raises(ImportError, match=msg): + ArrowStringDtype() + + with pytest.raises(ImportError, match=msg): + ArrowStringArray([]) + + with pytest.raises(ImportError, match=msg): + ArrowStringArray._from_sequence(["a", None, "b"])