Skip to content

Commit 74fa740

Browse files
lithomas1phofl
andauthored
Backport PR #56445: Adjust merge tests for new string option (#56938)
Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 988c3a4 commit 74fa740

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

pandas/tests/reshape/merge/test_join.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def test_mixed_type_join_with_suffix(self):
631631
df.insert(5, "dt", "foo")
632632

633633
grouped = df.groupby("id")
634-
msg = re.escape("agg function failed [how->mean,dtype->object]")
634+
msg = re.escape("agg function failed [how->mean,dtype->")
635635
with pytest.raises(TypeError, match=msg):
636636
grouped.mean()
637637
mn = grouped.mean(numeric_only=True)
@@ -776,7 +776,7 @@ def test_join_on_tz_aware_datetimeindex(self):
776776
)
777777
result = df1.join(df2.set_index("date"), on="date")
778778
expected = df1.copy()
779-
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"), dtype=object)
779+
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"))
780780
tm.assert_frame_equal(result, expected)
781781

782782
def test_join_datetime_string(self):

pandas/tests/reshape/merge/test_merge.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas.core.dtypes.common import is_object_dtype
11+
from pandas.core.dtypes.common import (
12+
is_object_dtype,
13+
is_string_dtype,
14+
)
1215
from pandas.core.dtypes.dtypes import CategoricalDtype
1316

1417
import pandas as pd
@@ -316,14 +319,15 @@ def test_merge_copy(self):
316319
merged["d"] = "peekaboo"
317320
assert (right["d"] == "bar").all()
318321

319-
def test_merge_nocopy(self, using_array_manager):
322+
def test_merge_nocopy(self, using_array_manager, using_infer_string):
320323
left = DataFrame({"a": 0, "b": 1}, index=range(10))
321324
right = DataFrame({"c": "foo", "d": "bar"}, index=range(10))
322325

323326
merged = merge(left, right, left_index=True, right_index=True, copy=False)
324327

325328
assert np.shares_memory(merged["a"]._values, left["a"]._values)
326-
assert np.shares_memory(merged["d"]._values, right["d"]._values)
329+
if not using_infer_string:
330+
assert np.shares_memory(merged["d"]._values, right["d"]._values)
327331

328332
def test_intelligently_handle_join_key(self):
329333
# #733, be a bit more 1337 about not returning unconsolidated DataFrame
@@ -667,11 +671,13 @@ def test_merge_nan_right(self):
667671
"i1_": {0: 0, 1: np.nan},
668672
"i3": {0: 0.0, 1: np.nan},
669673
None: {0: 0, 1: 0},
670-
}
674+
},
675+
columns=Index(["i1", "i2", "i1_", "i3", None], dtype=object),
671676
)
672677
.set_index(None)
673678
.reset_index()[["i1", "i2", "i1_", "i3"]]
674679
)
680+
result.columns = result.columns.astype("object")
675681
tm.assert_frame_equal(result, expected, check_dtype=False)
676682

677683
def test_merge_nan_right2(self):
@@ -820,7 +826,7 @@ def test_overlapping_columns_error_message(self):
820826

821827
# #2649, #10639
822828
df2.columns = ["key1", "foo", "foo"]
823-
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
829+
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object|string'\)"
824830
with pytest.raises(MergeError, match=msg):
825831
merge(df, df2)
826832

@@ -1498,7 +1504,7 @@ def test_different(self, right_vals):
14981504
# We allow merging on object and categorical cols and cast
14991505
# categorical cols to object
15001506
result = merge(left, right, on="A")
1501-
assert is_object_dtype(result.A.dtype)
1507+
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)
15021508

15031509
@pytest.mark.parametrize(
15041510
"d1", [np.int64, np.int32, np.intc, np.int16, np.int8, np.uint8]
@@ -1637,7 +1643,7 @@ def test_merge_incompat_dtypes_are_ok(self, df1_vals, df2_vals):
16371643
result = merge(df1, df2, on=["A"])
16381644
assert is_object_dtype(result.A.dtype)
16391645
result = merge(df2, df1, on=["A"])
1640-
assert is_object_dtype(result.A.dtype)
1646+
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)
16411647

16421648
@pytest.mark.parametrize(
16431649
"df1_vals, df2_vals",
@@ -1867,25 +1873,27 @@ def right():
18671873

18681874

18691875
class TestMergeCategorical:
1870-
def test_identical(self, left):
1876+
def test_identical(self, left, using_infer_string):
18711877
# merging on the same, should preserve dtypes
18721878
merged = merge(left, left, on="X")
18731879
result = merged.dtypes.sort_index()
1880+
dtype = np.dtype("O") if not using_infer_string else "string"
18741881
expected = Series(
1875-
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
1882+
[CategoricalDtype(categories=["foo", "bar"]), dtype, dtype],
18761883
index=["X", "Y_x", "Y_y"],
18771884
)
18781885
tm.assert_series_equal(result, expected)
18791886

1880-
def test_basic(self, left, right):
1887+
def test_basic(self, left, right, using_infer_string):
18811888
# we have matching Categorical dtypes in X
18821889
# so should preserve the merged column
18831890
merged = merge(left, right, on="X")
18841891
result = merged.dtypes.sort_index()
1892+
dtype = np.dtype("O") if not using_infer_string else "string"
18851893
expected = Series(
18861894
[
18871895
CategoricalDtype(categories=["foo", "bar"]),
1888-
np.dtype("O"),
1896+
dtype,
18891897
np.dtype("int64"),
18901898
],
18911899
index=["X", "Y", "Z"],
@@ -1989,16 +1997,17 @@ def test_multiindex_merge_with_unordered_categoricalindex(self, ordered):
19891997
).set_index(["id", "p"])
19901998
tm.assert_frame_equal(result, expected)
19911999

1992-
def test_other_columns(self, left, right):
2000+
def test_other_columns(self, left, right, using_infer_string):
19932001
# non-merge columns should preserve if possible
19942002
right = right.assign(Z=right.Z.astype("category"))
19952003

19962004
merged = merge(left, right, on="X")
19972005
result = merged.dtypes.sort_index()
2006+
dtype = np.dtype("O") if not using_infer_string else "string"
19982007
expected = Series(
19992008
[
20002009
CategoricalDtype(categories=["foo", "bar"]),
2001-
np.dtype("O"),
2010+
dtype,
20022011
CategoricalDtype(categories=[1, 2]),
20032012
],
20042013
index=["X", "Y", "Z"],
@@ -2017,7 +2026,9 @@ def test_other_columns(self, left, right):
20172026
lambda x: x.astype(CategoricalDtype(ordered=True)),
20182027
],
20192028
)
2020-
def test_dtype_on_merged_different(self, change, join_type, left, right):
2029+
def test_dtype_on_merged_different(
2030+
self, change, join_type, left, right, using_infer_string
2031+
):
20212032
# our merging columns, X now has 2 different dtypes
20222033
# so we must be object as a result
20232034

@@ -2029,9 +2040,8 @@ def test_dtype_on_merged_different(self, change, join_type, left, right):
20292040
merged = merge(left, right, on="X", how=join_type)
20302041

20312042
result = merged.dtypes.sort_index()
2032-
expected = Series(
2033-
[np.dtype("O"), np.dtype("O"), np.dtype("int64")], index=["X", "Y", "Z"]
2034-
)
2043+
dtype = np.dtype("O") if not using_infer_string else "string"
2044+
expected = Series([dtype, dtype, np.dtype("int64")], index=["X", "Y", "Z"])
20352045
tm.assert_series_equal(result, expected)
20362046

20372047
def test_self_join_multiple_categories(self):
@@ -2499,7 +2509,7 @@ def test_merge_multiindex_columns():
24992509
expected_index = MultiIndex.from_tuples(tuples, names=["outer", "inner"])
25002510
expected = DataFrame(columns=expected_index)
25012511

2502-
tm.assert_frame_equal(result, expected)
2512+
tm.assert_frame_equal(result, expected, check_dtype=False)
25032513

25042514

25052515
def test_merge_datetime_upcast_dtype():

pandas/tests/reshape/merge/test_merge_asof.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -3081,8 +3081,11 @@ def test_on_float_by_int(self):
30813081

30823082
tm.assert_frame_equal(result, expected)
30833083

3084-
def test_merge_datatype_error_raises(self):
3085-
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"
3084+
def test_merge_datatype_error_raises(self, using_infer_string):
3085+
if using_infer_string:
3086+
msg = "incompatible merge keys"
3087+
else:
3088+
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"
30863089

30873090
left = pd.DataFrame({"left_val": [1, 5, 10], "a": ["a", "b", "c"]})
30883091
right = pd.DataFrame({"right_val": [1, 2, 3, 6, 7], "a": [1, 2, 3, 6, 7]})
@@ -3134,7 +3137,7 @@ def test_merge_on_nans(self, func, side):
31343137
else:
31353138
merge_asof(df, df_null, on="a")
31363139

3137-
def test_by_nullable(self, any_numeric_ea_dtype):
3140+
def test_by_nullable(self, any_numeric_ea_dtype, using_infer_string):
31383141
# Note: this test passes if instead of using pd.array we use
31393142
# np.array([np.nan, 1]). Other than that, I (@jbrockmendel)
31403143
# have NO IDEA what the expected behavior is.
@@ -3176,6 +3179,8 @@ def test_by_nullable(self, any_numeric_ea_dtype):
31763179
}
31773180
)
31783181
expected["value_y"] = np.array([np.nan, np.nan, np.nan], dtype=object)
3182+
if using_infer_string:
3183+
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
31793184
tm.assert_frame_equal(result, expected)
31803185

31813186
def test_merge_by_col_tz_aware(self):
@@ -3201,7 +3206,7 @@ def test_merge_by_col_tz_aware(self):
32013206
)
32023207
tm.assert_frame_equal(result, expected)
32033208

3204-
def test_by_mixed_tz_aware(self):
3209+
def test_by_mixed_tz_aware(self, using_infer_string):
32053210
# GH 26649
32063211
left = pd.DataFrame(
32073212
{
@@ -3225,6 +3230,8 @@ def test_by_mixed_tz_aware(self):
32253230
columns=["by_col1", "by_col2", "on_col", "value_x"],
32263231
)
32273232
expected["value_y"] = np.array([np.nan], dtype=object)
3233+
if using_infer_string:
3234+
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
32283235
tm.assert_frame_equal(result, expected)
32293236

32303237
@pytest.mark.parametrize("dtype", ["float64", "int16", "m8[ns]", "M8[us]"])

pandas/tests/reshape/merge/test_multi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def test_join_multi_levels_outer(self, portfolio, household, expected):
639639
axis=0,
640640
sort=True,
641641
).reindex(columns=expected.columns)
642-
tm.assert_frame_equal(result, expected)
642+
tm.assert_frame_equal(result, expected, check_index_type=False)
643643

644644
def test_join_multi_levels_invalid(self, portfolio, household):
645645
portfolio = portfolio.copy()

0 commit comments

Comments
 (0)