Skip to content

Commit b193658

Browse files
committed
add extension tests for numpy_bool
1 parent 1c49fb0 commit b193658

File tree

9 files changed

+172
-39
lines changed

9 files changed

+172
-39
lines changed

pandas/conftest.py

+14
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,20 @@ def all_numeric_reductions(request):
159159
return request.param
160160

161161

162+
_all_numeric_reductions_for_boolean = ['min', 'max', 'mean', 'prod',
163+
'std', 'var', 'median',
164+
'kurt', 'skew']
165+
166+
167+
@pytest.fixture(params=_all_numeric_reductions_for_boolean)
168+
def all_numeric_reductions_for_boolean(request):
169+
"""
170+
Fixture for numeric reduction names that are not allowed
171+
for boolean.
172+
"""
173+
return request.param
174+
175+
162176
_all_boolean_reductions = ['all', 'any']
163177

164178

pandas/core/arrays/integer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
is_bool_dtype, is_float, is_float_dtype, is_integer, is_integer_dtype,
1515
is_list_like, is_object_dtype, is_scalar)
1616
from pandas.core.dtypes.dtypes import register_extension_dtype
17-
from pandas.core.dtypes.generic import ABCIndexClass, ABCMaskArray, ABCSeries
17+
from pandas.core.dtypes.generic import ABCIndexClass, ABCSeries
1818
from pandas.core.dtypes.missing import isna, notna
1919

2020
from pandas.core import nanops
@@ -288,8 +288,7 @@ def __init__(self, values, mask, copy=False):
288288
and is_integer_dtype(values.dtype)):
289289
raise TypeError("values should be integer numpy array. Use "
290290
"the 'integer_array' function instead")
291-
if not (isinstance(mask, (np.ndarray, ABCMaskArray)) and
292-
is_bool_dtype(mask.dtype)):
291+
if not is_bool_dtype(mask):
293292
raise TypeError("mask should be boolean numpy array. Use "
294293
"the 'integer_array' function instead")
295294

pandas/core/arrays/mask/_base.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A boolean mask interace.
1+
"""A boolean mask intferace.
22
33
This module provides an interface to a numpy / pyarrow boolean mask.
44
This is limited as not all of the implementations can hold NA, so
@@ -11,6 +11,7 @@
1111

1212
from pandas import compat
1313
from pandas.api.extensions import ExtensionDtype
14+
from pandas.api.types import is_scalar
1415
from pandas.core.arrays.base import ExtensionArray
1516
from pandas.core.missing import isna
1617

@@ -39,7 +40,7 @@ def __eq__(self, other):
3940
# compare == to np.dtype('bool')
4041
if isinstance(other, compat.string_types):
4142
return other == self.name
42-
elif other is self:
43+
elif isinstance(other, type(self)):
4344
return True
4445
elif isinstance(other, np.dtype):
4546
return other == 'bool'
@@ -55,10 +56,6 @@ class BoolArray(ExtensionArray):
5556
def _from_sequence(cls, scalars, dtype=None, copy=False):
5657
return cls.from_scalars(scalars)
5758

58-
@property
59-
def dtype(self):
60-
return self._dtype
61-
6259
@property
6360
def size(self):
6461
return len(self)
@@ -94,13 +91,21 @@ def __iand__(self, other):
9491
return type(self).from_scalars(
9592
np.array(self, copy=False) & (np.array(other, copy=False)))
9693

94+
def __getitem__(self, item):
95+
arr = np.array(self, copy=False)
96+
if is_scalar(item):
97+
return arr[item]
98+
else:
99+
arr = arr[item]
100+
return type(self).from_scalars(arr)
101+
97102
def view(self, dtype=None):
98103
arr = np.array(self._data, copy=False)
99104
if dtype is not None:
100105
arr = arr.view(dtype=dtype)
101106
return arr
102107

103-
def sum(self, axis=None):
108+
def sum(self, axis=None, min_count=None):
104109
return np.array(self, copy=False).sum()
105110

106111
def copy(self, deep=False):
@@ -120,9 +125,11 @@ def _reduce(self, method, skipna=True, **kwargs):
120125
arr = self[~self.isna()]
121126
else:
122127
arr = self
123-
128+
# we only allow explicity defined methods
129+
# ndarrays actually support: mean, var, prod, min, max
124130
try:
125131
op = getattr(arr, method)
132+
return op()
126133
except AttributeError:
127-
raise TypeError
128-
return op(**kwargs)
134+
pass
135+
raise TypeError

pandas/core/arrays/mask/_numpy.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class NumpyBoolArray(BoolArray):
2121
"""Generic class which can be used to represent missing data.
2222
"""
2323

24+
dtype = NumpyBoolDtype()
25+
2426
@classmethod
2527
def from_scalars(cls, values):
2628
arr = np.asarray(values).astype(np.bool_, copy=False)
@@ -39,10 +41,6 @@ def __init__(self, mask, copy=True):
3941
if copy:
4042
mask = mask.copy()
4143
self._data = mask
42-
self._dtype = NumpyBoolDtype()
43-
44-
def __getitem__(self, key):
45-
return self._data[key]
4644

4745
def __setitem__(self, key, value):
4846
self._data[key] = value
@@ -60,8 +58,13 @@ def nbytes(self):
6058
def reshape(self, shape, **kwargs):
6159
return np.array(self, copy=False).reshape(shape, **kwargs)
6260

63-
def astype(self, dtype, copy=False):
64-
return np.array(self, copy=False).astype(dtype, copy=copy)
61+
def astype(self, dtype, copy=True):
62+
# needed to fix this astype for the Series constructor.
63+
if isinstance(dtype, type(self.dtype)) and dtype == self.dtype:
64+
if copy:
65+
return self.copy()
66+
return self
67+
return super(NumpyBoolArray, self).astype(dtype, copy)
6568

6669
def take(self, indices, allow_fill=False, fill_value=None, axis=None):
6770
# TODO: had to add axis here
@@ -73,3 +76,7 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None):
7376
result = take(data, indices, fill_value=fill_value,
7477
allow_fill=allow_fill)
7578
return self._from_sequence(result, dtype=self.dtype)
79+
80+
def _concat_same_type(cls, to_concat):
81+
concat = np.concatenate(to_concat)
82+
return cls.from_scalars(concat)

pandas/core/arrays/mask/_pyarrow.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import numpy as np
1212

1313
from pandas.api.extensions import take
14-
from pandas.api.types import is_scalar
1514
from pandas.core.arrays.mask._base import BoolArray, BoolDtype
1615

1716
# we require pyarrow >= 0.10.0
@@ -35,6 +34,8 @@ def construct_array_type(cls):
3534

3635
class ArrowBoolArray(BoolArray):
3736

37+
dtype = ArrowBoolDtype()
38+
3839
@classmethod
3940
def from_scalars(cls, values):
4041
values = np.asarray(values).astype(np.bool_, copy=False)
@@ -53,14 +54,6 @@ def __init__(self, values, copy=False):
5354
values = values.copy()
5455

5556
self._data = values
56-
self._dtype = ArrowBoolDtype()
57-
58-
def __getitem__(self, item):
59-
if is_scalar(item):
60-
return self._data.to_pandas()[item]
61-
else:
62-
vals = self._data.to_pandas()[item]
63-
return type(self).from_scalars(vals)
6457

6558
def __setitem__(self, key, value):
6659
# TODO: hack-a-minute

pandas/tests/arrays/mask/test_mask.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,29 @@
66

77

88
@pytest.fixture(params=['numpy', 'arrow', 'mask'])
9-
def mask_type(request):
9+
def mask_dtype(request):
10+
""" dtype type """
1011
if request.param == 'numpy':
11-
from pandas.core.arrays.mask._numpy import NumpyBoolArray
12-
return NumpyBoolArray
12+
from pandas.core.arrays.mask._numpy import NumpyBoolDtype
13+
return NumpyBoolDtype
1314
elif request.param == 'arrow':
1415
pytest.importorskip('pyarrow', minversion="0.10.0")
15-
from pandas.core.arrays.mask._pyarrow import ArrowBoolArray
16-
return ArrowBoolArray
16+
from pandas.core.arrays.mask._pyarrow import ArrowBoolDtype
17+
return ArrowBoolDtype
1718
elif request.param == 'mask':
1819
from pandas.core.arrays.mask import get_mask_array_type
19-
return get_mask_array_type()
20+
return type(get_mask_array_type().dtype)
21+
22+
23+
@pytest.fixture
24+
def mask_type(mask_dtype):
25+
""" array type """
26+
return mask_dtype.construct_array_type()
2027

2128

2229
@pytest.fixture
2330
def mask(mask_type):
31+
""" array object """
2432
return mask_type._from_sequence([1, 0, 1])
2533

2634

@@ -93,3 +101,10 @@ def test_functions(mask):
93101
assert (mask2 == mask).all()
94102

95103
assert mask.size == len(mask)
104+
105+
106+
def test_dtype(mask_dtype):
107+
m = mask_dtype()
108+
assert m == m
109+
assert m == mask_dtype()
110+
assert hash(m) is not None

pandas/tests/extension/mask/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""
2+
This file contains a minimal set of tests for compliance with the extension
3+
array interface test suite, and should contain no other tests.
4+
The test suite for the full functionality of the array is located in
5+
`pandas/tests/arrays/`.
6+
7+
The tests in this file are inherited from the BaseExtensionTests, and only
8+
minimal tweaks should be applied to get the tests passing (by overwriting a
9+
parent method).
10+
11+
Additional tests should either be added to one of the BaseExtensionTests
12+
classes (if they are relevant for the extension interface for all dtypes), or
13+
be added to the array-specific tests in `pandas/tests/arrays/`.
14+
15+
"""
16+
import numpy as np
17+
import pytest
18+
19+
import pandas as pd
20+
from pandas.tests.extension import base
21+
import pandas.util.testing as tm
22+
23+
from pandas.core.arrays.mask._numpy import (
24+
NumpyBoolArray, NumpyBoolDtype)
25+
26+
27+
@pytest.fixture
28+
def dtype():
29+
return NumpyBoolDtype()
30+
31+
32+
@pytest.fixture
33+
def data():
34+
return NumpyBoolArray.from_scalars(np.random.randint(0, 2, size=100,
35+
dtype=bool))
36+
37+
38+
@pytest.fixture
39+
def data_missing():
40+
pytest.skip("not supported in NumpyBoolArray")
41+
42+
43+
class BaseNumpyTests(object):
44+
pass
45+
46+
47+
class TestDtype(BaseNumpyTests, base.BaseDtypeTests):
48+
def test_array_type_with_arg(self, data, dtype):
49+
pytest.skip("GH-22666")
50+
51+
52+
class TestInterface(BaseNumpyTests, base.BaseInterfaceTests):
53+
def test_repr(self, data):
54+
raise pytest.skip("TODO")
55+
56+
57+
class TestConstructors(BaseNumpyTests, base.BaseConstructorsTests):
58+
def test_from_dtype(self, data):
59+
pytest.skip("GH-22666")
60+
61+
62+
class TestReduceBoolean(base.BaseBooleanReduceTests):
63+
64+
@pytest.mark.parametrize('skipna', [True, False])
65+
def test_reduce_series_numeric(
66+
self, data, all_numeric_reductions_for_boolean, skipna):
67+
op_name = all_numeric_reductions_for_boolean
68+
s = pd.Series(data)
69+
70+
with pytest.raises(TypeError):
71+
getattr(s, op_name)(skipna=skipna)
72+
73+
74+
def test_is_bool_dtype(data):
75+
assert pd.api.types.is_bool_dtype(data)
76+
assert pd.core.common.is_bool_indexer(data)
77+
s = pd.Series(range(len(data)))
78+
result = s[data]
79+
expected = s[np.asarray(data)]
80+
tm.assert_series_equal(result, expected)

pandas/tests/arrays/mask/test_pyarrow_bool.py renamed to pandas/tests/extension/mask/test_pyarrow_bool.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
"""
2+
This file contains a minimal set of tests for compliance with the extension
3+
array interface test suite, and should contain no other tests.
4+
The test suite for the full functionality of the array is located in
5+
`pandas/tests/arrays/`.
6+
7+
The tests in this file are inherited from the BaseExtensionTests, and only
8+
minimal tweaks should be applied to get the tests passing (by overwriting a
9+
parent method).
10+
11+
Additional tests should either be added to one of the BaseExtensionTests
12+
classes (if they are relevant for the extension interface for all dtypes), or
13+
be added to the array-specific tests in `pandas/tests/arrays/`.
14+
15+
"""
116
import numpy as np
217
import pytest
318

@@ -46,13 +61,16 @@ def test_from_dtype(self, data):
4661
pytest.skip("GH-22666")
4762

4863

49-
class TestReduce(base.BaseNoReduceTests):
50-
def test_reduce_series_boolean(self):
51-
pass
64+
class TestReduceBoolean(base.BaseBooleanReduceTests):
5265

66+
@pytest.mark.parametrize('skipna', [True, False])
67+
def test_reduce_series_numeric(
68+
self, data, all_numeric_reductions_for_boolean, skipna):
69+
op_name = all_numeric_reductions_for_boolean
70+
s = pd.Series(data)
5371

54-
class TestReduceBoolean(base.BaseBooleanReduceTests):
55-
pass
72+
with pytest.raises(TypeError):
73+
getattr(s, op_name)(skipna=skipna)
5674

5775

5876
def test_is_bool_dtype(data):

0 commit comments

Comments
 (0)