diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index bb910b687cd0f..ef4e9be55d597 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -38,11 +38,21 @@ class ExtensionArray(object): * copy * _concat_same_type - Some additional methods are available to satisfy pandas' internal, private - block API: + An additional method and attribute is available to satisfy pandas' + internal, private block API. - * _can_hold_na * _formatting_values + * _can_hold_na + + Some methods require casting the ExtensionArray to an ndarray of Python + objects with ``self.astype(object)``, which may be expensive. When + performance is a concern, we highly recommend overriding the following + methods: + + * fillna + * unique + * factorize / _values_for_factorize + * argsort / _values_for_argsort Some methods require casting the ExtensionArray to an ndarray of Python objects with ``self.astype(object)``, which may be expensive. When @@ -393,7 +403,8 @@ def _values_for_factorize(self): Returns ------- values : ndarray - An array suitable for factoraization. This should maintain order + + An array suitable for factorization. This should maintain order and be a supported dtype (Float64, Int64, UInt64, String, Object). By default, the extension array is cast to object dtype. na_value : object @@ -416,7 +427,7 @@ def factorize(self, na_sentinel=-1): Returns ------- labels : ndarray - An interger NumPy array that's an indexer into the original + An integer NumPy array that's an indexer into the original ExtensionArray. uniques : ExtensionArray An ExtensionArray containing the unique values of `self`. @@ -560,16 +571,13 @@ def _concat_same_type(cls, to_concat): """ raise AbstractMethodError(cls) - @property - def _can_hold_na(self): - # type: () -> bool - """Whether your array can hold missing values. True by default. + _can_hold_na = True + """Whether your array can hold missing values. True by default. - Notes - ----- - Setting this to false will optimize some operations like fillna. - """ - return True + Notes + ----- + Setting this to False will optimize some operations like fillna. + """ @property def _ndarray_values(self): diff --git a/pandas/core/ops.py b/pandas/core/ops.py index e14f82906cd06..6a45033254969 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -6,6 +6,7 @@ # necessary to enforce truediv in Python 2.X from __future__ import division import operator +import inspect import numpy as np import pandas as pd @@ -30,6 +31,7 @@ is_bool_dtype, is_list_like, is_scalar, + is_extension_array_dtype, _ensure_object) from pandas.core.dtypes.cast import ( maybe_upcast_putmask, find_common_type, @@ -1208,6 +1210,78 @@ def wrapper(self, other, axis=None): return self._constructor(res_values, index=self.index, name=res_name) + elif is_extension_array_dtype(self): + values = self.values + + # GH 20659 + # The idea here is as follows. First we see if the op is + # defined in the ExtensionArray subclass, and returns a + # result that is not NotImplemented. If so, we use that + # result. If that fails, then we try to see if the underlying + # dtype has implemented the op, and if not, then we try to + # put the values into a numpy array to see if the optimized + # comparisons will work. Either way, then we try an + # element by element comparison + + method = getattr(values, op_name, None) + # First see if the extension array object supports the op + res = NotImplemented + if inspect.ismethod(method): + try: + res = method(other) + except TypeError: + pass + except Exception as e: + raise e + + isfunc = inspect.isfunction(getattr(self.dtype.type, + op_name, None)) + + if res is NotImplemented and not isfunc: + # Try the fast implementation, i.e., see if passing a + # numpy array will allow the optimized comparison to + # work + nvalues = self.get_values() + if isinstance(other, list): + other = np.asarray(other) + + try: + with np.errstate(all='ignore'): + res = na_op(nvalues, other) + except TypeError: + pass + except Exception as e: + raise e + + if res is NotImplemented: + # Try it on each element. Support comparing to another + # ExtensionArray, or something that is list like, or + # a single object. This allows a result of a comparison + # to be an object as opposed to a boolean + if is_extension_array_dtype(other): + ovalues = other.values + elif is_list_like(other): + ovalues = other + else: # Assume its an object + ovalues = [other] * len(self) + + # Get the method for each object. + res = [getattr(a, op_name, None)(b) + for (a, b) in zip(values, ovalues)] + + # We can't use (NotImplemented in res) because the + # results might be objects that have overridden __eq__ + if any([isinstance(r, type(NotImplemented)) for r in res]): + msg = "invalid type comparison between {one} and {two}" + raise TypeError(msg.format(one=type(values), + two=type(other))) + + # At this point we have the result + # always return a full value series here + res_values = com._values_from_object(res) + return self._constructor(res_values, index=self.index, + name=res_name) + elif isinstance(other, ABCSeries): # By this point we have checked that self._indexed_same(other) res_values = na_op(self.values, other.values) diff --git a/pandas/tests/extension/base/getitem.py b/pandas/tests/extension/base/getitem.py index ac156900671a6..5f53a64269283 100644 --- a/pandas/tests/extension/base/getitem.py +++ b/pandas/tests/extension/base/getitem.py @@ -82,8 +82,9 @@ def test_getitem_scalar(self, data): assert isinstance(result, data.dtype.type) def test_getitem_scalar_na(self, data_missing, na_cmp, na_value): - result = data_missing[0] - assert na_cmp(result, na_value) + if data_missing._can_hold_na: + result = data_missing[0] + assert na_cmp(result, na_value) def test_getitem_mask(self, data): # Empty mask, raw array @@ -134,8 +135,9 @@ def test_take(self, data, na_value, na_cmp): def test_take_empty(self, data, na_value, na_cmp): empty = data[:0] - result = empty.take([-1]) - na_cmp(result[0], na_value) + if data._can_hold_na: + result = empty.take([-1]) + na_cmp(result[0], na_value) with tm.assert_raises_regex(IndexError, "cannot do a non-empty take"): empty.take([0, 1]) diff --git a/pandas/tests/extension/base/groupby.py b/pandas/tests/extension/base/groupby.py index a29ef2a509a63..baa9c4c0328fc 100644 --- a/pandas/tests/extension/base/groupby.py +++ b/pandas/tests/extension/base/groupby.py @@ -27,7 +27,10 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): _, index = pd.factorize(data_for_grouping, sort=True) # TODO(ExtensionIndex): remove astype index = pd.Index(index.astype(object), name="B") - expected = pd.Series([3, 1, 4], index=index, name="A") + if data_for_grouping._can_hold_na: + expected = pd.Series([3, 1, 4], index=index, name="A") + else: + expected = pd.Series([2, 3, 1, 4], index=index, name="A") if as_index: self.assert_series_equal(result, expected) else: @@ -41,16 +44,26 @@ def test_groupby_extension_no_sort(self, data_for_grouping): _, index = pd.factorize(data_for_grouping, sort=False) # TODO(ExtensionIndex): remove astype index = pd.Index(index.astype(object), name="B") - expected = pd.Series([1, 3, 4], index=index, name="A") + if data_for_grouping._can_hold_na: + expected = pd.Series([1, 3, 4], index=index, name="A") + else: + expected = pd.Series([1, 2, 3, 4], index=index, name="A") + self.assert_series_equal(result, expected) def test_groupby_extension_transform(self, data_for_grouping): valid = data_for_grouping[~data_for_grouping.isna()] - df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4], + if data_for_grouping._can_hold_na: + dfval = [1, 1, 3, 3, 1, 4] + exres = [3, 3, 2, 2, 3, 1] + else: + dfval = [1, 1, 2, 2, 3, 3, 1, 4] + exres = [3, 3, 2, 2, 2, 2, 3, 1] + df = pd.DataFrame({"A": dfval, "B": valid}) result = df.groupby("B").A.transform(len) - expected = pd.Series([3, 3, 2, 2, 3, 1], name="A") + expected = pd.Series(exres, name="A") self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index c5436aa731d50..0dd8ee9057021 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -24,10 +24,11 @@ def test_value_counts(self, all_data, dropna): self.assert_series_equal(result, expected) def test_count(self, data_missing): - df = pd.DataFrame({"A": data_missing}) - result = df.count(axis='columns') - expected = pd.Series([0, 1]) - self.assert_series_equal(result, expected) + if data_missing._can_hold_na: + df = pd.DataFrame({"A": data_missing}) + result = df.count(axis='columns') + expected = pd.Series([0, 1]) + self.assert_series_equal(result, expected) def test_apply_simple_series(self, data): result = pd.Series(data).apply(id) @@ -40,7 +41,10 @@ def test_argsort(self, data_for_sorting): def test_argsort_missing(self, data_missing_for_sorting): result = pd.Series(data_missing_for_sorting).argsort() - expected = pd.Series(np.array([1, -1, 0], dtype=np.int64)) + if data_missing_for_sorting._can_hold_na: + expected = pd.Series(np.array([1, -1, 0], dtype=np.int64)) + else: + expected = pd.Series(np.array([1, 2, 0], dtype=np.int64)) self.assert_series_equal(result, expected) @pytest.mark.parametrize('ascending', [True, False]) @@ -58,7 +62,10 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending): ser = pd.Series(data_missing_for_sorting) result = ser.sort_values(ascending=ascending) if ascending: - expected = ser.iloc[[2, 0, 1]] + if data_missing_for_sorting._can_hold_na: + expected = ser.iloc[[2, 0, 1]] + else: + expected = ser.iloc[[1, 2, 0]] else: expected = ser.iloc[[0, 2, 1]] self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/base/missing.py b/pandas/tests/extension/base/missing.py index f6cee9af0b722..f09d9f3ba5d7b 100644 --- a/pandas/tests/extension/base/missing.py +++ b/pandas/tests/extension/base/missing.py @@ -24,7 +24,10 @@ def test_isna(self, data_missing): def test_dropna_series(self, data_missing): ser = pd.Series(data_missing) result = ser.dropna() - expected = ser.iloc[[1]] + if data_missing._can_hold_na: + expected = ser.iloc[[1]] + else: + expected = ser self.assert_series_equal(result, expected) def test_dropna_frame(self, data_missing): @@ -32,19 +35,28 @@ def test_dropna_frame(self, data_missing): # defaults result = df.dropna() - expected = df.iloc[[1]] + if data_missing._can_hold_na: + expected = df.iloc[[1]] + else: + expected = df self.assert_frame_equal(result, expected) # axis = 1 result = df.dropna(axis='columns') - expected = pd.DataFrame(index=[0, 1]) + if data_missing._can_hold_na: + expected = pd.DataFrame(index=[0, 1]) + else: + expected = df self.assert_frame_equal(result, expected) # multiple df = pd.DataFrame({"A": data_missing, "B": [1, np.nan]}) result = df.dropna() - expected = df.iloc[:0] + if data_missing._can_hold_na: + expected = df.iloc[:0] + else: + expected = df.iloc[:1] self.assert_frame_equal(result, expected) def test_fillna_scalar(self, data_missing): @@ -56,22 +68,32 @@ def test_fillna_scalar(self, data_missing): def test_fillna_limit_pad(self, data_missing): arr = data_missing.take([1, 0, 0, 0, 1]) result = pd.Series(arr).fillna(method='ffill', limit=2) - expected = pd.Series(data_missing.take([1, 1, 1, 0, 1])) + if data_missing._can_hold_na: + expected = pd.Series(data_missing.take([1, 1, 1, 0, 1])) + else: + expected = pd.Series(arr) self.assert_series_equal(result, expected) def test_fillna_limit_backfill(self, data_missing): arr = data_missing.take([1, 0, 0, 0, 1]) result = pd.Series(arr).fillna(method='backfill', limit=2) - expected = pd.Series(data_missing.take([1, 0, 1, 1, 1])) + if data_missing._can_hold_na: + expected = pd.Series(data_missing.take([1, 0, 1, 1, 1])) + else: + expected = pd.Series(arr) self.assert_series_equal(result, expected) def test_fillna_series(self, data_missing): + fill_value = data_missing[1] ser = pd.Series(data_missing) result = ser.fillna(fill_value) - expected = pd.Series( - data_missing._from_sequence([fill_value, fill_value])) + if data_missing._can_hold_na: + expected = pd.Series( + data_missing._from_sequence([fill_value, fill_value])) + else: + expected = ser self.assert_series_equal(result, expected) # Fill with a series @@ -90,8 +112,11 @@ def test_fillna_series_method(self, data_missing, method): data_missing = type(data_missing)(data_missing[::-1]) result = pd.Series(data_missing).fillna(method=method) - expected = pd.Series( - data_missing._from_sequence([fill_value, fill_value])) + if data_missing._can_hold_na: + expected = pd.Series( + data_missing._from_sequence([fill_value, fill_value])) + else: + expected = pd.Series(data_missing) self.assert_series_equal(result, expected) @@ -103,8 +128,13 @@ def test_fillna_frame(self, data_missing): "B": [1, 2] }).fillna(fill_value) + if data_missing._can_hold_na: + a = data_missing._from_sequence([fill_value, fill_value]) + else: + a = data_missing + expected = pd.DataFrame({ - "A": data_missing._from_sequence([fill_value, fill_value]), + "A": a, "B": [1, 2], }) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 53d74cd6d38cb..dd6d9e4221e39 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -1,4 +1,5 @@ import decimal +import operator import numpy as np import pandas as pd @@ -175,3 +176,49 @@ def test_dataframe_constructor_with_different_dtype_raises(): xpr = "Cannot coerce extension array to dtype 'int64'. " with tm.assert_raises_regex(ValueError, xpr): pd.DataFrame({"A": arr}, dtype='int64') + + +@pytest.mark.parametrize('op', ['lt', 'le', 'gt', 'ge', 'eq', 'ne']) +def test_comparisons(op): + arr1 = DecimalArray(make_data()) + arr2 = DecimalArray(make_data()) + ser1 = pd.Series(arr1) + ser2 = pd.Series(arr2) + func = getattr(operator, op) + + result = func(ser1, ser2) + expected = pd.Series([func(a, b) for (a, b) in zip(arr1, arr2)]) + tm.assert_series_equal(result, expected) + + oneval = decimal.Decimal('0.5') + result = func(ser1, oneval) + expected = pd.Series([func(a, oneval) for a in arr1]) + tm.assert_series_equal(result, expected) + + oneval = 0.5 + result = func(ser1, oneval) + expected = pd.Series([func(a, oneval) for a in arr1]) + tm.assert_series_equal(result, expected) + + alist = [i for i in arr2] + result = func(ser1, alist) + expected = pd.Series([func(a, b) for (a, b) in zip(arr1, alist)]) + tm.assert_series_equal(result, expected) + + alist = [float(i) for i in arr2] + result = func(ser1, alist) + expected = pd.Series([func(a, b) for (a, b) in zip(arr1, alist)]) + tm.assert_series_equal(result, expected) + + if op not in ['eq', 'ne']: + l2 = list(arr2) + l2[5] = 'abc' + with pytest.raises(TypeError) as excinfo: + func(ser1, "abc") + assert (str(excinfo.value) + .startswith("invalid type comparison between")) + + with pytest.raises(TypeError) as excinfo: + func(ser1, l2) + assert (str(excinfo.value) + .startswith("invalid type comparison between")) diff --git a/pandas/tests/extension/relobject/__init__.py b/pandas/tests/extension/relobject/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/extension/relobject/array.py b/pandas/tests/extension/relobject/array.py new file mode 100644 index 0000000000000..5a6dc3738d6d3 --- /dev/null +++ b/pandas/tests/extension/relobject/array.py @@ -0,0 +1,211 @@ +import sys +import random +import numbers + + +import numpy as np + +import pandas as pd +from pandas.core.arrays import ExtensionArray +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.common import _ensure_platform_int + +# Idea here is to test if we have an array of objects where +# relational operators are defined that return strings rather +# than booleans, but not for all of the operators. Also test +# if things work when _can_hold_na == False. +# NOTE: It is not expected that sorting and groupby() will +# work on these objects !! + + +def relation(left, op, right): + return str(left) + op + str(right) + + +class RelObj(object): + def __init__(self, value): + if type(value) == int: + self._value = value + else: + raise TypeError("Value should be int") + + def __eq__(self, other): + if isinstance(other, RelObj): + return relation(self, " == ", other) + else: + raise TypeError("Cannot compare RelObj to " + + "object of different type") + + def __le__(self, other): + if isinstance(other, RelObj): + return relation(self, " <= ", other) + else: + raise TypeError("Cannot compare RelObj to " + + "object of different type") + + def __ge__(self, other): + if isinstance(other, RelObj): + return relation(self, " >= ", other) + else: + raise TypeError("Cannot compare RelObj to " + + "object of different type") + + def __lt__(self, other): + raise Exception('lt not supported') + + def __gt__(self, other): + raise Exception('gt not supported') + + def __ne__(self, other): + raise Exception('ne not supported') + + @property + def value(self): + return self._value + + def __hash__(self): + return id(self) + + def __repr__(self): + return "RelObjValue: " + str(self._value) + + +class RelObjectDtype(ExtensionDtype): + type = RelObj + name = 'relobj' + kind = 'O' + + @classmethod + def construct_from_string(cls, string): + if string == cls.name: + return cls() + else: + raise TypeError("Cannot construct a '{}' from " + "'{}'".format(cls, string)) + + +class RelObjectArray(ExtensionArray): + dtype = RelObjectDtype() + + def __init__(self, values): + values = np.asarray(values, dtype=object) + + self.values = values + + @classmethod + def _from_sequence(cls, scalars): + return cls(scalars) + + def __getitem__(self, item): + if isinstance(item, numbers.Integral): + return self.values[item] + else: + return type(self)(self.values[item]) + + def copy(self, deep=False): + if deep: + return type(self)(self.values.copy()) + return type(self)(self) + + def __setitem__(self, key, value): + if pd.api.types.is_list_like(value): + value = [int(v) for v in value] + else: + value = int(value) + self.values[key] = value + + def __len__(self): + return len(self.values) + + def __repr__(self): + return 'RelObjArray({!r})'.format([i for i in self.values]) + + @property + def nbytes(self): + n = len(self) + if n: + return n * sys.getsizeof(self[0]) + return 0 + + def isna(self): + return np.full(len(self.values), False) + + def _values_for_argsort(self): + # type: () -> ndarray + """Return values for sorting. + + Returns + ------- + ndarray + The transformed values should maintain the ordering between values + within the array. + + See Also + -------- + ExtensionArray.argsort + """ + # Note: this is used in `ExtensionArray.argsort`. + # We will sort based on the value of the object + return np.array([ro.value for ro in self.values]) + + def unique(self): + """Compute unique values using the ID of the objects + + Cannot use pandas.unique() because it requires __eq__() + to return a boolean + + Returns + ------- + uniques : RelObjArray + """ + seen = set() + uniques = [x for x in self.values if x not in seen and not seen.add(x)] + return self._from_sequence(uniques) + + def factorize(self, na_sentinel=-1): + # type: (int) -> Tuple[ndarray, ExtensionArray] + """Encode the extension array as an enumerated type. + + Parameters + ---------- + na_sentinel : int, default -1 + Value to use in the `labels` array to indicate missing values. + + Returns + ------- + labels : ndarray + An integer NumPy array that's an indexer into the original + ExtensionArray. + uniques : ExtensionArray + An ExtensionArray containing the unique values of `self`. + + .. note:: + + uniques will *not* contain an entry for the NA value of + the ExtensionArray if there are any missing values present + in `self`. + """ + uniques = self.unique() + dun = {id(v): i for i, v in enumerate(uniques)} + labels = np.array([dun[id(v)] for v in self.values], dtype=np.intp) + return labels, uniques + + def take(self, indexer, allow_fill=True, fill_value=None): + indexer = np.asarray(indexer) + mask = indexer == -1 + + indexer = _ensure_platform_int(indexer) + out = self.values.take(indexer) + out[mask] = np.nan + + return type(self)(out) + + @classmethod + def _concat_same_type(cls, to_concat): + return cls(np.concatenate([x.values for x in to_concat])) + + _can_hold_na = False + + +def make_data(): + return [RelObj(random.randint(0, 2000)) for _ in range(100)] diff --git a/pandas/tests/extension/relobject/test_relobject.py b/pandas/tests/extension/relobject/test_relobject.py new file mode 100644 index 0000000000000..7977ad4926e90 --- /dev/null +++ b/pandas/tests/extension/relobject/test_relobject.py @@ -0,0 +1,268 @@ +import operator + +import numpy as np +import pandas as pd +import pandas.util.testing as tm +import pytest + +from pandas.tests.extension import base + +from .array import RelObjectDtype, RelObjectArray, RelObj, make_data + + +@pytest.fixture +def dtype(): + return RelObjectDtype() + + +@pytest.fixture +def data(): + return RelObjectArray(make_data()) + + +@pytest.fixture +def data_missing(): + # Since _can_hold_na is False, we don't have missing values + return RelObjectArray([RelObj(10), RelObj(5)]) + + +@pytest.fixture +def data_for_sorting(): + return RelObjectArray([RelObj(4), RelObj(5), RelObj(3)]) + + +@pytest.fixture +def data_missing_for_sorting(): + # Since _can_hold_na is False, we don't have missing values + # Tests assume middle value is smallest + return RelObjectArray([RelObj(1), RelObj(-1), RelObj(0)]) + + +@pytest.fixture +def na_cmp(): + return lambda x, y: x is np.nan and y is np.nan + + +@pytest.fixture +def na_value(): + return np.nan + + +@pytest.fixture +def data_for_grouping(): + b = RelObj(1) + a = RelObj(0) + c = RelObj(2) + d = RelObj(-1) + return RelObjectArray([b, b, d, d, a, a, b, c]) + + +class BaseRelObject(object): + + def assert_series_equal(self, left, right, *args, **kwargs): + + result = tm.assert_series_equal(left, right, + *args, **kwargs) + if result: + diff = 0 + for l, r in zip(left.values, right.values): + if l is not r: + diff += 1 + if diff > 0: + diff = diff * 100.0 / len(left) + obj = 'RelObjArray' + msg = '{obj} values are different ({pct} %)'.format( + obj='RelObjArray', pct=np.round(diff, 5)) + tm.raise_assert_detail(obj, msg, left, right) + return result + + def assert_frame_equal(self, left, right, *args, **kwargs): + relobjs = (left.dtypes == 'relobj').index + + for col in relobjs: + self.assert_series_equal(left[col], right[col], + *args, **kwargs) + + left = left.drop(columns=relobjs) + right = right.drop(columns=relobjs) + tm.assert_frame_equal(left, right, *args, **kwargs) + + +class TestDtype(BaseRelObject, base.BaseDtypeTests): + pass + + +class TestInterface(BaseRelObject, base.BaseInterfaceTests): + pass + + +class TestConstructors(BaseRelObject, base.BaseConstructorsTests): + pass + + +class TestReshaping(BaseRelObject, base.BaseReshapingTests): + pass + + +class TestGetitem(BaseRelObject, base.BaseGetitemTests): + pass + + +class TestMissing(BaseRelObject, base.BaseMissingTests): + pass + + +class TestMethods(BaseRelObject, base.BaseMethodsTests): + @pytest.mark.parametrize('dropna', [True, False]) + @pytest.mark.xfail(reason="value_counts not implemented yet.") + def test_value_counts(self, all_data, dropna): + all_data = all_data[:10] + other = all_data + + result = pd.Series(all_data).value_counts(dropna=dropna).sort_index() + expected = pd.Series(other).value_counts(dropna=dropna).sort_index() + + tm.assert_series_equal(result, expected) + + @pytest.mark.xfail(reason="sorting not appropriate") + def test_argsort(self, data_for_sorting): + pass + + @pytest.mark.xfail(reason="sorting not appropriate") + def test_argsort_missing(self, data_missing_for_sorting): + pass + + @pytest.mark.parametrize('ascending', [True, False]) + @pytest.mark.xfail(reason="sorting not appropriate") + def test_sort_values(self, data_for_sorting, ascending): + pass + + @pytest.mark.parametrize('ascending', [True, False]) + @pytest.mark.xfail(reason="sorting not appropriate") + def test_sort_values_missing(self, data_missing_for_sorting, ascending): + pass + + @pytest.mark.parametrize('ascending', [True, False]) + @pytest.mark.xfail(reason="sorting not appropriate") + def test_sort_values_frame(self, data_for_sorting, ascending): + pass + + def test_factorize(self, data_for_grouping, na_sentinel=None): + labels, uniques = pd.factorize(data_for_grouping, + na_sentinel=na_sentinel) + expected_labels = np.array([0, 0, 1, + 1, 2, 2, 0, 3], + dtype=np.intp) + expected_uniques = data_for_grouping.take([0, 2, 4, 7]) + + tm.assert_numpy_array_equal(labels, expected_labels) + self.assert_extension_array_equal(uniques, expected_uniques) + + +class TestCasting(BaseRelObject, base.BaseCastingTests): + pass + + +class TestGroupby(BaseRelObject, base.BaseGroupbyTests): + + @pytest.mark.xfail(reason="transform fails when __eq__ returns obj") + def test_groupby_extension_transform(self, data_for_grouping): + pass + + @pytest.mark.xfail(reason="apply fails when __eq__ returns obj") + def test_groupby_extension_apply(self, data_for_grouping, op): + pass + + +def test_series_constructor_coerce_data_to_extension_dtype_raises(): + xpr = ("Cannot cast data to extension dtype 'relobj'. Pass the " + "extension array directly.") + with tm.assert_raises_regex(ValueError, xpr): + pd.Series([0, 1, 2], dtype=RelObjectDtype()) + + +def test_series_constructor_with_same_dtype_ok(): + arr = RelObjectArray([10]) + result = pd.Series(arr, dtype=RelObjectDtype()) + expected = pd.Series(arr) + tm.assert_series_equal(result, expected) + + +def test_series_constructor_coerce_extension_array_to_dtype_raises(): + arr = RelObjectArray([10]) + xpr = r"Cannot specify a dtype 'float.* \('relobj'\)." + + with tm.assert_raises_regex(ValueError, xpr): + pd.Series(arr, dtype='float') + + +def test_dataframe_constructor_with_same_dtype_ok(): + arr = RelObjectArray([10]) + + result = pd.DataFrame({"A": arr}, dtype=RelObjectDtype()) + expected = pd.DataFrame({"A": arr}) + tm.assert_frame_equal(result, expected) + + +def test_dataframe_constructor_with_different_dtype_raises(): + arr = RelObjectArray([10]) + + xpr = "Cannot coerce extension array to dtype 'float" + with tm.assert_raises_regex(ValueError, xpr): + pd.DataFrame({"A": arr}, dtype='float') + + +@pytest.mark.parametrize( + 'op, supported', + [ + ('lt', False), + ('le', True), + ('gt', False), + ('ge', True), + ('eq', True), + ('ne', False)]) +def test_comparisons(op, supported): + arr1 = RelObjectArray(make_data()) + arr2 = RelObjectArray(make_data()) + ser1 = pd.Series(arr1) + ser2 = pd.Series(arr2) + func = getattr(operator, op) + + nsuppmsg = op + " not supported" + nocomparemsg = "Cannot compare RelObj to object of different type" + + if supported: + result = func(ser1, ser2) + expected = pd.Series([func(a, b) for (a, b) in zip(arr1, arr2)]) + tm.assert_series_equal(result, expected) + else: + with tm.assert_raises_regex(Exception, nsuppmsg): + result = func(ser1, ser2) + + oneval = 10 + if supported: + etype = TypeError + msg = nocomparemsg + else: + etype = Exception + msg = nsuppmsg + with tm.assert_raises_regex(etype, msg): + result = func(ser1, oneval) + + alist = [i for i in arr2] + if supported: + result = func(ser1, alist) + expected = pd.Series([func(a, b) for (a, b) in zip(arr1, alist)]) + tm.assert_series_equal(result, expected) + else: + with tm.assert_raises_regex(Exception, nsuppmsg): + result = func(ser1, alist) + + if op not in ['eq', 'ne']: + l2 = list(arr2) + l2[5] = 'abc' + with tm.assert_raises_regex(etype, msg): + func(ser1, "abc") + + with tm.assert_raises_regex(etype, msg): + func(ser1, l2)