Skip to content

Commit a9f371c

Browse files
itholicHyukjinKwon
authored andcommitted
[SPARK-36369][PYTHON] Fix Index.union to follow pandas 1.3
### What changes were proposed in this pull request? This PR proposes fixing the `Index.union` to follow the behavior of pandas 1.3. Before: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2], dtype='int64') ``` After: ```python >>> ps_idx1 = ps.Index([1, 1, 1, 1, 1, 2, 2]) >>> ps_idx2 = ps.Index([1, 1, 2, 2, 2, 2, 2]) >>> ps_idx1.union(ps_idx2) Int64Index([1, 1, 1, 1, 1, 2, 2, 2, 2, 2], dtype='int64') ``` This bug is fixed in pandas-dev/pandas#36289. ### Why are the changes needed? We should follow the behavior of pandas as much as possible. ### Does this PR introduce _any_ user-facing change? Yes, the result for some cases have duplicates values will change. ### How was this patch tested? Unit test. Closes #33634 from itholic/SPARK-36369. Authored-by: itholic <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 5a22f9c commit a9f371c

File tree

2 files changed

+68
-66
lines changed

2 files changed

+68
-66
lines changed

python/pyspark/pandas/indexes/base.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2292,9 +2292,7 @@ def union(
22922292

22932293
sdf_self = self._internal.spark_frame.select(self._internal.index_spark_columns)
22942294
sdf_other = other_idx._internal.spark_frame.select(other_idx._internal.index_spark_columns)
2295-
sdf = sdf_self.union(sdf_other.subtract(sdf_self))
2296-
if isinstance(self, MultiIndex):
2297-
sdf = sdf.drop_duplicates()
2295+
sdf = sdf_self.unionAll(sdf_other).exceptAll(sdf_self.intersectAll(sdf_other))
22982296
if sort:
22992297
sdf = sdf.sort(*self._internal.index_spark_column_names)
23002298

python/pyspark/pandas/tests/indexes/test_base.py

+67-63
Original file line numberDiff line numberDiff line change
@@ -1527,21 +1527,20 @@ def test_union(self):
15271527
almost=True,
15281528
)
15291529

1530-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1531-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1532-
pass
1533-
else:
1534-
self.assert_eq(psidx2.union(psidx1), pidx2.union(pidx1))
1535-
self.assert_eq(
1536-
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1537-
pidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1538-
almost=True,
1539-
)
1540-
self.assert_eq(
1541-
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1542-
pidx2.union(pd.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1543-
almost=True,
1544-
)
1530+
# Manually create the expected result here since there is a bug in Index.union
1531+
# dropping duplicated values in pandas < 1.3.
1532+
expected = pd.Index([1, 2, 3, 3, 3, 4, 4, 4, 5, 6])
1533+
self.assert_eq(psidx2.union(psidx1), expected)
1534+
self.assert_eq(
1535+
psidx2.union([1, 2, 3, 4, 3, 4, 3, 4]),
1536+
expected,
1537+
almost=True,
1538+
)
1539+
self.assert_eq(
1540+
psidx2.union(ps.Series([1, 2, 3, 4, 3, 4, 3, 4])),
1541+
expected,
1542+
almost=True,
1543+
)
15451544

15461545
# MultiIndex
15471546
pmidx1 = pd.MultiIndex.from_tuples([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")])
@@ -1553,80 +1552,85 @@ def test_union(self):
15531552
psmidx3 = ps.from_pandas(pmidx3)
15541553
psmidx4 = ps.from_pandas(pmidx4)
15551554

1556-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1557-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1558-
pass
1559-
else:
1560-
self.assert_eq(psmidx1.union(psmidx2), pmidx1.union(pmidx2))
1561-
self.assert_eq(psmidx2.union(psmidx1), pmidx2.union(pmidx1))
1562-
self.assert_eq(psmidx3.union(psmidx4), pmidx3.union(pmidx4))
1563-
self.assert_eq(psmidx4.union(psmidx3), pmidx4.union(pmidx3))
1564-
self.assert_eq(
1565-
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1566-
pmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1567-
)
1568-
self.assert_eq(
1569-
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1570-
pmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1571-
)
1572-
self.assert_eq(
1573-
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1574-
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1575-
)
1576-
self.assert_eq(
1577-
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1578-
pmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1579-
)
1555+
# Manually create the expected result here since there is a bug in MultiIndex.union
1556+
# dropping duplicated values in pandas < 1.3.
1557+
expected = pd.MultiIndex.from_tuples(
1558+
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
1559+
)
1560+
self.assert_eq(psmidx1.union(psmidx2), expected)
1561+
self.assert_eq(psmidx2.union(psmidx1), expected)
1562+
self.assert_eq(
1563+
psmidx1.union([("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")]),
1564+
expected,
1565+
)
1566+
self.assert_eq(
1567+
psmidx2.union([("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")]),
1568+
expected,
1569+
)
1570+
1571+
expected = pd.MultiIndex.from_tuples(
1572+
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
1573+
)
1574+
self.assert_eq(psmidx3.union(psmidx4), expected)
1575+
self.assert_eq(psmidx4.union(psmidx3), expected)
1576+
self.assert_eq(
1577+
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)]),
1578+
expected,
1579+
)
1580+
self.assert_eq(
1581+
psmidx4.union([(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)]),
1582+
expected,
1583+
)
15801584

1581-
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
1582-
# TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3
1583-
pass
15841585
# Testing if the result is correct after sort=False.
15851586
# The `sort` argument is added in pandas 0.24.
1586-
elif LooseVersion(pd.__version__) >= LooseVersion("0.24"):
1587+
if LooseVersion(pd.__version__) >= LooseVersion("0.24"):
1588+
# Manually create the expected result here since there is a bug in MultiIndex.union
1589+
# dropping duplicated values in pandas < 1.3.
1590+
expected = pd.MultiIndex.from_tuples(
1591+
[("x", "a"), ("x", "a"), ("x", "b"), ("x", "b"), ("x", "c"), ("x", "d")]
1592+
)
15871593
self.assert_eq(
15881594
psmidx1.union(psmidx2, sort=False).sort_values(),
1589-
pmidx1.union(pmidx2, sort=False).sort_values(),
1595+
expected,
15901596
)
15911597
self.assert_eq(
15921598
psmidx2.union(psmidx1, sort=False).sort_values(),
1593-
pmidx2.union(pmidx1, sort=False).sort_values(),
1594-
)
1595-
self.assert_eq(
1596-
psmidx3.union(psmidx4, sort=False).sort_values(),
1597-
pmidx3.union(pmidx4, sort=False).sort_values(),
1598-
)
1599-
self.assert_eq(
1600-
psmidx4.union(psmidx3, sort=False).sort_values(),
1601-
pmidx4.union(pmidx3, sort=False).sort_values(),
1599+
expected,
16021600
)
16031601
self.assert_eq(
16041602
psmidx1.union(
16051603
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
16061604
).sort_values(),
1607-
pmidx1.union(
1608-
[("x", "a"), ("x", "b"), ("x", "c"), ("x", "d")], sort=False
1609-
).sort_values(),
1605+
expected,
16101606
)
16111607
self.assert_eq(
16121608
psmidx2.union(
16131609
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
16141610
).sort_values(),
1615-
pmidx2.union(
1616-
[("x", "a"), ("x", "b"), ("x", "a"), ("x", "b")], sort=False
1617-
).sort_values(),
1611+
expected,
1612+
)
1613+
1614+
expected = pd.MultiIndex.from_tuples(
1615+
[(1, 1), (1, 2), (1, 3), (1, 3), (1, 4), (1, 4), (1, 5), (1, 6)]
1616+
)
1617+
self.assert_eq(
1618+
psmidx3.union(psmidx4, sort=False).sort_values(),
1619+
expected,
1620+
)
1621+
self.assert_eq(
1622+
psmidx4.union(psmidx3, sort=False).sort_values(),
1623+
expected,
16181624
)
16191625
self.assert_eq(
16201626
psmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
1621-
pmidx3.union([(1, 3), (1, 4), (1, 5), (1, 6)], sort=False).sort_values(),
1627+
expected,
16221628
)
16231629
self.assert_eq(
16241630
psmidx4.union(
16251631
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
16261632
).sort_values(),
1627-
pmidx4.union(
1628-
[(1, 1), (1, 2), (1, 3), (1, 4), (1, 3), (1, 4)], sort=False
1629-
).sort_values(),
1633+
expected,
16301634
)
16311635

16321636
self.assertRaises(NotImplementedError, lambda: psidx1.union(psmidx1))

0 commit comments

Comments
 (0)