Skip to content

Commit 8275e88

Browse files
authored
ENH: Improve ref-tracking for group keys (pandas-dev#51442)
1 parent c133327 commit 8275e88

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

pandas/core/groupby/grouper.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -944,12 +944,17 @@ def is_in_obj(gpr) -> bool:
944944
if not hasattr(gpr, "name"):
945945
return False
946946
if using_copy_on_write():
947-
# For the CoW case, we need an equality check as the identity check
948-
# no longer works (each Series from column access is a new object)
947+
# For the CoW case, we check the references to determine if the
948+
# series is part of the object
949949
try:
950-
return gpr.equals(obj[gpr.name])
951-
except (AttributeError, KeyError, IndexError, InvalidIndexError):
950+
obj_gpr_column = obj[gpr.name]
951+
except (KeyError, IndexError, InvalidIndexError):
952952
return False
953+
if isinstance(gpr, Series) and isinstance(obj_gpr_column, Series):
954+
return gpr._mgr.references_same_values( # type: ignore[union-attr]
955+
obj_gpr_column._mgr, 0 # type: ignore[arg-type]
956+
)
957+
return False
953958
try:
954959
return gpr is obj[gpr.name]
955960
except (KeyError, IndexError, InvalidIndexError):

pandas/core/internals/managers.py

+9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
cast,
1212
)
1313
import warnings
14+
import weakref
1415

1516
import numpy as np
1617

@@ -258,6 +259,14 @@ def add_references(self, mgr: BaseBlockManager) -> None:
258259
# "Block"; expected "SharedBlock"
259260
blk.refs.add_reference(blk) # type: ignore[arg-type]
260261

262+
def references_same_values(self, mgr: BaseBlockManager, blkno: int) -> bool:
263+
"""
264+
Checks if two blocks from two different block managers reference the
265+
same underlying values.
266+
"""
267+
ref = weakref.ref(self.blocks[blkno])
268+
return ref in mgr.blocks[blkno].refs.referenced_blocks
269+
261270
def get_dtypes(self):
262271
dtypes = np.array([blk.dtype for blk in self.blocks])
263272
return dtypes.take(self.blknos)

0 commit comments

Comments
 (0)