|
7 | 7 | import pytest
|
8 | 8 |
|
9 | 9 | import cudf
|
10 |
| -from cudf.core._compat import PANDAS_GE_200 |
| 10 | +from cudf.core._compat import PANDAS_GE_200, PANDAS_GE_220 |
11 | 11 | from cudf.core.dtypes import CategoricalDtype, Decimal64Dtype, Decimal128Dtype
|
12 | 12 | from cudf.testing._utils import (
|
13 | 13 | INTEGER_TYPES,
|
@@ -160,33 +160,30 @@ def _check_series(expect, got):
|
160 | 160 | def test_dataframe_join_suffix():
|
161 | 161 | np.random.seed(0)
|
162 | 162 |
|
163 |
| - df = cudf.DataFrame() |
164 |
| - for k in "abc": |
165 |
| - df[k] = np.random.randint(0, 5, 5) |
| 163 | + df = cudf.DataFrame(np.random.randint(0, 5, (5, 3)), columns=list("abc")) |
166 | 164 |
|
167 | 165 | left = df.set_index("a")
|
168 | 166 | right = df.set_index("c")
|
169 |
| - with pytest.raises(ValueError) as raises: |
170 |
| - left.join(right) |
171 |
| - raises.match( |
172 |
| - "there are overlapping columns but lsuffix" |
173 |
| - " and rsuffix are not defined" |
| 167 | + msg = ( |
| 168 | + "there are overlapping columns but lsuffix and rsuffix are not defined" |
174 | 169 | )
|
| 170 | + with pytest.raises(ValueError, match=msg): |
| 171 | + left.join(right) |
175 | 172 |
|
176 | 173 | got = left.join(right, lsuffix="_left", rsuffix="_right", sort=True)
|
177 |
| - # Get expected value |
178 |
| - pddf = df.to_pandas() |
179 |
| - expect = pddf.set_index("a").join( |
180 |
| - pddf.set_index("c"), lsuffix="_left", rsuffix="_right" |
| 174 | + expect = left.to_pandas().join( |
| 175 | + right.to_pandas(), |
| 176 | + lsuffix="_left", |
| 177 | + rsuffix="_right", |
| 178 | + sort=PANDAS_GE_220, |
181 | 179 | )
|
182 |
| - # Check |
183 |
| - assert list(expect.columns) == list(got.columns) |
184 |
| - assert_eq(expect.index.values, got.index.values) |
| 180 | + # TODO: Retain result index name |
| 181 | + expect.index.name = None |
| 182 | + assert_eq(got, expect) |
185 | 183 |
|
186 | 184 | got_sorted = got.sort_values(by=["b_left", "c", "b_right"], axis=0)
|
187 | 185 | expect_sorted = expect.sort_values(by=["b_left", "c", "b_right"], axis=0)
|
188 |
| - for k in expect_sorted.columns: |
189 |
| - _check_series(expect_sorted[k].fillna(-1), got_sorted[k].fillna(-1)) |
| 186 | + assert_eq(got_sorted, expect_sorted) |
190 | 187 |
|
191 | 188 |
|
192 | 189 | def test_dataframe_join_cats():
|
|
0 commit comments