From 9189c756119b5dcc5c3c671cef5249f5248f49e8 Mon Sep 17 00:00:00 2001 From: phofl Date: Thu, 15 Apr 2021 22:24:00 +0200 Subject: [PATCH 1/2] Fix typing for union_with_duplicates --- pandas/core/algorithms.py | 8 ++++---- pandas/core/indexes/base.py | 6 +----- pandas/tests/test_algos.py | 14 +++++++++----- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 9e2dd846f0379..1efd596419791 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1876,21 +1876,21 @@ def _sort_tuples(values: np.ndarray) -> np.ndarray: return values[indexer] -def union_with_duplicates(lvals: np.ndarray, rvals: np.ndarray) -> np.ndarray: +def union_with_duplicates(lvals: ArrayLike, rvals: ArrayLike) -> ArrayLike: """ Extracts the union from lvals and rvals with respect to duplicates and nans in both arrays. Parameters ---------- - lvals: np.ndarray + lvals: ArrayLike left values which is ordered in front. - rvals: np.ndarray + rvals: ArrayLike right values ordered after lvals. Returns ------- - np.ndarray containing the unsorted union of both arrays + ArrayLike containing the unsorted union of both arrays """ indexer = [] l_count = value_counts(lvals, dropna=False) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 705a279638097..8275a38c19312 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2983,11 +2983,7 @@ def _union(self, other: Index, sort): elif not other.is_unique: # other has duplicates - # error: Argument 1 to "union_with_duplicates" has incompatible type - # "Union[ExtensionArray, ndarray]"; expected "ndarray" - # error: Argument 2 to "union_with_duplicates" has incompatible type - # "Union[ExtensionArray, ndarray]"; expected "ndarray" - result = algos.union_with_duplicates(lvals, rvals) # type: ignore[arg-type] + result = algos.union_with_duplicates(lvals, rvals) return _maybe_try_sort(result, sort) # Self may have duplicates diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index 127baae6e9352..ef2c47ecdb16c 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -2418,10 +2418,14 @@ def test_diff_low_precision_int(self, dtype): tm.assert_numpy_array_equal(result, expected) -def test_union_with_duplicates(): +@pytest.mark.parametrize("op", [np.array, pd.array]) +def test_union_with_duplicates(op): # GH#36289 - lvals = np.array([3, 1, 3, 4]) - rvals = np.array([2, 3, 1, 1]) + lvals = op([3, 1, 3, 4]) + rvals = op([2, 3, 1, 1]) result = algos.union_with_duplicates(lvals, rvals) - expected = np.array([3, 3, 1, 1, 4, 2]) - tm.assert_numpy_array_equal(result, expected) + expected = op([3, 3, 1, 1, 4, 2]) + if isinstance(expected, np.ndarray): + tm.assert_numpy_array_equal(result, expected) + else: + tm.assert_extension_array_equal(result, expected) From df9664d8a15d24c06d9fd261b611ee5f133f3c50 Mon Sep 17 00:00:00 2001 From: phofl Date: Tue, 20 Apr 2021 22:22:49 +0200 Subject: [PATCH 2/2] Change test --- pandas/core/indexes/base.py | 1 - pandas/tests/test_algos.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index a2c23464c60d2..58f5ca3de5dce 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3005,7 +3005,6 @@ def _union(self, other: Index, sort): elif not other.is_unique: # other has duplicates - result = algos.union_with_duplicates(lvals, rvals) return _maybe_try_sort(result, sort) diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index 55cb7c028fcac..75fc7a782772a 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -2423,9 +2423,11 @@ def test_union_with_duplicates(op): # GH#36289 lvals = op([3, 1, 3, 4]) rvals = op([2, 3, 1, 1]) - result = algos.union_with_duplicates(lvals, rvals) expected = op([3, 3, 1, 1, 4, 2]) if isinstance(expected, np.ndarray): + result = algos.union_with_duplicates(lvals, rvals) tm.assert_numpy_array_equal(result, expected) else: + with tm.assert_produces_warning(RuntimeWarning): + result = algos.union_with_duplicates(lvals, rvals) tm.assert_extension_array_equal(result, expected)