diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 9958be47267ee..04b5cc726a4d0 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -38,11 +38,11 @@ class ExtensionArray(object): * copy * _concat_same_type - Some additional methods are available to satisfy pandas' internal, private - block API: + An additional method and attribute is available to satisfy pandas' + internal, private block API. - * _can_hold_na * _formatting_values + * _can_hold_na Some methods require casting the ExtensionArray to an ndarray of Python objects with ``self.astype(object)``, which may be expensive. When @@ -399,7 +399,8 @@ def _values_for_factorize(self): Returns ------- values : ndarray - An array suitable for factoraization. This should maintain order + + An array suitable for factorization. This should maintain order and be a supported dtype (Float64, Int64, UInt64, String, Object). By default, the extension array is cast to object dtype. na_value : object @@ -422,7 +423,7 @@ def factorize(self, na_sentinel=-1): Returns ------- labels : ndarray - An interger NumPy array that's an indexer into the original + An integer NumPy array that's an indexer into the original ExtensionArray. uniques : ExtensionArray An ExtensionArray containing the unique values of `self`. @@ -566,16 +567,10 @@ def _concat_same_type(cls, to_concat): """ raise AbstractMethodError(cls) - @property - def _can_hold_na(self): - # type: () -> bool - """Whether your array can hold missing values. True by default. - - Notes - ----- - Setting this to false will optimize some operations like fillna. - """ - return True + # The _can_hold_na attribute tells pandas whether your array can + # hold missing values. True by default. Setting this to False will + # optimize some operations like fillna. + _can_hold_na = True @property def _ndarray_values(self): diff --git a/pandas/tests/extension/base/getitem.py b/pandas/tests/extension/base/getitem.py index ac156900671a6..2ed2ca444def4 100644 --- a/pandas/tests/extension/base/getitem.py +++ b/pandas/tests/extension/base/getitem.py @@ -134,8 +134,9 @@ def test_take(self, data, na_value, na_cmp): def test_take_empty(self, data, na_value, na_cmp): empty = data[:0] - result = empty.take([-1]) - na_cmp(result[0], na_value) + if data._can_hold_na: + result = empty.take([-1]) + na_cmp(result[0], na_value) with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"): empty.take([0, 1]) diff --git a/pandas/tests/extension/base/groupby.py b/pandas/tests/extension/base/groupby.py index a29ef2a509a63..baa9c4c0328fc 100644 --- a/pandas/tests/extension/base/groupby.py +++ b/pandas/tests/extension/base/groupby.py @@ -27,7 +27,10 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): _, index = pd.factorize(data_for_grouping, sort=True) # TODO(ExtensionIndex): remove astype index = pd.Index(index.astype(object), name="B") - expected = pd.Series([3, 1, 4], index=index, name="A") + if data_for_grouping._can_hold_na: + expected = pd.Series([3, 1, 4], index=index, name="A") + else: + expected = pd.Series([2, 3, 1, 4], index=index, name="A") if as_index: self.assert_series_equal(result, expected) else: @@ -41,16 +44,26 @@ def test_groupby_extension_no_sort(self, data_for_grouping): _, index = pd.factorize(data_for_grouping, sort=False) # TODO(ExtensionIndex): remove astype index = pd.Index(index.astype(object), name="B") - expected = pd.Series([1, 3, 4], index=index, name="A") + if data_for_grouping._can_hold_na: + expected = pd.Series([1, 3, 4], index=index, name="A") + else: + expected = pd.Series([1, 2, 3, 4], index=index, name="A") + self.assert_series_equal(result, expected) def test_groupby_extension_transform(self, data_for_grouping): valid = data_for_grouping[~data_for_grouping.isna()] - df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4], + if data_for_grouping._can_hold_na: + dfval = [1, 1, 3, 3, 1, 4] + exres = [3, 3, 2, 2, 3, 1] + else: + dfval = [1, 1, 2, 2, 3, 3, 1, 4] + exres = [3, 3, 2, 2, 2, 2, 3, 1] + df = pd.DataFrame({"A": dfval, "B": valid}) result = df.groupby("B").A.transform(len) - expected = pd.Series([3, 3, 2, 2, 3, 1], name="A") + expected = pd.Series(exres, name="A") self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/base/missing.py b/pandas/tests/extension/base/missing.py index f6cee9af0b722..32cf29818e069 100644 --- a/pandas/tests/extension/base/missing.py +++ b/pandas/tests/extension/base/missing.py @@ -9,10 +9,7 @@ class BaseMissingTests(BaseExtensionTests): def test_isna(self, data_missing): - if data_missing._can_hold_na: - expected = np.array([True, False]) - else: - expected = np.array([False, False]) + expected = np.array([True, False]) result = pd.isna(data_missing) tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/extension/conftest.py b/pandas/tests/extension/conftest.py index 4cb4ea21d9be3..5bbec6810c71f 100644 --- a/pandas/tests/extension/conftest.py +++ b/pandas/tests/extension/conftest.py @@ -17,7 +17,9 @@ def data(): @pytest.fixture def data_missing(): - """Length-2 array with [NA, Valid]""" + """Length-2 array with [NA, Valid] + Use pytest.skip() if _can_hold_na==False + """ raise NotImplementedError @@ -46,6 +48,7 @@ def data_missing_for_sorting(): This should be three items [B, NA, A] with A < B and NA missing. + Use pytest.skip() if _can_hold_na==False """ raise NotImplementedError @@ -57,7 +60,7 @@ def na_cmp(): Should return a function of two arguments that returns True if both arguments are (scalar) NA for your type. - By default, uses ``operator.or`` + By default, uses ``operator.is_`` """ return operator.is_ @@ -75,5 +78,10 @@ def data_for_grouping(): Expected to be like [B, B, NA, NA, A, A, B, C] Where A < B < C and NA is missing + + If _can_hold_na==False, use a 4th value D for NA, + + where D < A < B < C + """ raise NotImplementedError diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 5d749126e0cec..d859e46d2d88c 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -103,5 +103,25 @@ def _concat_same_type(cls, to_concat): return cls(np.concatenate([x._data for x in to_concat])) +class DecimalNoNaArray(DecimalArray): + + _can_hold_na = False + + def isna(self): + return np.array([False] * len(self._data)) + + @property + def _na_value(self): + raise ValueError("No NA value for DecimalNoNaArray") + + def take(self, indexer, allow_fill=False, fill_value=None): + indexer = np.asarray(indexer) + + indexer = _ensure_platform_int(indexer) + out = self._data.take(indexer) + + return type(self)(out) + + def make_data(): return [decimal.Decimal(random.random()) for _ in range(100)] diff --git a/pandas/tests/extension/decimal/test_nona.py b/pandas/tests/extension/decimal/test_nona.py new file mode 100644 index 0000000000000..b2cfb2b8bf907 --- /dev/null +++ b/pandas/tests/extension/decimal/test_nona.py @@ -0,0 +1,135 @@ +import decimal + +import numpy as np +import pandas as pd +import pandas.util.testing as tm +import pytest + +from pandas.tests.extension import base + +from .array import DecimalDtype, DecimalNoNaArray, make_data +from .test_decimal import ( + BaseDecimal, TestDtype, TestInterface, TestConstructors, + TestReshaping, TestGetitem, TestMissing, TestCasting, + TestGroupby) + + +@pytest.fixture +def dtype(): + return DecimalDtype() + + +@pytest.fixture +def data(): + return DecimalNoNaArray(make_data()) + + +@pytest.fixture +def data_missing(): + pytest.skip("No missing data tests for _can_hold_na=False") + + +@pytest.fixture +def data_for_sorting(): + return DecimalNoNaArray([decimal.Decimal('1'), + decimal.Decimal('2'), + decimal.Decimal('0')]) + + +@pytest.fixture +def data_missing_for_sorting(): + pytest.skip("No missing data tests for _can_hold_na=False") + + +@pytest.fixture +def na_cmp(): + pytest.skip("No missing data tests for _can_hold_na=False") + + +@pytest.fixture +def na_value(): + pytest.skip("No missing data tests for _can_hold_na=False") + + +@pytest.fixture +def data_for_grouping(): + b = decimal.Decimal('1.0') + a = decimal.Decimal('0.0') + c = decimal.Decimal('2.0') + d = decimal.Decimal('-1.0') + return DecimalNoNaArray([b, b, d, d, a, a, b, c]) + + +class TestNoNaDtype(TestDtype): + pass + + +class TestNoNaInterface(TestInterface): + pass + + +class TestNoNaConstructors(TestConstructors): + pass + + +class TestNoNaReshaping(TestReshaping): + pass + + +class TestNoNaGetitem(TestGetitem): + pass + + +class TestNoNaMissing(TestMissing): + pass + + +class TestNoNaMethods(BaseDecimal, base.BaseMethodsTests): + def test_factorize(self, data_for_grouping, na_sentinel=None): + labels, uniques = pd.factorize(data_for_grouping) + expected_labels = np.array([0, 0, 1, + 1, 2, 2, 0, 3], + dtype=np.intp) + expected_uniques = data_for_grouping.take([0, 2, 4, 7]) + + tm.assert_numpy_array_equal(labels, expected_labels) + self.assert_extension_array_equal(uniques, expected_uniques) + + +class TestNoNaCasting(TestCasting): + pass + + +class TestNoNaGroupby(TestGroupby): + pass + + +def test_series_constructor_with_same_dtype_ok(): + arr = DecimalNoNaArray([decimal.Decimal('10.0')]) + result = pd.Series(arr, dtype=DecimalDtype()) + expected = pd.Series(arr) + tm.assert_series_equal(result, expected) + + +def test_series_constructor_coerce_extension_array_to_dtype_raises(): + arr = DecimalNoNaArray([decimal.Decimal('10.0')]) + xpr = r"Cannot specify a dtype 'int64' .* \('decimal'\)." + + with tm.assert_raises_regex(ValueError, xpr): + pd.Series(arr, dtype='int64') + + +def test_dataframe_constructor_with_same_dtype_ok(): + arr = DecimalNoNaArray([decimal.Decimal('10.0')]) + + result = pd.DataFrame({"A": arr}, dtype=DecimalDtype()) + expected = pd.DataFrame({"A": arr}) + tm.assert_frame_equal(result, expected) + + +def test_dataframe_constructor_with_different_dtype_raises(): + arr = DecimalNoNaArray([decimal.Decimal('10.0')]) + + xpr = "Cannot coerce extension array to dtype 'int64'. " + with tm.assert_raises_regex(ValueError, xpr): + pd.DataFrame({"A": arr}, dtype='int64')