Skip to content

Commit f2f09e4

Browse files
itholicHyukjinKwon
authored andcommitted
[SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3. Before: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64') ``` After: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64') ``` This bug is fixed in pandas-dev/pandas#36289. We should follow the behavior of pandas as much as possible. Yes, the result for some cases have duplicates values will change. Unit test. Closes #33634 from itholic/SPARK-36369. Authored-by: itholic <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]> (cherry picked from commit a9f371c) Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent cb075b5 commit f2f09e4

File tree

2 files changed

+68
-66
lines changed

2 files changed

+68
-66
lines changed

python/pyspark/pandas/indexes/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2336,9 +2336,7 @@ def union(
23362336

23372337
sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
23382338
sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
2339-
sdf = sdf_self.union(sdf_other.subtract(sdf_self))
2340-
if isinstance(self, MultiIndex):
2341-
sdf = sdf.drop_duplicates()
2339+
sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
23422340
if sort:
23432341
sdf = sdf.sort(*self._internal.index_spark_column_names)
23442342

python/pyspark/pandas/tests/indexes/test_base.py

+67-63
Original file line numberDiff line numberDiff line change
@@ -1487,21 +1487,20 @@ def test_union(self):
14871487
almost=True,
14881488
)
14891489

1490-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1491-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1492-
pass
1493-
else:
1494-
self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
1495-
self.assert_eq(
1496-
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1497-
pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1498-
almost=True,
1499-
)
1500-
self.assert_eq(
1501-
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1502-
pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1503-
almost=True,
1504-
)
1490+
# Manually create the expected result here since there is a bug in Index.union
1491+
# dropping duplicated values in pandas < 1.3.
1492+
expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
1493+
self.assert_eq(psidx2.union(psidx1), expected)
1494+
self.assert_eq(
1495+
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1496+
expected,
1497+
almost=True,
1498+
)
1499+
self.assert_eq(
1500+
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1501+
expected,
1502+
almost=True,
1503+
)
15051504

15061505
# MultiIndex
15071506
pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
@@ -1513,80 +1512,85 @@ def test_union(self):
15131512
psmidx3 = ps.from_pandas(pmidx3)
15141513
psmidx4 = ps.from_pandas(pmidx4)
15151514

1516-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1517-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1518-
pass
1519-
else:
1520-
self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
1521-
self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
1522-
self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
1523-
self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
1524-
self.assert_eq(
1525-
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1526-
pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1527-
)
1528-
self.assert_eq(
1529-
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1530-
pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1531-
)
1532-
self.assert_eq(
1533-
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1534-
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1535-
)
1536-
self.assert_eq(
1537-
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1538-
pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1539-
)
1515+
# Manually create the expected result here since there is a bug in MultiIndex.union
1516+
# dropping duplicated values in pandas < 1.3.
1517+
expected = pd.MultiIndex.from_tuples(
1518+
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
1519+
)
1520+
self.assert_eq(psmidx1.union(psmidx2), expected)
1521+
self.assert_eq(psmidx2.union(psmidx1), expected)
1522+
self.assert_eq(
1523+
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1524+
expected,
1525+
)
1526+
self.assert_eq(
1527+
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1528+
expected,
1529+
)
1530+
1531+
expected = pd.MultiIndex.from_tuples(
1532+
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
1533+
)
1534+
self.assert_eq(psmidx3.union(psmidx4), expected)
1535+
self.assert_eq(psmidx4.union(psmidx3), expected)
1536+
self.assert_eq(
1537+
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1538+
expected,
1539+
)
1540+
self.assert_eq(
1541+
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1542+
expected,
1543+
)
15401544

1541-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1542-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1543-
pass
15441545
# Testing if the result is correct after sort=False.
15451546
# The `sort` argument is added in pandas 0.24.
1546-
elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
1547+
if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
1548+
# Manually create the expected result here since there is a bug in MultiIndex.union
1549+
# dropping duplicated values in pandas < 1.3.
1550+
expected = pd.MultiIndex.from_tuples(
1551+
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
1552+
)
15471553
self.assert_eq(
15481554
psmidx1.union(psmidx2, sort=False).sort_values(),
1549-
pmidx1.union(pmidx2, sort=False).sort_values(),
1555+
expected,
15501556
)
15511557
self.assert_eq(
15521558
psmidx2.union(psmidx1, sort=False).sort_values(),
1553-
pmidx2.union(pmidx1, sort=False).sort_values(),
1554-
)
1555-
self.assert_eq(
1556-
psmidx3.union(psmidx4, sort=False).sort_values(),
1557-
pmidx3.union(pmidx4, sort=False).sort_values(),
1558-
)
1559-
self.assert_eq(
1560-
psmidx4.union(psmidx3, sort=False).sort_values(),
1561-
pmidx4.union(pmidx3, sort=False).sort_values(),
1559+
expected,
15621560
)
15631561
self.assert_eq(
15641562
psmidx1.union(
15651563
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
15661564
).sort_values(),
1567-
pmidx1.union(
1568-
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
1569-
).sort_values(),
1565+
expected,
15701566
)
15711567
self.assert_eq(
15721568
psmidx2.union(
15731569
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
15741570
).sort_values(),
1575-
pmidx2.union(
1576-
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
1577-
).sort_values(),
1571+
expected,
1572+
)
1573+
1574+
expected = pd.MultiIndex.from_tuples(
1575+
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
1576+
)
1577+
self.assert_eq(
1578+
psmidx3.union(psmidx4, sort=False).sort_values(),
1579+
expected,
1580+
)
1581+
self.assert_eq(
1582+
psmidx4.union(psmidx3, sort=False).sort_values(),
1583+
expected,
15781584
)
15791585
self.assert_eq(
15801586
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
1581-
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
1587+
expected,
15821588
)
15831589
self.assert_eq(
15841590
psmidx4.union(
15851591
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
15861592
).sort_values(),
1587-
pmidx4.union(
1588-
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
1589-
).sort_values(),
1593+
expected,
15901594
)
15911595

15921596
self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))

0 commit comments

Comments
 (0)