Skip to content

Commit 8ec9e0a

Browse files
authored
ENH: recognize Decimal("NaN") in pd.isna (#39409)
1 parent 7c669c2 commit 8ec9e0a

File tree

16 files changed

+139
-60
lines changed

16 files changed

+139
-60
lines changed

doc/source/whatsnew/v1.3.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ Missing
379379
^^^^^^^
380380

381381
- Bug in :class:`Grouper` now correctly propagates ``dropna`` argument and :meth:`DataFrameGroupBy.transform` now correctly handles missing values for ``dropna=True`` (:issue:`35612`)
382-
-
382+
- Bug in :func:`isna`, and :meth:`Series.isna`, :meth:`Index.isna`, :meth:`DataFrame.isna` (and the corresponding ``notna`` functions) not recognizing ``Decimal("NaN")`` objects (:issue:`39409`)
383383
-
384384

385385
MultiIndex

pandas/_libs/missing.pyx

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from decimal import Decimal
12
import numbers
23

34
import cython
@@ -36,6 +37,8 @@ cdef:
3637

3738
bint is_32bit = not IS64
3839

40+
type cDecimal = Decimal # for faster isinstance checks
41+
3942

4043
cpdef bint is_matching_na(object left, object right, bint nan_matches_none=False):
4144
"""
@@ -86,6 +89,8 @@ cpdef bint is_matching_na(object left, object right, bint nan_matches_none=False
8689
and util.is_timedelta64_object(right)
8790
and get_timedelta64_value(right) == NPY_NAT
8891
)
92+
elif is_decimal_na(left):
93+
return is_decimal_na(right)
8994
return False
9095

9196

@@ -113,7 +118,18 @@ cpdef bint checknull(object val):
113118
The difference between `checknull` and `checknull_old` is that `checknull`
114119
does *not* consider INF or NEGINF to be NA.
115120
"""
116-
return val is C_NA or is_null_datetimelike(val, inat_is_null=False)
121+
return (
122+
val is C_NA
123+
or is_null_datetimelike(val, inat_is_null=False)
124+
or is_decimal_na(val)
125+
)
126+
127+
128+
cdef inline bint is_decimal_na(object val):
129+
"""
130+
Is this a decimal.Decimal object Decimal("NAN").
131+
"""
132+
return isinstance(val, cDecimal) and val != val
117133

118134

119135
cpdef bint checknull_old(object val):

pandas/_testing/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections
44
from datetime import datetime
5+
from decimal import Decimal
56
from functools import wraps
67
import operator
78
import os
@@ -146,7 +147,7 @@
146147
+ BYTES_DTYPES
147148
)
148149

149-
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA]
150+
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
150151

151152
EMPTY_STRING_PATTERN = re.compile("^$")
152153

pandas/_testing/asserters.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
from pandas._libs.lib import no_default
10+
from pandas._libs.missing import is_matching_na
1011
import pandas._libs.testing as _testing
1112

1213
from pandas.core.dtypes.common import (
@@ -458,22 +459,8 @@ def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):
458459

459460
if left_attr is right_attr:
460461
return True
461-
elif (
462-
is_number(left_attr)
463-
and np.isnan(left_attr)
464-
and is_number(right_attr)
465-
and np.isnan(right_attr)
466-
):
467-
# np.nan
468-
return True
469-
elif (
470-
isinstance(left_attr, (np.datetime64, np.timedelta64))
471-
and isinstance(right_attr, (np.datetime64, np.timedelta64))
472-
and type(left_attr) is type(right_attr)
473-
and np.isnat(left_attr)
474-
and np.isnat(right_attr)
475-
):
476-
# np.datetime64("nat") or np.timedelta64("nat")
462+
elif is_matching_na(left_attr, right_attr):
463+
# e.g. both np.nan, both NaT, both pd.NA, ...
477464
return True
478465

479466
try:

pandas/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def nselect_method(request):
304304
# ----------------------------------------------------------------
305305
# Missing values & co.
306306
# ----------------------------------------------------------------
307-
@pytest.fixture(params=tm.NULL_OBJECTS, ids=str)
307+
@pytest.fixture(params=tm.NULL_OBJECTS, ids=lambda x: type(x).__name__)
308308
def nulls_fixture(request):
309309
"""
310310
Fixture for each null type in pandas.

pandas/core/dtypes/missing.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
missing types & inference
33
"""
4+
from decimal import Decimal
45
from functools import partial
56

67
import numpy as np
@@ -610,20 +611,24 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool:
610611
"""
611612
if not lib.is_scalar(obj) or not isna(obj):
612613
return False
613-
if dtype.kind == "M":
614+
elif dtype.kind == "M":
614615
if isinstance(dtype, np.dtype):
615616
# i.e. not tzaware
616-
return not isinstance(obj, np.timedelta64)
617+
return not isinstance(obj, (np.timedelta64, Decimal))
617618
# we have to rule out tznaive dt64("NaT")
618-
return not isinstance(obj, (np.timedelta64, np.datetime64))
619-
if dtype.kind == "m":
620-
return not isinstance(obj, np.datetime64)
621-
if dtype.kind in ["i", "u", "f", "c"]:
619+
return not isinstance(obj, (np.timedelta64, np.datetime64, Decimal))
620+
elif dtype.kind == "m":
621+
return not isinstance(obj, (np.datetime64, Decimal))
622+
elif dtype.kind in ["i", "u", "f", "c"]:
622623
# Numeric
623624
return obj is not NaT and not isinstance(obj, (np.datetime64, np.timedelta64))
624625

626+
elif dtype == np.dtype(object):
627+
# This is needed for Categorical, but is kind of weird
628+
return True
629+
625630
# must be PeriodDType
626-
return not isinstance(obj, (np.datetime64, np.timedelta64))
631+
return not isinstance(obj, (np.datetime64, np.timedelta64, Decimal))
627632

628633

629634
def isna_all(arr: ArrayLike) -> bool:

pandas/tests/dtypes/cast/test_promote.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import datetime
6+
from decimal import Decimal
67

78
import numpy as np
89
import pytest
@@ -538,7 +539,20 @@ def test_maybe_promote_any_numpy_dtype_with_na(any_numpy_dtype_reduced, nulls_fi
538539
fill_value = nulls_fixture
539540
dtype = np.dtype(any_numpy_dtype_reduced)
540541

541-
if is_integer_dtype(dtype) and fill_value is not NaT:
542+
if isinstance(fill_value, Decimal):
543+
# Subject to change, but ATM (When Decimal(NAN) is being added to nulls_fixture)
544+
# this is the existing behavior in maybe_promote,
545+
# hinges on is_valid_na_for_dtype
546+
if dtype.kind in ["i", "u", "f", "c"]:
547+
if dtype.kind in ["i", "u"]:
548+
expected_dtype = np.dtype(np.float64)
549+
else:
550+
expected_dtype = dtype
551+
exp_val_for_scalar = np.nan
552+
else:
553+
expected_dtype = np.dtype(object)
554+
exp_val_for_scalar = fill_value
555+
elif is_integer_dtype(dtype) and fill_value is not NaT:
542556
# integer + other missing value (np.nan / None) casts to float
543557
expected_dtype = np.float64
544558
exp_val_for_scalar = np.nan

pandas/tests/dtypes/test_missing.py

+48-13
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,43 @@ def test_period(self):
317317
tm.assert_series_equal(isna(s), exp)
318318
tm.assert_series_equal(notna(s), ~exp)
319319

320+
def test_decimal(self):
321+
# scalars GH#23530
322+
a = Decimal(1.0)
323+
assert pd.isna(a) is False
324+
assert pd.notna(a) is True
325+
326+
b = Decimal("NaN")
327+
assert pd.isna(b) is True
328+
assert pd.notna(b) is False
329+
330+
# array
331+
arr = np.array([a, b])
332+
expected = np.array([False, True])
333+
result = pd.isna(arr)
334+
tm.assert_numpy_array_equal(result, expected)
335+
336+
result = pd.notna(arr)
337+
tm.assert_numpy_array_equal(result, ~expected)
338+
339+
# series
340+
ser = Series(arr)
341+
expected = Series(expected)
342+
result = pd.isna(ser)
343+
tm.assert_series_equal(result, expected)
344+
345+
result = pd.notna(ser)
346+
tm.assert_series_equal(result, ~expected)
347+
348+
# index
349+
idx = pd.Index(arr)
350+
expected = np.array([False, True])
351+
result = pd.isna(idx)
352+
tm.assert_numpy_array_equal(result, expected)
353+
354+
result = pd.notna(idx)
355+
tm.assert_numpy_array_equal(result, ~expected)
356+
320357

321358
@pytest.mark.parametrize("dtype_equal", [True, False])
322359
def test_array_equivalent(dtype_equal):
@@ -619,24 +656,22 @@ def test_empty_like(self):
619656

620657

621658
class TestLibMissing:
622-
def test_checknull(self):
623-
for value in na_vals:
624-
assert libmissing.checknull(value)
659+
@pytest.mark.parametrize("func", [libmissing.checknull, isna])
660+
def test_checknull(self, func):
661+
for value in na_vals + sometimes_na_vals:
662+
assert func(value)
625663

626664
for value in inf_vals:
627-
assert not libmissing.checknull(value)
665+
assert not func(value)
628666

629667
for value in int_na_vals:
630-
assert not libmissing.checknull(value)
631-
632-
for value in sometimes_na_vals:
633-
assert not libmissing.checknull(value)
668+
assert not func(value)
634669

635670
for value in never_na_vals:
636-
assert not libmissing.checknull(value)
671+
assert not func(value)
637672

638673
def test_checknull_old(self):
639-
for value in na_vals:
674+
for value in na_vals + sometimes_na_vals:
640675
assert libmissing.checknull_old(value)
641676

642677
for value in inf_vals:
@@ -645,9 +680,6 @@ def test_checknull_old(self):
645680
for value in int_na_vals:
646681
assert not libmissing.checknull_old(value)
647682

648-
for value in sometimes_na_vals:
649-
assert not libmissing.checknull_old(value)
650-
651683
for value in never_na_vals:
652684
assert not libmissing.checknull_old(value)
653685

@@ -682,6 +714,9 @@ def test_is_matching_na(self, nulls_fixture, nulls_fixture2):
682714
elif is_float(left) and is_float(right):
683715
# np.nan vs float("NaN") we consider as matching
684716
assert libmissing.is_matching_na(left, right)
717+
elif type(left) is type(right):
718+
# e.g. both Decimal("NaN")
719+
assert libmissing.is_matching_na(left, right)
685720
else:
686721
assert not libmissing.is_matching_na(left, right)
687722

pandas/tests/extension/base/interface.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def test_contains(self, data, data_missing):
4848

4949
# the data can never contain other nan-likes than na_value
5050
for na_value_obj in tm.NULL_OBJECTS:
51-
if na_value_obj is na_value:
51+
if na_value_obj is na_value or type(na_value_obj) == type(na_value):
52+
# type check for e.g. two instances of Decimal("NAN")
5253
continue
5354
assert na_value_obj not in data
5455
assert na_value_obj not in data_missing

pandas/tests/extension/decimal/test_decimal.py

-13
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,6 @@ class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
178178
class TestMethods(BaseDecimal, base.BaseMethodsTests):
179179
@pytest.mark.parametrize("dropna", [True, False])
180180
def test_value_counts(self, all_data, dropna, request):
181-
if any(x != x for x in all_data):
182-
mark = pytest.mark.xfail(
183-
reason="tm.assert_series_equal incorrectly raises",
184-
raises=AssertionError,
185-
)
186-
request.node.add_marker(mark)
187-
188181
all_data = all_data[:10]
189182
if dropna:
190183
other = np.array(all_data[~all_data.isna()])
@@ -212,12 +205,6 @@ class TestCasting(BaseDecimal, base.BaseCastingTests):
212205

213206

214207
class TestGroupby(BaseDecimal, base.BaseGroupbyTests):
215-
def test_groupby_apply_identity(self, data_for_grouping, request):
216-
if any(x != x for x in data_for_grouping):
217-
mark = pytest.mark.xfail(reason="tm.assert_series_equal raises incorrectly")
218-
request.node.add_marker(mark)
219-
super().test_groupby_apply_identity(data_for_grouping)
220-
221208
def test_groupby_agg_extension(self, data_for_grouping):
222209
super().test_groupby_agg_extension(data_for_grouping)
223210

pandas/tests/indexes/test_index_new.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Tests for the Index constructor conducting inference.
33
"""
4+
from decimal import Decimal
5+
46
import numpy as np
57
import pytest
68

@@ -89,6 +91,10 @@ def test_constructor_infer_periodindex(self):
8991
def test_constructor_infer_nat_dt_like(
9092
self, pos, klass, dtype, ctor, nulls_fixture, request
9193
):
94+
if isinstance(nulls_fixture, Decimal):
95+
# We dont cast these to datetime64/timedelta64
96+
return
97+
9298
expected = klass([NaT, NaT])
9399
assert expected.dtype == dtype
94100
data = [ctor]

pandas/tests/indexes/test_numeric.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ def test_numeric_compat(self):
104104
def test_insert_na(self, nulls_fixture):
105105
# GH 18295 (test missing)
106106
index = self.create_index()
107+
na_val = nulls_fixture
107108

108-
if nulls_fixture is pd.NaT:
109+
if na_val is pd.NaT:
109110
expected = Index([index[0], pd.NaT] + list(index[1:]), dtype=object)
110111
else:
111112
expected = Float64Index([index[0], np.nan] + list(index[1:]))
112-
result = index.insert(1, nulls_fixture)
113+
114+
result = index.insert(1, na_val)
113115
tm.assert_index_equal(result, expected)
114116

115117

pandas/tests/io/json/test_pandas.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
from datetime import timedelta
3+
from decimal import Decimal
34
from io import StringIO
45
import json
56
import os
@@ -1742,8 +1743,12 @@ def test_json_pandas_na(self):
17421743
result = DataFrame([[pd.NA]]).to_json()
17431744
assert result == '{"0":{"0":null}}'
17441745

1745-
def test_json_pandas_nulls(self, nulls_fixture):
1746+
def test_json_pandas_nulls(self, nulls_fixture, request):
17461747
# GH 31615
1748+
if isinstance(nulls_fixture, Decimal):
1749+
mark = pytest.mark.xfail(reason="not implemented")
1750+
request.node.add_marker(mark)
1751+
17471752
result = DataFrame([[nulls_fixture]]).to_json()
17481753
assert result == '{"0":{"0":null}}'
17491754

pandas/tests/tools/test_to_datetime.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
datetime,
77
timedelta,
88
)
9+
from decimal import Decimal
910
import locale
1011

1112
from dateutil.parser import parse
@@ -2446,9 +2447,15 @@ def test_nullable_integer_to_datetime():
24462447

24472448
@pytest.mark.parametrize("klass", [np.array, list])
24482449
def test_na_to_datetime(nulls_fixture, klass):
2449-
result = pd.to_datetime(klass([nulls_fixture]))
24502450

2451-
assert result[0] is pd.NaT
2451+
if isinstance(nulls_fixture, Decimal):
2452+
with pytest.raises(TypeError, match="not convertible to datetime"):
2453+
pd.to_datetime(klass([nulls_fixture]))
2454+
2455+
else:
2456+
result = pd.to_datetime(klass([nulls_fixture]))
2457+
2458+
assert result[0] is pd.NaT
24522459

24532460

24542461
def test_empty_string_datetime_coerce__format():

0 commit comments

Comments
 (0)