Skip to content

Commit 98809e7

Browse files
committed
TST: Arrow-backed BoolArray
1 parent 805c7c2 commit 98809e7

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed

pandas/tests/extension/arrow/__init__.py

Whitespace-only changes.

pandas/tests/extension/arrow/bool.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import copy
2+
import itertools
3+
4+
import numpy as np
5+
import pyarrow as pa
6+
import pandas as pd
7+
from pandas.api.extensions import (
8+
ExtensionDtype, ExtensionArray
9+
)
10+
11+
12+
# @register_extension_dtype
13+
class ArrowBoolDtype(ExtensionDtype):
14+
15+
type = np.bool_
16+
name = 'arrow_bool'
17+
na_value = pa.NULL
18+
19+
@classmethod
20+
def construct_from_string(cls, string):
21+
if string == cls.name:
22+
return cls()
23+
else:
24+
raise TypeError("Cannot construct a '{}' from "
25+
"'{}'".format(cls, string))
26+
27+
@classmethod
28+
def construct_array_type(cls):
29+
return ArrowBoolArray
30+
31+
32+
class ArrowBoolArray(ExtensionArray):
33+
def __init__(self, values):
34+
if not isinstance(values, pa.ChunkedArray):
35+
raise ValueError
36+
37+
assert values.type == pa.bool_()
38+
self._data = values
39+
self._dtype = ArrowBoolDtype()
40+
41+
def __repr__(self):
42+
return "ArrowBoolArray({})".format(repr(self._data))
43+
44+
@classmethod
45+
def from_scalars(cls, values):
46+
arr = pa.chunked_array([pa.array(np.asarray(values))])
47+
return cls(arr)
48+
49+
@classmethod
50+
def from_array(cls, arr):
51+
assert isinstance(arr, pa.Array)
52+
return cls(pa.chunked_array([arr]))
53+
54+
@classmethod
55+
def _from_sequence(cls, scalars, dtype=None, copy=False):
56+
return cls.from_scalars(scalars)
57+
58+
def __getitem__(self, item):
59+
return self._data.to_pandas()[item]
60+
61+
def __len__(self):
62+
return len(self._data)
63+
64+
@property
65+
def dtype(self):
66+
return self._dtype
67+
68+
@property
69+
def nbytes(self):
70+
return sum(x.size for chunk in self._data.chunks
71+
for x in chunk.buffers()
72+
if x is not None)
73+
74+
def isna(self):
75+
return pd.isna(self._data.to_pandas())
76+
77+
def take(self, indices, allow_fill=False, fill_value=None):
78+
from pandas.core.algorithms import take
79+
data = self._data.to_pandas()
80+
81+
if allow_fill and fill_value is None:
82+
fill_value = self.dtype.na_value
83+
84+
result = take(data, indices, fill_value=fill_value,
85+
allow_fill=allow_fill)
86+
return self._from_sequence(result, dtype=self.dtype)
87+
88+
def copy(self, deep=False):
89+
if deep:
90+
return copy.deepcopy(self._data)
91+
else:
92+
return copy.copy(self._data)
93+
94+
def _concat_same_type(cls, to_concat):
95+
chunks = list(itertools.chain.from_iterable(x._data.chunks
96+
for x in to_concat))
97+
arr = pa.chunked_array(chunks)
98+
return cls(arr)
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import numpy as np
2+
import pytest
3+
import pandas as pd
4+
import pandas.util.testing as tm
5+
from pandas.tests.extension import base
6+
7+
pytest.importorskip('pyarrow')
8+
9+
from .bool import ArrowBoolDtype, ArrowBoolArray
10+
11+
12+
@pytest.fixture
13+
def dtype():
14+
return ArrowBoolDtype()
15+
16+
17+
@pytest.fixture
18+
def data():
19+
return ArrowBoolArray.from_scalars(np.random.randint(0, 2, size=100,
20+
dtype=bool))
21+
22+
23+
class BaseArrowTests(object):
24+
pass
25+
26+
27+
class TestDtype(BaseArrowTests, base.BaseDtypeTests):
28+
pass
29+
30+
31+
class TestInterface(BaseArrowTests, base.BaseInterfaceTests):
32+
def test_repr(self, data):
33+
raise pytest.skip("TODO")
34+
35+
36+
class TestConstructors(BaseArrowTests, base.BaseConstructorsTests):
37+
def test_from_dtype(self, data):
38+
pytest.skip("GH-22666")
39+
40+
41+
def test_is_bool_dtype(data):
42+
assert pd.api.types.is_bool_dtype(data)
43+
assert pd.core.common.is_bool_indexer(data)
44+
s = pd.Series(range(len(data)))
45+
result = s[data]
46+
expected = s[np.asarray(data)]
47+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)