diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 11757e1bf14e0..0630823f0de35 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -396,6 +396,7 @@ Other - Set operations on an object-dtype :class:`Index` now always return object-dtype results (:issue:`31401`) - Bug in :meth:`AbstractHolidayCalendar.holidays` when no rules were defined (:issue:`31415`) - Bug in :meth:`DataFrame.to_records` incorrectly losing timezone information in timezone-aware ``datetime64`` columns (:issue:`32535`) +- Fixed :func:`pandas.testing.assert_series_equal` to correctly raise if left object is a different subclass with ``check_series_type=True`` (:issue:`32670`). - :meth:`IntegerArray.astype` now supports ``datetime64`` dtype (:issue:32538`) .. --------------------------------------------------------------------------- diff --git a/pandas/_testing.py b/pandas/_testing.py index dff15c66750ac..d473b453d77d2 100644 --- a/pandas/_testing.py +++ b/pandas/_testing.py @@ -1050,6 +1050,7 @@ def assert_series_equal( right, check_dtype=True, check_index_type="equiv", + check_series_type=True, check_less_precise=False, check_names=True, check_exact=False, @@ -1070,6 +1071,8 @@ def assert_series_equal( check_index_type : bool or {'equiv'}, default 'equiv' Whether to check the Index class, dtype and inferred_type are identical. + check_series_type : bool, default True + Whether to check the Series class is identical. check_less_precise : bool or int, default False Specify comparison precision. Only used when check_exact is False. 5 digits (False) or 3 digits (True) after decimal points are compared. @@ -1101,10 +1104,8 @@ def assert_series_equal( # instance validation _check_isinstance(left, right, Series) - # TODO: There are some tests using rhs is sparse - # lhs is dense. Should use assert_class_equal in future - assert isinstance(left, type(right)) - # assert_class_equal(left, right, obj=obj) + if check_series_type: + assert_class_equal(left, right, obj=obj) # length comparison if len(left) != len(right): diff --git a/pandas/tests/frame/test_subclass.py b/pandas/tests/frame/test_subclass.py index a2e7dc527c4b8..16bf651829a04 100644 --- a/pandas/tests/frame/test_subclass.py +++ b/pandas/tests/frame/test_subclass.py @@ -163,12 +163,14 @@ def test_subclass_align_combinations(self): # frame + series res1, res2 = df.align(s, axis=0) - exp1 = pd.DataFrame( + exp1 = tm.SubclassedDataFrame( {"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]}, index=list("ABCDE"), ) # name is lost when - exp2 = pd.Series([1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x") + exp2 = tm.SubclassedSeries( + [1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x" + ) assert isinstance(res1, tm.SubclassedDataFrame) tm.assert_frame_equal(res1, exp1) diff --git a/pandas/tests/util/test_assert_series_equal.py b/pandas/tests/util/test_assert_series_equal.py index eaf0824f52927..2550b32446055 100644 --- a/pandas/tests/util/test_assert_series_equal.py +++ b/pandas/tests/util/test_assert_series_equal.py @@ -194,3 +194,24 @@ def test_series_equal_categorical_mismatch(check_categorical): tm.assert_series_equal(s1, s2, check_categorical=check_categorical) else: _assert_series_equal_both(s1, s2, check_categorical=check_categorical) + + +def test_series_equal_series_type(): + class MySeries(Series): + pass + + s1 = Series([1, 2]) + s2 = Series([1, 2]) + s3 = MySeries([1, 2]) + + tm.assert_series_equal(s1, s2, check_series_type=False) + tm.assert_series_equal(s1, s2, check_series_type=True) + + tm.assert_series_equal(s1, s3, check_series_type=False) + tm.assert_series_equal(s3, s1, check_series_type=False) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s1, s3, check_series_type=True) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s3, s1, check_series_type=True)