Skip to content

Commit ab82fb0

Browse files
[ArrayManager] Implement .equals method (#39721)
1 parent 373d677 commit ab82fb0

File tree

4 files changed

+43
-23
lines changed

4 files changed

+43
-23
lines changed

pandas/core/internals/array_manager.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from pandas.core.dtypes.dtypes import ExtensionDtype, PandasDtype
2222
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
23-
from pandas.core.dtypes.missing import isna
23+
from pandas.core.dtypes.missing import array_equals, isna
2424

2525
import pandas.core.algorithms as algos
2626
from pandas.core.arrays import ExtensionArray
@@ -829,9 +829,16 @@ def _make_na_array(self, fill_value=None):
829829
values.fill(fill_value)
830830
return values
831831

832-
def equals(self, other: object) -> bool:
833-
# TODO
834-
raise NotImplementedError
832+
def _equal_values(self, other) -> bool:
833+
"""
834+
Used in .equals defined in base class. Only check the column values
835+
assuming shape and indexes have already been checked.
836+
"""
837+
for left, right in zip(self.arrays, other.arrays):
838+
if not array_equals(left, right):
839+
return False
840+
else:
841+
return True
835842

836843
def unstack(self, unstacker, fill_value) -> ArrayManager:
837844
"""

pandas/core/internals/base.py

+22
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,25 @@ def reindex_axis(
7070
consolidate=consolidate,
7171
only_slice=only_slice,
7272
)
73+
74+
def _equal_values(self: T, other: T) -> bool:
75+
"""
76+
To be implemented by the subclasses. Only check the column values
77+
assuming shape and indexes have already been checked.
78+
"""
79+
raise AbstractMethodError(self)
80+
81+
def equals(self, other: object) -> bool:
82+
"""
83+
Implementation for DataFrame.equals
84+
"""
85+
if not isinstance(other, DataManager):
86+
return False
87+
88+
self_axes, other_axes = self.axes, other.axes
89+
if len(self_axes) != len(other_axes):
90+
return False
91+
if not all(ax1.equals(ax2) for ax1, ax2 in zip(self_axes, other_axes)):
92+
return False
93+
94+
return self._equal_values(other)

pandas/core/internals/managers.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -1395,16 +1395,11 @@ def take(self, indexer, axis: int = 1, verify: bool = True, convert: bool = True
13951395
consolidate=False,
13961396
)
13971397

1398-
def equals(self, other: object) -> bool:
1399-
if not isinstance(other, BlockManager):
1400-
return False
1401-
1402-
self_axes, other_axes = self.axes, other.axes
1403-
if len(self_axes) != len(other_axes):
1404-
return False
1405-
if not all(ax1.equals(ax2) for ax1, ax2 in zip(self_axes, other_axes)):
1406-
return False
1407-
1398+
def _equal_values(self: T, other: T) -> bool:
1399+
"""
1400+
Used in .equals defined in base class. Only check the column values
1401+
assuming shape and indexes have already been checked.
1402+
"""
14081403
if self.ndim == 1:
14091404
# For SingleBlockManager (i.e.Series)
14101405
if other.ndim != 1:

pandas/tests/frame/methods/test_equals.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import numpy as np
22

3-
import pandas.util._test_decorators as td
4-
53
from pandas import DataFrame, date_range
64
import pandas._testing as tm
75

8-
# TODO(ArrayManager) implement equals
9-
pytestmark = td.skip_array_manager_not_yet_implemented
10-
116

127
class TestEquals:
138
def test_dataframe_not_equal(self):
@@ -16,13 +11,14 @@ def test_dataframe_not_equal(self):
1611
df2 = DataFrame({"a": ["s", "d"], "b": [1, 2]})
1712
assert df1.equals(df2) is False
1813

19-
def test_equals_different_blocks(self):
14+
def test_equals_different_blocks(self, using_array_manager):
2015
# GH#9330
2116
df0 = DataFrame({"A": ["x", "y"], "B": [1, 2], "C": ["w", "z"]})
2217
df1 = df0.reset_index()[["A", "B", "C"]]
23-
# this assert verifies that the above operations have
24-
# induced a block rearrangement
25-
assert df0._mgr.blocks[0].dtype != df1._mgr.blocks[0].dtype
18+
if not using_array_manager:
19+
# this assert verifies that the above operations have
20+
# induced a block rearrangement
21+
assert df0._mgr.blocks[0].dtype != df1._mgr.blocks[0].dtype
2622

2723
# do the real tests
2824
tm.assert_frame_equal(df0, df1)

0 commit comments

Comments
 (0)