Skip to content

BUG: assert_attr_equal with numpy nat or pd.NA #39461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ Other
- Bug in :class:`Index` constructor sometimes silently ignorning a specified ``dtype`` (:issue:`38879`)
- Bug in constructing a :class:`Series` from a list and a :class:`PandasDtype` (:issue:`39357`)
- Bug in :class:`Styler` which caused CSS to duplicate on multiple renders. (:issue:`39395`)
-
- Bug in :func:`pandas.testing.assert_series_equal`, :func:`pandas.testing.assert_frame_equal`, :func:`pandas.testing.assert_index_equal` and :func:`pandas.testing.assert_extension_array_equal` incorrectly raising when an attribute has an unrecognized NA type (:issue:`39461`)

.. ---------------------------------------------------------------------------

Expand Down
13 changes: 12 additions & 1 deletion pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,24 @@ def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):
):
# np.nan
return True
elif (
isinstance(left_attr, (np.datetime64, np.timedelta64))
and isinstance(right_attr, (np.datetime64, np.timedelta64))
and type(left_attr) is type(right_attr)
and np.isnat(left_attr)
and np.isnat(right_attr)
):
# np.datetime64("nat") or np.timedelta64("nat")
return True

try:
result = left_attr == right_attr
except TypeError:
# datetimetz on rhs may raise TypeError
result = False
if not isinstance(result, bool):
if (left_attr is pd.NA) ^ (right_attr is pd.NA):
result = False
elif not isinstance(result, bool):
result = result.all()

if result:
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,8 @@ def factorize(
values, dtype = _ensure_data(values)

if original.dtype.kind in ["m", "M"]:
# Note: factorize_array will cast NaT bc it has a __int__
# method, but will not cast the more-correct dtype.type("nat")
na_value = iNaT
else:
na_value = None
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def __init__(
stacklevel=2,
)
data = np.asarray(data, dtype="datetime64[ns]")
if fill_value is NaT:
fill_value = np.datetime64("NaT", "ns")
data = np.asarray(data)
sparse_values, sparse_index, fill_value = make_sparse(
data, kind=kind, fill_value=fill_value, dtype=dtype
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,14 +559,14 @@ def na_value_for_dtype(dtype, compat: bool = True):
>>> na_value_for_dtype(np.dtype('bool'))
False
>>> na_value_for_dtype(np.dtype('datetime64[ns]'))
NaT
numpy.datetime64('NaT')
"""
dtype = pandas_dtype(dtype)

if is_extension_array_dtype(dtype):
return dtype.na_value
if needs_i8_conversion(dtype):
return NaT
elif needs_i8_conversion(dtype):
return dtype.type("NaT", "ns")
elif is_float_dtype(dtype):
return np.nan
elif is_integer_dtype(dtype):
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/arrays/sparse/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
("float", np.nan),
("bool", False),
("object", np.nan),
("datetime64[ns]", pd.NaT),
("timedelta64[ns]", pd.NaT),
("datetime64[ns]", np.datetime64("NaT", "ns")),
("timedelta64[ns]", np.timedelta64("NaT", "ns")),
],
)
def test_inferred_dtype(dtype, fill_value):
Expand Down
10 changes: 7 additions & 3 deletions pandas/tests/dtypes/test_missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def test_array_equivalent_nested():
"dtype, na_value",
[
# Datetime-like
(np.dtype("M8[ns]"), NaT),
(np.dtype("m8[ns]"), NaT),
(np.dtype("M8[ns]"), np.datetime64("NaT", "ns")),
(np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")),
(DatetimeTZDtype.construct_from_string("datetime64[ns, US/Eastern]"), NaT),
(PeriodDtype("M"), NaT),
# Integer
Expand All @@ -499,7 +499,11 @@ def test_array_equivalent_nested():
)
def test_na_value_for_dtype(dtype, na_value):
result = na_value_for_dtype(dtype)
assert result is na_value
# identify check doesnt work for datetime64/timedelta64("NaT") bc they
# are not singletons
assert result is na_value or (
isna(result) and isna(na_value) and type(result) is type(na_value)
)


class TestNAObj:
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/util/test_assert_attr_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from types import SimpleNamespace

import pytest

from pandas.core.dtypes.common import is_float

import pandas._testing as tm


def test_assert_attr_equal(nulls_fixture):
obj = SimpleNamespace()
obj.na_value = nulls_fixture
assert tm.assert_attr_equal("na_value", obj, obj)


def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2):
obj = SimpleNamespace()
obj.na_value = nulls_fixture

obj2 = SimpleNamespace()
obj2.na_value = nulls_fixture2

if nulls_fixture is nulls_fixture2:
assert tm.assert_attr_equal("na_value", obj, obj2)
elif is_float(nulls_fixture) and is_float(nulls_fixture2):
# we consider float("nan") and np.float64("nan") to be equivalent
assert tm.assert_attr_equal("na_value", obj, obj2)
else:
with pytest.raises(AssertionError, match='"na_value" are different'):
tm.assert_attr_equal("na_value", obj, obj2)