Skip to content

Commit 4f3cd55

Browse files
tqa236pmhatre1
authored andcommitted
Fix SparseDtype comparison (pandas-dev#57783)
* Fix SparseDtype comparison * Fix tests * Add whatsnew * Fix
1 parent 2bcb844 commit 4f3cd55

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

pandas/core/dtypes/dtypes.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1705,17 +1705,15 @@ def __eq__(self, other: object) -> bool:
17051705

17061706
if isinstance(other, type(self)):
17071707
subtype = self.subtype == other.subtype
1708-
if self._is_na_fill_value:
1708+
if self._is_na_fill_value or other._is_na_fill_value:
17091709
# this case is complicated by two things:
17101710
# SparseDtype(float, float(nan)) == SparseDtype(float, np.nan)
17111711
# SparseDtype(float, np.nan) != SparseDtype(float, pd.NaT)
17121712
# i.e. we want to treat any floating-point NaN as equal, but
17131713
# not a floating-point NaN and a datetime NaT.
1714-
fill_value = (
1715-
other._is_na_fill_value
1716-
and isinstance(self.fill_value, type(other.fill_value))
1717-
or isinstance(other.fill_value, type(self.fill_value))
1718-
)
1714+
fill_value = isinstance(
1715+
self.fill_value, type(other.fill_value)
1716+
) or isinstance(other.fill_value, type(self.fill_value))
17191717
else:
17201718
with warnings.catch_warnings():
17211719
# Ignore spurious numpy warning

pandas/tests/arrays/sparse/test_dtype.py

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def test_nans_equal():
6868
assert b == a
6969

7070

71+
def test_nans_not_equal():
72+
# GH 54770
73+
a = SparseDtype(float, 0)
74+
b = SparseDtype(float, pd.NA)
75+
assert a != b
76+
assert b != a
77+
78+
7179
with warnings.catch_warnings():
7280
msg = "Allowing arbitrary scalar fill_value in SparseDtype is deprecated"
7381
warnings.filterwarnings("ignore", msg, category=FutureWarning)

pandas/tests/extension/test_sparse.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ def test_fillna_limit_backfill(self, data_missing):
240240
super().test_fillna_limit_backfill(data_missing)
241241

242242
def test_fillna_no_op_returns_copy(self, data, request):
243-
if np.isnan(data.fill_value):
244-
request.applymarker(
245-
pytest.mark.xfail(reason="returns array with different fill value")
246-
)
247243
super().test_fillna_no_op_returns_copy(data)
248244

249245
@pytest.mark.xfail(reason="Unsupported")
@@ -400,6 +396,8 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
400396
"rmul",
401397
"floordiv",
402398
"rfloordiv",
399+
"truediv",
400+
"rtruediv",
403401
"pow",
404402
"mod",
405403
"rmod",

0 commit comments

Comments
 (0)