Skip to content

Commit a53a9e8

Browse files
phoflcbpygit
authored andcommitted
BUG: assert_series_equal not properly respecting check-dtype (pandas-dev#56654)
1 parent ab666f4 commit a53a9e8

File tree

4 files changed

+21
-25
lines changed

4 files changed

+21
-25
lines changed

pandas/_testing/asserters.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -949,9 +949,15 @@ def assert_series_equal(
949949
obj=str(obj),
950950
)
951951
else:
952+
# convert both to NumPy if not, check_dtype would raise earlier
953+
lv, rv = left_values, right_values
954+
if isinstance(left_values, ExtensionArray):
955+
lv = left_values.to_numpy()
956+
if isinstance(right_values, ExtensionArray):
957+
rv = right_values.to_numpy()
952958
assert_numpy_array_equal(
953-
left_values,
954-
right_values,
959+
lv,
960+
rv,
955961
check_dtype=check_dtype,
956962
obj=str(obj),
957963
index_values=left.index,

pandas/tests/extension/test_numpy.py

-10
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,6 @@ def test_index_from_listlike_with_dtype(self, data):
421421
def test_EA_types(self, engine, data, request):
422422
super().test_EA_types(engine, data, request)
423423

424-
@pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
425-
def test_compare_array(self, data, comparison_op):
426-
super().test_compare_array(data, comparison_op)
427-
428-
def test_compare_scalar(self, data, comparison_op, request):
429-
if data.dtype.kind == "f" or comparison_op.__name__ in ["eq", "ne"]:
430-
mark = pytest.mark.xfail(reason="Expect NumpyEA, get np.ndarray")
431-
request.applymarker(mark)
432-
super().test_compare_scalar(data, comparison_op)
433-
434424

435425
class Test2DCompat(base.NDArrayBacked2DTests):
436426
pass

pandas/tests/util/test_assert_frame_equal.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,7 @@ def test_assert_frame_equal_extension_dtype_mismatch():
211211
"\\[right\\]: int[32|64]"
212212
)
213213

214-
# TODO: this shouldn't raise (or should raise a better error message)
215-
# https://github.com/pandas-dev/pandas/issues/56131
216-
with pytest.raises(AssertionError, match="classes are different"):
217-
tm.assert_frame_equal(left, right, check_dtype=False)
214+
tm.assert_frame_equal(left, right, check_dtype=False)
218215

219216
with pytest.raises(AssertionError, match=msg):
220217
tm.assert_frame_equal(left, right, check_dtype=True)
@@ -246,7 +243,6 @@ def test_assert_frame_equal_ignore_extension_dtype_mismatch():
246243
tm.assert_frame_equal(left, right, check_dtype=False)
247244

248245

249-
@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
250246
def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class():
251247
# https://github.com/pandas-dev/pandas/issues/35715
252248
left = DataFrame({"a": [1, 2, 3]}, dtype="Int64")
@@ -300,9 +296,7 @@ def test_frame_equal_mixed_dtypes(frame_or_series, any_numeric_ea_dtype, indexer
300296
dtypes = (any_numeric_ea_dtype, "int64")
301297
obj1 = frame_or_series([1, 2], dtype=dtypes[indexer[0]])
302298
obj2 = frame_or_series([1, 2], dtype=dtypes[indexer[1]])
303-
msg = r'(Series|DataFrame.iloc\[:, 0\] \(column name="0"\) classes) are different'
304-
with pytest.raises(AssertionError, match=msg):
305-
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)
299+
tm.assert_equal(obj1, obj2, check_exact=True, check_dtype=False)
306300

307301

308302
def test_assert_frame_equal_check_like_different_indexes():

pandas/tests/util/test_assert_series_equal.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,7 @@ def test_assert_series_equal_extension_dtype_mismatch():
290290
\\[left\\]: Int64
291291
\\[right\\]: int[32|64]"""
292292

293-
# TODO: this shouldn't raise (or should raise a better error message)
294-
# https://github.com/pandas-dev/pandas/issues/56131
295-
with pytest.raises(AssertionError, match="Series classes are different"):
296-
tm.assert_series_equal(left, right, check_dtype=False)
293+
tm.assert_series_equal(left, right, check_dtype=False)
297294

298295
with pytest.raises(AssertionError, match=msg):
299296
tm.assert_series_equal(left, right, check_dtype=True)
@@ -372,7 +369,6 @@ def test_assert_series_equal_ignore_extension_dtype_mismatch():
372369
tm.assert_series_equal(left, right, check_dtype=False)
373370

374371

375-
@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/56131")
376372
def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class():
377373
# https://github.com/pandas-dev/pandas/issues/35715
378374
left = Series([1, 2, 3], dtype="Int64")
@@ -456,3 +452,13 @@ def test_large_unequal_ints(dtype):
456452
right = Series([1577840521123543], dtype=dtype)
457453
with pytest.raises(AssertionError, match="Series are different"):
458454
tm.assert_series_equal(left, right)
455+
456+
457+
@pytest.mark.parametrize("dtype", [None, object])
458+
@pytest.mark.parametrize("check_exact", [True, False])
459+
@pytest.mark.parametrize("val", [3, 3.5])
460+
def test_ea_and_numpy_no_dtype_check(val, check_exact, dtype):
461+
# GH#56651
462+
left = Series([1, 2, val], dtype=dtype)
463+
right = Series(pd.array([1, 2, val]))
464+
tm.assert_series_equal(left, right, check_dtype=False, check_exact=check_exact)

0 commit comments

Comments
 (0)