Skip to content

Commit 9cf631f

Browse files
TST: reintroduce check_series_type in assert_series_equal (#32670)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent a942264 commit 9cf631f

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ Other
396396
- Set operations on an object-dtype :class:`Index` now always return object-dtype results (:issue:`31401`)
397397
- Bug in :meth:`AbstractHolidayCalendar.holidays` when no rules were defined (:issue:`31415`)
398398
- Bug in :meth:`DataFrame.to_records` incorrectly losing timezone information in timezone-aware ``datetime64`` columns (:issue:`32535`)
399+
- Fixed :func:`pandas.testing.assert_series_equal` to correctly raise if left object is a different subclass with ``check_series_type=True`` (:issue:`32670`).
399400
- :meth:`IntegerArray.astype` now supports ``datetime64`` dtype (:issue:32538`)
400401

401402
.. ---------------------------------------------------------------------------

pandas/_testing.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ def assert_series_equal(
10501050
right,
10511051
check_dtype=True,
10521052
check_index_type="equiv",
1053+
check_series_type=True,
10531054
check_less_precise=False,
10541055
check_names=True,
10551056
check_exact=False,
@@ -1070,6 +1071,8 @@ def assert_series_equal(
10701071
check_index_type : bool or {'equiv'}, default 'equiv'
10711072
Whether to check the Index class, dtype and inferred_type
10721073
are identical.
1074+
check_series_type : bool, default True
1075+
Whether to check the Series class is identical.
10731076
check_less_precise : bool or int, default False
10741077
Specify comparison precision. Only used when check_exact is False.
10751078
5 digits (False) or 3 digits (True) after decimal points are compared.
@@ -1101,10 +1104,8 @@ def assert_series_equal(
11011104
# instance validation
11021105
_check_isinstance(left, right, Series)
11031106

1104-
# TODO: There are some tests using rhs is sparse
1105-
# lhs is dense. Should use assert_class_equal in future
1106-
assert isinstance(left, type(right))
1107-
# assert_class_equal(left, right, obj=obj)
1107+
if check_series_type:
1108+
assert_class_equal(left, right, obj=obj)
11081109

11091110
# length comparison
11101111
if len(left) != len(right):

pandas/tests/frame/test_subclass.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,14 @@ def test_subclass_align_combinations(self):
163163

164164
# frame + series
165165
res1, res2 = df.align(s, axis=0)
166-
exp1 = pd.DataFrame(
166+
exp1 = tm.SubclassedDataFrame(
167167
{"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]},
168168
index=list("ABCDE"),
169169
)
170170
# name is lost when
171-
exp2 = pd.Series([1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x")
171+
exp2 = tm.SubclassedSeries(
172+
[1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x"
173+
)
172174

173175
assert isinstance(res1, tm.SubclassedDataFrame)
174176
tm.assert_frame_equal(res1, exp1)

pandas/tests/util/test_assert_series_equal.py

+21
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,24 @@ def test_series_equal_categorical_mismatch(check_categorical):
194194
tm.assert_series_equal(s1, s2, check_categorical=check_categorical)
195195
else:
196196
_assert_series_equal_both(s1, s2, check_categorical=check_categorical)
197+
198+
199+
def test_series_equal_series_type():
200+
class MySeries(Series):
201+
pass
202+
203+
s1 = Series([1, 2])
204+
s2 = Series([1, 2])
205+
s3 = MySeries([1, 2])
206+
207+
tm.assert_series_equal(s1, s2, check_series_type=False)
208+
tm.assert_series_equal(s1, s2, check_series_type=True)
209+
210+
tm.assert_series_equal(s1, s3, check_series_type=False)
211+
tm.assert_series_equal(s3, s1, check_series_type=False)
212+
213+
with pytest.raises(AssertionError, match="Series classes are different"):
214+
tm.assert_series_equal(s1, s3, check_series_type=True)
215+
216+
with pytest.raises(AssertionError, match="Series classes are different"):
217+
tm.assert_series_equal(s3, s1, check_series_type=True)

0 commit comments

Comments
 (0)