diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 55e76512b2440..61f8ebe3618f1 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -29,12 +29,22 @@ Datetimelike API Changes - For :class:`DatetimeIndex` and :class:`TimedeltaIndex` with non-``None`` ``freq`` attribute, addition or subtraction of integer-dtyped array or ``Index`` will return an object of the same class (:issue:`19959`) +.. _whatsnew_0240.api.extension: + +ExtensionType Changes +^^^^^^^^^^^^^^^^^^^^^ + +- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` and ``.append()`` (:issue:`21185`) +- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore + the dtype has gained the ``construct_array_type`` (:issue:`21185`) +- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`) + .. _whatsnew_0240.api.other: Other API Changes ^^^^^^^^^^^^^^^^^ -- +- Invalid consruction of ``IntervalDtype`` will now always raise a ``TypeError`` rather than a ``ValueError`` if the subdtype is invalid (:issue:`21185`) - - diff --git a/pandas/conftest.py b/pandas/conftest.py index a463f573c82e0..227ee3a150154 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -92,6 +92,15 @@ def observed(request): def all_arithmetic_operators(request): """ Fixture for dunder names for common arithmetic operations + """ + return request.param + + +@pytest.fixture(params=['__eq__', '__ne__', '__le__', + '__lt__', '__ge__', '__gt__']) +def all_compare_operators(request): + """ + Fixture for dunder names for common compare operations """ return request.param diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 88bc497f9f22d..f6dd6dac87035 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -154,7 +154,7 @@ def _reconstruct_data(values, dtype, original): """ from pandas import Index if is_extension_array_dtype(dtype): - pass + values = dtype.construct_array_type(values)._from_sequence(values) elif is_datetime64tz_dtype(dtype) or is_period_dtype(dtype): values = Index(original)._shallow_copy(values, name=None) elif is_bool_dtype(dtype): @@ -705,7 +705,7 @@ def value_counts(values, sort=True, ascending=False, normalize=False, else: - if is_categorical_dtype(values) or is_sparse(values): + if is_extension_array_dtype(values) or is_sparse(values): # handle Categorical and sparse, result = Series(values)._values.value_counts(dropna=dropna) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 1922801c30719..04bf3bb8a06c3 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -9,6 +9,9 @@ from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv +from pandas.compat import set_function_name, PY3 +from pandas.core import ops +import operator _not_implemented_message = "{} does not implement {}." @@ -36,6 +39,7 @@ class ExtensionArray(object): * isna * take * copy + * append * _concat_same_type An additional method is available to satisfy pandas' internal, @@ -49,6 +53,7 @@ class ExtensionArray(object): methods: * fillna + * dropna * unique * factorize / _values_for_factorize * argsort / _values_for_argsort @@ -82,7 +87,7 @@ class ExtensionArray(object): # Constructors # ------------------------------------------------------------------------ @classmethod - def _from_sequence(cls, scalars): + def _from_sequence(cls, scalars, copy=False): """Construct a new ExtensionArray from a sequence of scalars. Parameters @@ -90,6 +95,8 @@ def _from_sequence(cls, scalars): scalars : Sequence Each element will be an instance of the scalar type for this array, ``cls.dtype.type``. + copy : boolean, default True + if True, copy the underlying data Returns ------- ExtensionArray @@ -379,6 +386,16 @@ def fillna(self, value=None, method=None, limit=None): new_values = self.copy() return new_values + def dropna(self): + """ Return ExtensionArray without NA values + + Returns + ------- + valid : ExtensionArray + """ + + return self[~self.isna()] + def unique(self): """Compute the ExtensionArray of unique values. @@ -567,6 +584,34 @@ def copy(self, deep=False): """ raise AbstractMethodError(self) + def append(self, other): + """ + Append a collection of Arrays together + + Parameters + ---------- + other : ExtensionArray or list/tuple of ExtensionArrays + + Returns + ------- + appended : ExtensionArray + """ + + to_concat = [self] + cls = self.__class__ + + if isinstance(other, (list, tuple)): + to_concat = to_concat + list(other) + else: + to_concat.append(other) + + for obj in to_concat: + if not isinstance(obj, cls): + raise TypeError('all inputs must be of type {}'.format( + cls.__name__)) + + return cls._concat_same_type(to_concat) + # ------------------------------------------------------------------------ # Block-related methods # ------------------------------------------------------------------------ @@ -610,3 +655,56 @@ def _ndarray_values(self): used for interacting with our indexers. """ return np.array(self) + + # ------------------------------------------------------------------------ + # ops-related methods + # ------------------------------------------------------------------------ + + @classmethod + def _add_comparison_methods_binary(cls): + cls.__eq__ = cls._make_comparison_op(operator.eq) + cls.__ne__ = cls._make_comparison_op(operator.ne) + cls.__lt__ = cls._make_comparison_op(operator.lt) + cls.__gt__ = cls._make_comparison_op(operator.gt) + cls.__le__ = cls._make_comparison_op(operator.le) + cls.__ge__ = cls._make_comparison_op(operator.ge) + + @classmethod + def _add_numeric_methods_binary(cls): + """ add in numeric methods """ + cls.__add__ = cls._make_arithmetic_op(operator.add) + cls.__radd__ = cls._make_arithmetic_op(ops.radd) + cls.__sub__ = cls._make_arithmetic_op(operator.sub) + cls.__rsub__ = cls._make_arithmetic_op(ops.rsub) + cls.__mul__ = cls._make_arithmetic_op(operator.mul) + cls.__rmul__ = cls._make_arithmetic_op(ops.rmul) + cls.__rpow__ = cls._make_arithmetic_op(ops.rpow) + cls.__pow__ = cls._make_arithmetic_op(operator.pow) + cls.__mod__ = cls._make_arithmetic_op(operator.mod) + cls.__rmod__ = cls._make_arithmetic_op(ops.rmod) + cls.__floordiv__ = cls._make_arithmetic_op(operator.floordiv) + cls.__rfloordiv__ = cls._make_arithmetic_op(ops.rfloordiv) + cls.__truediv__ = cls._make_arithmetic_op(operator.truediv) + cls.__rtruediv__ = cls._make_arithmetic_op(ops.rtruediv) + if not PY3: + cls.__div__ = cls._make_arithmetic_op(operator.div) + cls.__rdiv__ = cls._make_arithmetic_op(ops.rdiv) + + cls.__divmod__ = cls._make_arithmetic_op(divmod) + cls.__rdivmod__ = cls._make_arithmetic_op(ops.rdivmod) + + @classmethod + def make_comparison_op(cls, op): + def cmp_method(self, other): + raise NotImplementedError + + name = '__{name}__'.format(name=op.__name__) + return set_function_name(cmp_method, name, cls) + + @classmethod + def make_arithmetic_op(cls, op): + def integer_arithmetic_method(self, other): + raise NotImplementedError + + name = '__{name}__'.format(name=op.__name__) + return set_function_name(integer_arithmetic_method, name, cls) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 49e98c16c716e..c0c9a8d22ce4f 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -109,6 +109,11 @@ class ExtensionDtype(_DtypeOpsMixin): * name * construct_from_string + Optionally one can override construct_array_type for construction + with the name of this dtype via the Registry + + * construct_array_type + The `na_value` class attribute can be used to set the default NA value for this type. :attr:`numpy.nan` is used by default. @@ -156,6 +161,22 @@ def name(self): """ raise AbstractMethodError(self) + @classmethod + def construct_array_type(cls, array=None): + """Return the array type associated with this dtype + + Parameters + ---------- + array : array-like, optional + + Returns + ------- + type + """ + if array is None: + return cls + raise NotImplementedError + @classmethod def construct_from_string(cls, string): """Attempt to construct this type from a string. diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index e4ed6d544d42e..73176887ca0d9 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -647,6 +647,11 @@ def conv(r, dtype): def astype_nansafe(arr, dtype, copy=True): """ return a view if copy is False, but need to be very careful as the result shape could change! """ + + # dispatch on extension dtype if needed + if is_extension_array_dtype(dtype): + return dtype.array_type._from_sequence(arr, copy=copy) + if not isinstance(dtype, np.dtype): dtype = pandas_dtype(dtype) diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index c45838e6040a9..37d260088c4d4 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -9,7 +9,7 @@ DatetimeTZDtype, DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype, IntervalDtypeType, - ExtensionDtype, PandasExtensionDtype) + ExtensionDtype, registry) from .generic import (ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries, ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, @@ -1975,38 +1975,13 @@ def pandas_dtype(dtype): np.dtype or a pandas dtype """ - if isinstance(dtype, DatetimeTZDtype): - return dtype - elif isinstance(dtype, PeriodDtype): - return dtype - elif isinstance(dtype, CategoricalDtype): - return dtype - elif isinstance(dtype, IntervalDtype): - return dtype - elif isinstance(dtype, string_types): - try: - return DatetimeTZDtype.construct_from_string(dtype) - except TypeError: - pass - - if dtype.startswith('period[') or dtype.startswith('Period['): - # do not parse string like U as period[U] - try: - return PeriodDtype.construct_from_string(dtype) - except TypeError: - pass - - elif dtype.startswith('interval') or dtype.startswith('Interval'): - try: - return IntervalDtype.construct_from_string(dtype) - except TypeError: - pass + # registered extension types + result = registry.find(dtype) + if result is not None: + return result - try: - return CategoricalDtype.construct_from_string(dtype) - except TypeError: - pass - elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)): + # un-registered extension types + if isinstance(dtype, ExtensionDtype): return dtype try: diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 708f54f5ca75b..7d147da661e34 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -8,6 +8,60 @@ from .base import ExtensionDtype, _DtypeOpsMixin +class Registry(object): + """ Registry for dtype inference + + We can directly construct dtypes in pandas_dtypes if they are + a type; the registry allows us to register an extension dtype + to try inference from a string or a dtype class + + These are tried in order for inference. + """ + dtypes = [] + + @classmethod + def register(self, dtype): + """ + Parameters + ---------- + dtype : ExtensionDtype + """ + if not issubclass(dtype, (PandasExtensionDtype, ExtensionDtype)): + raise ValueError("can only register pandas extension dtypes") + + self.dtypes.append(dtype) + + def find(self, dtype): + """ + Parameters + ---------- + dtype : PandasExtensionDtype or string + + Returns + ------- + return the first matching dtype, otherwise return None + """ + if not isinstance(dtype, compat.string_types): + dtype_type = dtype + if not isinstance(dtype, type): + dtype_type = type(dtype) + if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)): + return dtype + + return None + + for dtype_type in self.dtypes: + try: + return dtype_type.construct_from_string(dtype) + except TypeError: + pass + + return None + + +registry = Registry() + + class PandasExtensionDtype(_DtypeOpsMixin): """ A np.dtype duck-typed class, suitable for holding a custom dtype. @@ -263,6 +317,21 @@ def _hash_categories(categories, ordered=True): else: return np.bitwise_xor.reduce(hashed) + @classmethod + def construct_array_type(cls, array=None): + """Return the array type associated with this dtype + + Parameters + ---------- + array : array-like, optional + + Returns + ------- + type + """ + from pandas import Categorical + return Categorical + @classmethod def construct_from_string(cls, string): """ attempt to construct this type from a string, raise a TypeError if @@ -552,11 +621,16 @@ def _parse_dtype_strict(cls, freq): @classmethod def construct_from_string(cls, string): """ - attempt to construct this type from a string, raise a TypeError - if its not possible + Strict construction from a string, raise a TypeError if not + possible """ from pandas.tseries.offsets import DateOffset - if isinstance(string, (compat.string_types, DateOffset)): + + if (isinstance(string, compat.string_types) and + (string.startswith('period[') or + string.startswith('Period[')) or + isinstance(string, DateOffset)): + # do not parse string like U as period[U] # avoid tuple to be regarded as freq try: return cls(freq=string) @@ -656,7 +730,7 @@ def __new__(cls, subtype=None): try: subtype = pandas_dtype(subtype) except TypeError: - raise ValueError("could not construct IntervalDtype") + raise TypeError("could not construct IntervalDtype") if is_categorical_dtype(subtype) or is_string_dtype(subtype): # GH 19016 @@ -678,8 +752,11 @@ def construct_from_string(cls, string): attempt to construct this type from a string, raise a TypeError if its not possible """ - if isinstance(string, compat.string_types): + if (isinstance(string, compat.string_types) and + (string.startswith('interval') or + string.startswith('Interval'))): return cls(string) + msg = "a string needs to be passed, got type {typ}" raise TypeError(msg.format(typ=type(string))) @@ -723,3 +800,10 @@ def is_dtype(cls, dtype): else: return False return super(IntervalDtype, cls).is_dtype(dtype) + + +# register the dtypes in search order +registry.register(DatetimeTZDtype) +registry.register(PeriodDtype) +registry.register(IntervalDtype) +registry.register(CategoricalDtype) diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 8f8d8760583ce..2694f5d5be384 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -800,7 +800,7 @@ def astype(self, dtype, copy=True): @cache_readonly def dtype(self): """Return the dtype object of the underlying data""" - return IntervalDtype.construct_from_string(str(self.left.dtype)) + return IntervalDtype(str(self.left.dtype)) @property def inferred_type(self): diff --git a/pandas/core/internals.py b/pandas/core/internals.py index fe508dc1bb0bc..a5e9107b8a660 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -633,8 +633,9 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, return self.make_block(Categorical(self.values, dtype=dtype)) # astype processing - dtype = np.dtype(dtype) - if self.dtype == dtype: + if not is_extension_array_dtype(dtype): + dtype = np.dtype(dtype) + if is_dtype_equal(self.dtype, dtype): if copy: return self.copy() return self @@ -662,7 +663,13 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, # _astype_nansafe works fine with 1-d only values = astype_nansafe(values.ravel(), dtype, copy=True) - values = values.reshape(self.shape) + + # TODO(extension) + # should we make this attribute? + try: + values = values.reshape(self.shape) + except AttributeError: + pass newb = make_block(values, placement=self.mgr_locs, klass=klass) @@ -3170,6 +3177,10 @@ def get_block_type(values, dtype=None): cls = TimeDeltaBlock elif issubclass(vtype, np.complexfloating): cls = ComplexBlock + elif is_categorical(values): + cls = CategoricalBlock + elif is_extension_array_dtype(values): + cls = ExtensionBlock elif issubclass(vtype, np.datetime64): assert not is_datetimetz(values) cls = DatetimeBlock @@ -3179,10 +3190,6 @@ def get_block_type(values, dtype=None): cls = IntBlock elif dtype == np.bool_: cls = BoolBlock - elif is_categorical(values): - cls = CategoricalBlock - elif is_extension_array_dtype(values): - cls = ExtensionBlock else: cls = ObjectBlock return cls diff --git a/pandas/core/missing.py b/pandas/core/missing.py index 31c489e2f8941..cb5ee8388c2c4 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -638,7 +638,8 @@ def fill_zeros(result, x, y, name, fill): # if we have a fill of inf, then sign it correctly # (GH 6178 and PR 9308) if np.isinf(fill): - signs = np.sign(y if name.startswith(('r', '__r')) else x) + signs = y if name.startswith(('r', '__r')) else x + signs = np.sign(signs.astype('float', copy=False)) negative_inf_mask = (signs.ravel() < 0) & mask np.putmask(result, negative_inf_mask, -fill) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index e14f82906cd06..69d1efb0304ab 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -27,7 +27,7 @@ is_integer_dtype, is_categorical_dtype, is_object_dtype, is_timedelta64_dtype, is_datetime64_dtype, is_datetime64tz_dtype, - is_bool_dtype, + is_bool_dtype, is_extension_array_dtype, is_list_like, is_scalar, _ensure_object) @@ -1003,8 +1003,18 @@ def _arith_method_SERIES(cls, op, special): if op is divmod else _construct_result) def na_op(x, y): - import pandas.core.computation.expressions as expressions + # handle extension array ops + # TODO(extension) + # the ops *between* non-same-type extension arrays are not + # very well defined + if (is_extension_array_dtype(x) or is_extension_array_dtype(y)): + if (op_name.startswith('__r') and not + is_extension_array_dtype(y) and not + is_scalar(y)): + y = x.__class__._from_sequence(y) + return op(x, y) + import pandas.core.computation.expressions as expressions try: result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs) except TypeError: @@ -1025,6 +1035,20 @@ def na_op(x, y): return result def safe_na_op(lvalues, rvalues): + """ + return the result of evaluating na_op on the passed in values + + try coercion to object type if the native types are not compatible + + Parameters + ---------- + lvalues : array-like + rvalues : array-like + + Raises + ------ + invalid operation raises TypeError + """ try: with np.errstate(all='ignore'): return na_op(lvalues, rvalues) @@ -1035,14 +1059,21 @@ def safe_na_op(lvalues, rvalues): raise def wrapper(left, right): - if isinstance(right, ABCDataFrame): return NotImplemented left, right = _align_method_SERIES(left, right) res_name = get_op_result_name(left, right) - if is_datetime64_dtype(left) or is_datetime64tz_dtype(left): + if is_categorical_dtype(left): + raise TypeError("{typ} cannot perform the operation " + "{op}".format(typ=type(left).__name__, op=str_rep)) + + elif (is_extension_array_dtype(left) or + is_extension_array_dtype(right)): + pass + + elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left): result = dispatch_to_index_op(op, left, right, pd.DatetimeIndex) return construct_result(left, result, index=left.index, name=res_name, @@ -1054,10 +1085,6 @@ def wrapper(left, right): index=left.index, name=res_name, dtype=result.dtype) - elif is_categorical_dtype(left): - raise TypeError("{typ} cannot perform the operation " - "{op}".format(typ=type(left).__name__, op=str_rep)) - lvalues = left.values rvalues = right if isinstance(rvalues, ABCSeries): @@ -1136,6 +1163,14 @@ def na_op(x, y): # The `not is_scalar(y)` check excludes the string "category" return op(y, x) + # handle extension array ops + # TODO(extension) + # the ops *between* non-same-type extension arrays are not + # very well defined + elif (is_extension_array_dtype(x) or + is_extension_array_dtype(y)): + return op(x, y) + elif is_object_dtype(x.dtype): result = _comp_method_OBJECT_ARRAY(op, x, y) diff --git a/pandas/core/series.py b/pandas/core/series.py index d59401414181f..4f63c56706c72 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -4054,11 +4054,9 @@ def _try_cast(arr, take_fast_path): subarr = Categorical(arr, dtype.categories, ordered=dtype.ordered) elif is_extension_array_dtype(dtype): - # We don't allow casting to third party dtypes, since we don't - # know what array belongs to which type. - msg = ("Cannot cast data to extension dtype '{}'. " - "Pass the extension array directly.".format(dtype)) - raise ValueError(msg) + # create an extension array from its dtype + array_type = dtype.construct_array_type(subarr) + subarr = array_type(subarr, copy=copy) elif dtype is not None and raise_cast_failure: raise diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index 12201f62946ac..adb4bf3f47572 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -514,7 +514,6 @@ def _to_str_columns(self): Render a DataFrame to a list of columns (as lists of strings). """ frame = self.tr_frame - # may include levels names also str_index = self._get_formatted_index(frame) diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index cc833af03ae66..fd8042212f658 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -9,10 +9,9 @@ from pandas import ( Series, Categorical, CategoricalIndex, IntervalIndex, date_range) -from pandas.compat import string_types from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, PeriodDtype, - IntervalDtype, CategoricalDtype) + IntervalDtype, CategoricalDtype, registry) from pandas.core.dtypes.common import ( is_categorical_dtype, is_categorical, is_datetime64tz_dtype, is_datetimetz, @@ -448,7 +447,7 @@ def test_construction_not_supported(self, subtype): def test_construction_errors(self): msg = 'could not construct IntervalDtype' - with tm.assert_raises_regex(ValueError, msg): + with tm.assert_raises_regex(TypeError, msg): IntervalDtype('xx') def test_construction_from_string(self): @@ -458,14 +457,21 @@ def test_construction_from_string(self): assert is_dtype_equal(self.dtype, result) @pytest.mark.parametrize('string', [ - 'foo', 'interval[foo]', 'foo[int64]', 0, 3.14, ('a', 'b'), None]) + 'foo', 'foo[int64]', 0, 3.14, ('a', 'b'), None]) def test_construction_from_string_errors(self, string): - if isinstance(string, string_types): - error, msg = ValueError, 'could not construct IntervalDtype' - else: - error, msg = TypeError, 'a string needs to be passed, got type' + # these are invalid entirely + msg = 'a string needs to be passed, got type' + + with tm.assert_raises_regex(TypeError, msg): + IntervalDtype.construct_from_string(string) + + @pytest.mark.parametrize('string', [ + 'interval[foo]']) + def test_construction_from_string_error_subtype(self, string): + # this is an invalid subtype + msg = 'could not construct IntervalDtype' - with tm.assert_raises_regex(error, msg): + with tm.assert_raises_regex(TypeError, msg): IntervalDtype.construct_from_string(string) def test_subclass(self): @@ -767,3 +773,24 @@ def test_update_dtype_errors(self, bad_dtype): msg = 'a CategoricalDtype must be passed to perform an update, ' with tm.assert_raises_regex(ValueError, msg): dtype.update_dtype(bad_dtype) + + +@pytest.mark.parametrize( + 'dtype', + [DatetimeTZDtype, CategoricalDtype, + PeriodDtype, IntervalDtype]) +def test_registry(dtype): + assert dtype in registry.dtypes + + +@pytest.mark.parametrize( + 'dtype, expected', + [('int64', None), + ('interval', IntervalDtype()), + ('interval[int64]', IntervalDtype()), + ('category', CategoricalDtype()), + ('period[D]', PeriodDtype('D')), + ('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))]) +def test_registry_find(dtype, expected): + + assert registry.find(dtype) == expected diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 9da985625c4ee..7bbba1e8640b1 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -46,6 +46,7 @@ class TestMyDtype(BaseDtypeTests): from .getitem import BaseGetitemTests # noqa from .groupby import BaseGroupbyTests # noqa from .interface import BaseInterfaceTests # noqa +from .ops import BaseOpsTests # noqa from .methods import BaseMethodsTests # noqa from .missing import BaseMissingTests # noqa from .reshaping import BaseReshapingTests # noqa diff --git a/pandas/tests/extension/base/constructors.py b/pandas/tests/extension/base/constructors.py index 489a430bb4020..972ef7f37acca 100644 --- a/pandas/tests/extension/base/constructors.py +++ b/pandas/tests/extension/base/constructors.py @@ -1,5 +1,6 @@ import pytest +import numpy as np import pandas as pd import pandas.util.testing as tm from pandas.core.internals import ExtensionBlock @@ -45,3 +46,14 @@ def test_series_given_mismatched_index_raises(self, data): msg = 'Length of passed values is 3, index implies 5' with tm.assert_raises_regex(ValueError, msg): pd.Series(data[:3], index=[0, 1, 2, 3, 4]) + + def test_from_dtype(self, data): + # construct from our dtype & string dtype + dtype = data.dtype + + expected = pd.Series(data) + result = pd.Series(np.array(data), dtype=dtype) + self.assert_series_equal(result, expected) + + result = pd.Series(np.array(data), dtype=str(dtype)) + self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 63d3d807c270c..52a12816c8722 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -1,3 +1,4 @@ +import pytest import numpy as np import pandas as pd @@ -46,3 +47,10 @@ def test_eq_with_str(self, dtype): def test_eq_with_numpy_object(self, dtype): assert dtype != np.dtype('object') + + def test_array_type(self, data, dtype): + assert dtype.construct_array_type() is type(data) + + def test_array_type_with_arg(self, data, dtype): + with pytest.raises(NotImplementedError): + dtype.construct_array_type('foo') diff --git a/pandas/tests/extension/base/interface.py b/pandas/tests/extension/base/interface.py index 8ef8debbdc666..69de0e1900831 100644 --- a/pandas/tests/extension/base/interface.py +++ b/pandas/tests/extension/base/interface.py @@ -40,6 +40,16 @@ def test_repr(self, data): df = pd.DataFrame({"A": data}) repr(df) + def test_repr_array(self, data): + # some arrays may be able to assert + # attributes in the repr + repr(data) + + def test_repr_array_long(self, data): + # some arrays may be able to assert a ... in the repr + with pd.option_context('display.max_seq_items', 1): + repr(data) + def test_dtype_name_in_info(self, data): buf = StringIO() pd.DataFrame({"A": data}).info(buf=buf) diff --git a/pandas/tests/extension/base/methods.py b/pandas/tests/extension/base/methods.py index c5436aa731d50..0ad3196277c34 100644 --- a/pandas/tests/extension/base/methods.py +++ b/pandas/tests/extension/base/methods.py @@ -19,7 +19,8 @@ def test_value_counts(self, all_data, dropna): other = all_data result = pd.Series(all_data).value_counts(dropna=dropna).sort_index() - expected = pd.Series(other).value_counts(dropna=dropna).sort_index() + expected = pd.Series(other).value_counts( + dropna=dropna).sort_index() self.assert_series_equal(result, expected) diff --git a/pandas/tests/extension/base/missing.py b/pandas/tests/extension/base/missing.py index af26d83df3fe2..43b2702c72193 100644 --- a/pandas/tests/extension/base/missing.py +++ b/pandas/tests/extension/base/missing.py @@ -23,6 +23,11 @@ def test_isna(self, data_missing): expected = pd.Series([], dtype=bool) self.assert_series_equal(result, expected) + def test_dropna_array(self, data_missing): + result = data_missing.dropna() + expected = data_missing[[1]] + self.assert_extension_array_equal(result, expected) + def test_dropna_series(self, data_missing): ser = pd.Series(data_missing) result = ser.dropna() diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py new file mode 100644 index 0000000000000..7bd97a94d4094 --- /dev/null +++ b/pandas/tests/extension/base/ops.py @@ -0,0 +1,54 @@ +import pytest +import numpy as np +import pandas as pd +from .base import BaseExtensionTests + + +class BaseOpsTests(BaseExtensionTests): + """Various Series and DataFrame ops methos.""" + + def compare(self, s, op, other, exc=NotImplementedError): + + with pytest.raises(exc): + getattr(s, op)(other) + + def test_arith_scalar(self, data, all_arithmetic_operators): + # scalar + op = all_arithmetic_operators + s = pd.Series(data) + self.compare(s, op, 1, exc=TypeError) + + def test_arith_array(self, data, all_arithmetic_operators): + # ndarray & other series + op = all_arithmetic_operators + s = pd.Series(data) + self.compare(s, op, np.ones(len(s), dtype=s.dtype.type), exc=TypeError) + + def test_compare_scalar(self, data, all_compare_operators): + op = all_compare_operators + + s = pd.Series(data) + + if op in '__eq__': + assert getattr(data, op)(0) is NotImplemented + assert not getattr(s, op)(0).all() + elif op in '__ne__': + assert getattr(data, op)(0) is NotImplemented + assert getattr(s, op)(0).all() + + else: + + # array + getattr(data, op)(0) is NotImplementedError + + # series + s = pd.Series(data) + with pytest.raises(TypeError): + getattr(s, op)(0) + + def test_error(self, data, all_arithmetic_operators): + + # invalid ops + op = all_arithmetic_operators + with pytest.raises(AttributeError): + getattr(data, op) diff --git a/pandas/tests/extension/base/reshaping.py b/pandas/tests/extension/base/reshaping.py index fe920a47ab740..ff739c97f2785 100644 --- a/pandas/tests/extension/base/reshaping.py +++ b/pandas/tests/extension/base/reshaping.py @@ -26,6 +26,14 @@ def test_concat(self, data, in_frame): assert dtype == data.dtype assert isinstance(result._data.blocks[0], ExtensionBlock) + def test_append(self, data): + + wrapped = pd.Series(data) + result = wrapped.append(wrapped) + expected = pd.concat([wrapped, wrapped]) + + self.assert_series_equal(result, expected) + @pytest.mark.parametrize('in_frame', [True, False]) def test_concat_all_na_block(self, data_missing, in_frame): valid_block = pd.Series(data_missing.take([1, 1]), index=[0, 1]) @@ -84,6 +92,7 @@ def test_concat_columns(self, data, na_value): expected = pd.DataFrame({ 'A': data._from_sequence(list(data[:3]) + [na_value]), 'B': [np.nan, 1, 2, 3]}) + result = pd.concat([df1, df2], axis=1) self.assert_frame_equal(result, expected) result = pd.concat([df1['A'], df2['B']], axis=1) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 530a4e7a22a7a..b331cded4ac6a 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -52,7 +52,25 @@ def data_for_grouping(): class TestDtype(base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is Categorical + + +class TestOps(base.BaseOpsTests): + + def test_compare_scalar(self, data, all_compare_operators): + op = all_compare_operators + + if op == '__eq__': + assert not getattr(data, op)(0).all() + + elif op == '__ne__': + assert getattr(data, op)(0).all() + + else: + with pytest.raises(TypeError): + getattr(data, op)(0) class TestInterface(base.BaseInterfaceTests): diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 90f0181beab0d..7bdbbf77cf4d6 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -15,6 +15,20 @@ class DecimalDtype(ExtensionDtype): name = 'decimal' na_value = decimal.Decimal('NaN') + @classmethod + def construct_array_type(cls, array=None): + """Return the array type associated with this dtype + + Parameters + ---------- + array : array-like, optional + + Returns + ------- + type + """ + return DecimalArray + @classmethod def construct_from_string(cls, string): if string == cls.name: @@ -27,7 +41,7 @@ def construct_from_string(cls, string): class DecimalArray(ExtensionArray): dtype = DecimalDtype() - def __init__(self, values): + def __init__(self, values, copy=False): assert all(isinstance(v, decimal.Decimal) for v in values) values = np.asarray(values, dtype=object) @@ -40,7 +54,7 @@ def __init__(self, values): # self._values = self.values = self.data @classmethod - def _from_sequence(cls, scalars): + def _from_sequence(cls, scalars, copy=False): return cls(scalars) @classmethod diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 1f8cf0264f62f..ca646486d2bff 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -92,15 +92,56 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseDecimal, base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is DecimalArray class TestInterface(BaseDecimal, base.BaseInterfaceTests): pass +class TestOps(BaseDecimal, base.BaseOpsTests): + + def compare(self, s, op, other): + # TODO(extension) + + pytest.xfail("not implemented") + + result = getattr(s, op)(other) + expected = result + + self.assert_series_equal(result, expected) + + def test_arith_scalar(self, data, all_arithmetic_operators): + # scalar + op = all_arithmetic_operators + s = pd.Series(data) + self.compare(s, op, 1) + + def test_arith_array(self, data, all_arithmetic_operators): + # ndarray & other series + op = all_arithmetic_operators + s = pd.Series(data) + self.compare(s, op, np.ones(len(s), dtype=s.dtype.type)) + + @pytest.mark.xfail(reason="Not implemented") + def test_compare_scalar(self, data, all_compare_operators): + op = all_compare_operators + + # array + result = getattr(data, op)(0) + expected = getattr(data.data, op)(0) + + tm.assert_series_equal(result, expected) + + class TestConstructors(BaseDecimal, base.BaseConstructorsTests): - pass + + @pytest.mark.xfail(reason="not implemented constructor from dtype") + def test_from_dtype(self, data): + # construct from our dtype & string dtype + pass class TestReshaping(BaseDecimal, base.BaseReshapingTests): @@ -147,6 +188,10 @@ class TestGroupby(BaseDecimal, base.BaseGroupbyTests): pass +# TODO(extension) +@pytest.mark.xfail(reason=( + "raising AssertionError as this is not implemented, " + "though easy enough to do")) def test_series_constructor_coerce_data_to_extension_dtype_raises(): xpr = ("Cannot cast data to extension dtype 'decimal'. Pass the " "extension array directly.") diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 10be7836cb8d7..f5d7d58277cc5 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -32,6 +32,20 @@ class JSONDtype(ExtensionDtype): # source compatibility with Py2. na_value = {} + @classmethod + def construct_array_type(cls, array=None): + """Return the array type associated with this dtype + + Parameters + ---------- + array : array-like, optional + + Returns + ------- + type + """ + return JSONArray + @classmethod def construct_from_string(cls, string): if string == cls.name: @@ -44,7 +58,7 @@ def construct_from_string(cls, string): class JSONArray(ExtensionArray): dtype = JSONDtype() - def __init__(self, values): + def __init__(self, values, copy=False): for val in values: if not isinstance(val, self.dtype.type): raise TypeError @@ -58,7 +72,7 @@ def __init__(self, values): # self._values = self.values = self.data @classmethod - def _from_sequence(cls, scalars): + def _from_sequence(cls, scalars, copy=False): return cls(scalars) @classmethod diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index b7ac8033f3f6d..97d3ddf1ec54a 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -107,7 +107,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseJSON, base.BaseDtypeTests): - pass + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type('foo') is JSONArray class TestInterface(BaseJSON, base.BaseInterfaceTests): @@ -130,13 +132,21 @@ def test_custom_asserts(self): class TestConstructors(BaseJSON, base.BaseConstructorsTests): - pass + + @pytest.mark.xfail(reason="not implemented constructor from dtype") + def test_from_dtype(self, data): + # construct from our dtype & string dtype + pass class TestReshaping(BaseJSON, base.BaseReshapingTests): pass +class TestOps(BaseJSON, base.BaseOpsTests): + pass + + class TestGetitem(BaseJSON, base.BaseGetitemTests): pass