Skip to content

Commit 7d9ad04

Browse files
authored
TST: tighten Decimal tests (#39381)
1 parent a5ecf22 commit 7d9ad04

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

pandas/tests/extension/decimal/array.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010

1111
from pandas.core.dtypes.base import ExtensionDtype
12-
from pandas.core.dtypes.common import is_dtype_equal, pandas_dtype
12+
from pandas.core.dtypes.common import is_dtype_equal, is_float, pandas_dtype
1313

1414
import pandas as pd
1515
from pandas.api.extensions import no_default, register_extension_dtype
@@ -52,8 +52,10 @@ class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
5252
__array_priority__ = 1000
5353

5454
def __init__(self, values, dtype=None, copy=False, context=None):
55-
for val in values:
56-
if not isinstance(val, decimal.Decimal):
55+
for i, val in enumerate(values):
56+
if is_float(val) and np.isnan(val):
57+
values[i] = DecimalDtype.na_value
58+
elif not isinstance(val, decimal.Decimal):
5759
raise TypeError("All values must be of type " + str(decimal.Decimal))
5860
values = np.asarray(values, dtype=object)
5961

@@ -228,6 +230,11 @@ def convert_values(param):
228230

229231
return np.asarray(res, dtype=bool)
230232

233+
def value_counts(self, dropna: bool = False):
234+
from pandas.core.algorithms import value_counts
235+
236+
return value_counts(self.to_numpy(), dropna=dropna)
237+
231238

232239
def to_decimal(values, context=None):
233240
return DecimalArray([decimal.Decimal(x) for x in values], context=context)

pandas/tests/extension/decimal/test_decimal.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,7 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests):
122122

123123

124124
class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
125-
@pytest.mark.skip(reason="not implemented constructor from dtype")
126-
def test_from_dtype(self, data):
127-
# construct from our dtype & string dtype
128-
pass
125+
pass
129126

130127

131128
class TestReshaping(BaseDecimal, base.BaseReshapingTests):
@@ -168,20 +165,32 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
168165

169166
class TestMethods(BaseDecimal, base.BaseMethodsTests):
170167
@pytest.mark.parametrize("dropna", [True, False])
171-
@pytest.mark.xfail(reason="value_counts not implemented yet.")
172-
def test_value_counts(self, all_data, dropna):
168+
def test_value_counts(self, all_data, dropna, request):
169+
if any(x != x for x in all_data):
170+
mark = pytest.mark.xfail(
171+
reason="tm.assert_series_equal incorrectly raises",
172+
raises=AssertionError,
173+
)
174+
request.node.add_marker(mark)
175+
173176
all_data = all_data[:10]
174177
if dropna:
175178
other = np.array(all_data[~all_data.isna()])
176179
else:
177180
other = all_data
178181

179-
result = pd.Series(all_data).value_counts(dropna=dropna).sort_index()
180-
expected = pd.Series(other).value_counts(dropna=dropna).sort_index()
182+
vcs = pd.Series(all_data).value_counts(dropna=dropna)
183+
vcs_ex = pd.Series(other).value_counts(dropna=dropna)
184+
185+
with decimal.localcontext() as ctx:
186+
# avoid raising when comparing Decimal("NAN") < Decimal(2)
187+
ctx.traps[decimal.InvalidOperation] = False
188+
189+
result = vcs.sort_index()
190+
expected = vcs_ex.sort_index()
181191

182192
tm.assert_series_equal(result, expected)
183193

184-
@pytest.mark.xfail(reason="value_counts not implemented yet.")
185194
def test_value_counts_with_normalize(self, data):
186195
return super().test_value_counts_with_normalize(data)
187196

@@ -191,13 +200,12 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):
191200

192201

193202
class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
194-
@pytest.mark.xfail(
195-
reason="needs to correctly define __eq__ to handle nans, xref #27081."
196-
)
197-
def test_groupby_apply_identity(self, data_for_grouping):
203+
def test_groupby_apply_identity(self, data_for_grouping, request):
204+
if any(x != x for x in data_for_grouping):
205+
mark = pytest.mark.xfail(reason="tm.assert_series_equal raises incorrectly")
206+
request.node.add_marker(mark)
198207
super().test_groupby_apply_identity(data_for_grouping)
199208

200-
@pytest.mark.xfail(reason="GH#39098: Converts agg result to object")
201209
def test_groupby_agg_extension(self, data_for_grouping):
202210
super().test_groupby_agg_extension(data_for_grouping)
203211

0 commit comments

Comments
 (0)