Skip to content

Commit 44913fc

Browse files
authored
1 parent d50c910 commit 44913fc

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

python/cudf/cudf/tests/test_joining.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
import cudf
10-
from cudf.core._compat import PANDAS_GE_200
10+
from cudf.core._compat import PANDAS_GE_200, PANDAS_GE_220
1111
from cudf.core.dtypes import CategoricalDtype, Decimal64Dtype, Decimal128Dtype
1212
from cudf.testing._utils import (
1313
INTEGER_TYPES,
@@ -160,33 +160,30 @@ def _check_series(expect, got):
160160
def test_dataframe_join_suffix():
161161
np.random.seed(0)
162162

163-
df = cudf.DataFrame()
164-
for k in "abc":
165-
df[k] = np.random.randint(0, 5, 5)
163+
df = cudf.DataFrame(np.random.randint(0, 5, (5, 3)), columns=list("abc"))
166164

167165
left = df.set_index("a")
168166
right = df.set_index("c")
169-
with pytest.raises(ValueError) as raises:
170-
left.join(right)
171-
raises.match(
172-
"there are overlapping columns but lsuffix"
173-
" and rsuffix are not defined"
167+
msg = (
168+
"there are overlapping columns but lsuffix and rsuffix are not defined"
174169
)
170+
with pytest.raises(ValueError, match=msg):
171+
left.join(right)
175172

176173
got = left.join(right, lsuffix="_left", rsuffix="_right", sort=True)
177-
# Get expected value
178-
pddf = df.to_pandas()
179-
expect = pddf.set_index("a").join(
180-
pddf.set_index("c"), lsuffix="_left", rsuffix="_right"
174+
expect = left.to_pandas().join(
175+
right.to_pandas(),
176+
lsuffix="_left",
177+
rsuffix="_right",
178+
sort=PANDAS_GE_220,
181179
)
182-
# Check
183-
assert list(expect.columns) == list(got.columns)
184-
assert_eq(expect.index.values, got.index.values)
180+
# TODO: Retain result index name
181+
expect.index.name = None
182+
assert_eq(got, expect)
185183

186184
got_sorted = got.sort_values(by=["b_left", "c", "b_right"], axis=0)
187185
expect_sorted = expect.sort_values(by=["b_left", "c", "b_right"], axis=0)
188-
for k in expect_sorted.columns:
189-
_check_series(expect_sorted[k].fillna(-1), got_sorted[k].fillna(-1))
186+
assert_eq(got_sorted, expect_sorted)
190187

191188

192189
def test_dataframe_join_cats():

0 commit comments

Comments
 (0)