From 2d9396fd973ac3e411c9a8d1bf1cb1cdab7b37bd Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 24 Jan 2021 13:33:35 -0800 Subject: [PATCH] TST: tighten Decimal tests --- pandas/tests/extension/decimal/array.py | 13 +++++-- .../tests/extension/decimal/test_decimal.py | 36 +++++++++++-------- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 3dcdde77aa420..4122fcaae496b 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -9,7 +9,7 @@ import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.dtypes.common import is_dtype_equal, pandas_dtype +from pandas.core.dtypes.common import is_dtype_equal, is_float, pandas_dtype import pandas as pd from pandas.api.extensions import no_default, register_extension_dtype @@ -52,8 +52,10 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): __array_priority__ = 1000 def __init__(self, values, dtype=None, copy=False, context=None): - for val in values: - if not isinstance(val, decimal.Decimal): + for i, val in enumerate(values): + if is_float(val) and np.isnan(val): + values[i] = DecimalDtype.na_value + elif not isinstance(val, decimal.Decimal): raise TypeError("All values must be of type " + str(decimal.Decimal)) values = np.asarray(values, dtype=object) @@ -228,6 +230,11 @@ def convert_values(param): return np.asarray(res, dtype=bool) + def value_counts(self, dropna: bool = False): + from pandas.core.algorithms import value_counts + + return value_counts(self.to_numpy(), dropna=dropna) + def to_decimal(values, context=None): return DecimalArray([decimal.Decimal(x) for x in values], context=context) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 08768bda312ba..5a1fbdf401b86 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -122,10 +122,7 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests): class TestConstructors(BaseDecimal, base.BaseConstructorsTests): - @pytest.mark.skip(reason="not implemented constructor from dtype") - def test_from_dtype(self, data): - # construct from our dtype & string dtype - pass + pass class TestReshaping(BaseDecimal, base.BaseReshapingTests): @@ -168,20 +165,32 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests): class TestMethods(BaseDecimal, base.BaseMethodsTests): @pytest.mark.parametrize("dropna", [True, False]) - @pytest.mark.xfail(reason="value_counts not implemented yet.") - def test_value_counts(self, all_data, dropna): + def test_value_counts(self, all_data, dropna, request): + if any(x != x for x in all_data): + mark = pytest.mark.xfail( + reason="tm.assert_series_equal incorrectly raises", + raises=AssertionError, + ) + request.node.add_marker(mark) + all_data = all_data[:10] if dropna: other = np.array(all_data[~all_data.isna()]) else: other = all_data - result = pd.Series(all_data).value_counts(dropna=dropna).sort_index() - expected = pd.Series(other).value_counts(dropna=dropna).sort_index() + vcs = pd.Series(all_data).value_counts(dropna=dropna) + vcs_ex = pd.Series(other).value_counts(dropna=dropna) + + with decimal.localcontext() as ctx: + # avoid raising when comparing Decimal("NAN") < Decimal(2) + ctx.traps[decimal.InvalidOperation] = False + + result = vcs.sort_index() + expected = vcs_ex.sort_index() tm.assert_series_equal(result, expected) - @pytest.mark.xfail(reason="value_counts not implemented yet.") def test_value_counts_with_normalize(self, data): return super().test_value_counts_with_normalize(data) @@ -191,13 +200,12 @@ class TestCasting(BaseDecimal, base.BaseCastingTests): class TestGroupby(BaseDecimal, base.BaseGroupbyTests): - @pytest.mark.xfail( - reason="needs to correctly define __eq__ to handle nans, xref #27081." - ) - def test_groupby_apply_identity(self, data_for_grouping): + def test_groupby_apply_identity(self, data_for_grouping, request): + if any(x != x for x in data_for_grouping): + mark = pytest.mark.xfail(reason="tm.assert_series_equal raises incorrectly") + request.node.add_marker(mark) super().test_groupby_apply_identity(data_for_grouping) - @pytest.mark.xfail(reason="GH#39098: Converts agg result to object") def test_groupby_agg_extension(self, data_for_grouping): super().test_groupby_agg_extension(data_for_grouping)