Skip to content

[ArrowStringArray] ENH: raise an ImportError when trying to create an arrow string dtype if pyarrow is not installed #41732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
type_t,
)
from pandas.compat import (
pa_version_under1p0,
pa_version_under2p0,
pa_version_under3p0,
pa_version_under4p0,
)
from pandas.compat.pyarrow import pa_version_under1p0
from pandas.util._decorators import doc
from pandas.util._validators import validate_fillna_kwargs

Expand Down Expand Up @@ -55,31 +55,33 @@
)
from pandas.core.strings.object_array import ObjectStringArrayMixin

try:
# PyArrow backed StringArrays are available starting at 1.0.0, but this
# file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute
# and its compute functions existed. GH38801
if not pa_version_under1p0:
import pyarrow as pa
except ImportError:
pa = None
else:
# PyArrow backed StringArrays are available starting at 1.0.0, but this
# file is imported from even if pyarrow is < 1.0.0, before pyarrow.compute
# and its compute functions existed. GH38801
if not pa_version_under1p0:
import pyarrow.compute as pc

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}
import pyarrow.compute as pc

ARROW_CMP_FUNCS = {
"eq": pc.equal,
"ne": pc.not_equal,
"lt": pc.less,
"gt": pc.greater,
"le": pc.less_equal,
"ge": pc.greater_equal,
}


if TYPE_CHECKING:
from pandas import Series


def _chk_pyarrow_available() -> None:
if pa_version_under1p0:
msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray."
raise ImportError(msg)


@register_extension_dtype
class ArrowStringDtype(StringDtype):
"""
Expand Down Expand Up @@ -112,6 +114,9 @@ class ArrowStringDtype(StringDtype):
#: StringDtype.na_value uses pandas.NA
na_value = libmissing.NA

def __init__(self):
_chk_pyarrow_available()

@property
def type(self) -> type[str]:
return str
Expand Down Expand Up @@ -213,10 +218,8 @@ class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin):
Length: 4, dtype: arrow_string
"""

_dtype = ArrowStringDtype()

def __init__(self, values):
self._chk_pyarrow_available()
self._dtype = ArrowStringDtype()
if isinstance(values, pa.Array):
self._data = pa.chunked_array([values])
elif isinstance(values, pa.ChunkedArray):
Expand All @@ -229,19 +232,11 @@ def __init__(self, values):
"ArrowStringArray requires a PyArrow (chunked) array of string type"
)

@classmethod
def _chk_pyarrow_available(cls) -> None:
# TODO: maybe update import_optional_dependency to allow a minimum
# version to be specified rather than use the global minimum
if pa is None or pa_version_under1p0:
msg = "pyarrow>=1.0.0 is required for PyArrow backed StringArray."
raise ImportError(msg)

@classmethod
def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
from pandas.core.arrays.masked import BaseMaskedArray

cls._chk_pyarrow_available()
_chk_pyarrow_available()

if isinstance(scalars, BaseMaskedArray):
# avoid costly conversion to object dtype in ensure_string_array and
Expand Down
34 changes: 31 additions & 3 deletions pandas/tests/arrays/string_/test_string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
import numpy as np
import pytest

from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.compat import pa_version_under1p0

pa = pytest.importorskip("pyarrow", minversion="1.0.0")
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringDtype,
)


@pytest.mark.skipif(
pa_version_under1p0,
reason="pyarrow>=1.0.0 is required for PyArrow backed StringArray",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the numpy test still run?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._dtype = ArrowStringDtype() is defined in __init__ before the input is checked. creating an instance ArrowStringDtype now raises an ImportError, tested in the test added.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually we checked pyarrow was availble before with self._chk_pyarrow_available() but this module was skipped with pytest.importorskip("pyarrow", minversion="1.0.0")

)
@pytest.mark.parametrize("chunked", [True, False])
@pytest.mark.parametrize("array", [np, pa])
@pytest.mark.parametrize("array", ["numpy", "pyarrow"])
def test_constructor_not_string_type_raises(array, chunked):
import pyarrow as pa

array = pa if array == "pyarrow" else np

arr = array.array([1, 2, 3])
if chunked:
if array is np:
Expand All @@ -24,3 +35,20 @@ def test_constructor_not_string_type_raises(array, chunked):
)
with pytest.raises(ValueError, match=msg):
ArrowStringArray(arr)


@pytest.mark.skipif(
not pa_version_under1p0,
reason="pyarrow is installed",
)
def test_pyarrow_not_installed_raises():
msg = re.escape("pyarrow>=1.0.0 is required for PyArrow backed StringArray")

with pytest.raises(ImportError, match=msg):
ArrowStringDtype()

with pytest.raises(ImportError, match=msg):
ArrowStringArray([])

with pytest.raises(ImportError, match=msg):
ArrowStringArray._from_sequence(["a", None, "b"])