Skip to content

Commit 9e0b655

Browse files
phofllithomas1
andauthored
Adjust merge tests for new string option (#56445)
* BUG: merge not raising for String and numeric merges * BUG: merge not sorting for new string dtype * Add coverage * Fixup tests * Update --------- Co-authored-by: Thomas Li <[email protected]>
1 parent 0ffb7e9 commit 9e0b655

File tree

4 files changed

+42
-25
lines changed

4 files changed

+42
-25
lines changed

pandas/tests/reshape/merge/test_join.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def test_mixed_type_join_with_suffix(self):
630630
df.insert(5, "dt", "foo")
631631

632632
grouped = df.groupby("id")
633-
msg = re.escape("agg function failed [how->mean,dtype->object]")
633+
msg = re.escape("agg function failed [how->mean,dtype->")
634634
with pytest.raises(TypeError, match=msg):
635635
grouped.mean()
636636
mn = grouped.mean(numeric_only=True)
@@ -775,7 +775,7 @@ def test_join_on_tz_aware_datetimeindex(self):
775775
)
776776
result = df1.join(df2.set_index("date"), on="date")
777777
expected = df1.copy()
778-
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"), dtype=object)
778+
expected["vals_2"] = Series([np.nan] * 2 + list("tuv"))
779779
tm.assert_frame_equal(result, expected)
780780

781781
def test_join_datetime_string(self):

pandas/tests/reshape/merge/test_merge.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import numpy as np
99
import pytest
1010

11-
from pandas.core.dtypes.common import is_object_dtype
11+
from pandas.core.dtypes.common import (
12+
is_object_dtype,
13+
is_string_dtype,
14+
)
1215
from pandas.core.dtypes.dtypes import CategoricalDtype
1316

1417
import pandas as pd
@@ -265,14 +268,15 @@ def test_merge_copy(self):
265268
merged["d"] = "peekaboo"
266269
assert (right["d"] == "bar").all()
267270

268-
def test_merge_nocopy(self):
271+
def test_merge_nocopy(self, using_infer_string):
269272
left = DataFrame({"a": 0, "b": 1}, index=range(10))
270273
right = DataFrame({"c": "foo", "d": "bar"}, index=range(10))
271274

272275
merged = merge(left, right, left_index=True, right_index=True, copy=False)
273276

274277
assert np.shares_memory(merged["a"]._values, left["a"]._values)
275-
assert np.shares_memory(merged["d"]._values, right["d"]._values)
278+
if not using_infer_string:
279+
assert np.shares_memory(merged["d"]._values, right["d"]._values)
276280

277281
def test_intelligently_handle_join_key(self):
278282
# #733, be a bit more 1337 about not returning unconsolidated DataFrame
@@ -660,11 +664,13 @@ def test_merge_nan_right(self):
660664
"i1_": {0: 0, 1: np.nan},
661665
"i3": {0: 0.0, 1: np.nan},
662666
None: {0: 0, 1: 0},
663-
}
667+
},
668+
columns=Index(["i1", "i2", "i1_", "i3", None], dtype=object),
664669
)
665670
.set_index(None)
666671
.reset_index()[["i1", "i2", "i1_", "i3"]]
667672
)
673+
result.columns = result.columns.astype("object")
668674
tm.assert_frame_equal(result, expected, check_dtype=False)
669675

670676
def test_merge_nan_right2(self):
@@ -808,7 +814,7 @@ def test_overlapping_columns_error_message(self):
808814

809815
# #2649, #10639
810816
df2.columns = ["key1", "foo", "foo"]
811-
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object'\)"
817+
msg = r"Data columns not unique: Index\(\['foo'\], dtype='object|string'\)"
812818
with pytest.raises(MergeError, match=msg):
813819
merge(df, df2)
814820

@@ -1485,7 +1491,7 @@ def test_different(self, dtype):
14851491
# We allow merging on object and categorical cols and cast
14861492
# categorical cols to object
14871493
result = merge(left, right, on="A")
1488-
assert is_object_dtype(result.A.dtype)
1494+
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)
14891495

14901496
@pytest.mark.parametrize("d2", [np.int64, np.float64, np.float32, np.float16])
14911497
def test_join_multi_dtypes(self, any_int_numpy_dtype, d2):
@@ -1621,7 +1627,7 @@ def test_merge_incompat_dtypes_are_ok(self, df1_vals, df2_vals):
16211627
result = merge(df1, df2, on=["A"])
16221628
assert is_object_dtype(result.A.dtype)
16231629
result = merge(df2, df1, on=["A"])
1624-
assert is_object_dtype(result.A.dtype)
1630+
assert is_object_dtype(result.A.dtype) or is_string_dtype(result.A.dtype)
16251631

16261632
@pytest.mark.parametrize(
16271633
"df1_vals, df2_vals",
@@ -1850,25 +1856,27 @@ def right():
18501856

18511857

18521858
class TestMergeCategorical:
1853-
def test_identical(self, left):
1859+
def test_identical(self, left, using_infer_string):
18541860
# merging on the same, should preserve dtypes
18551861
merged = merge(left, left, on="X")
18561862
result = merged.dtypes.sort_index()
1863+
dtype = np.dtype("O") if not using_infer_string else "string"
18571864
expected = Series(
1858-
[CategoricalDtype(categories=["foo", "bar"]), np.dtype("O"), np.dtype("O")],
1865+
[CategoricalDtype(categories=["foo", "bar"]), dtype, dtype],
18591866
index=["X", "Y_x", "Y_y"],
18601867
)
18611868
tm.assert_series_equal(result, expected)
18621869

1863-
def test_basic(self, left, right):
1870+
def test_basic(self, left, right, using_infer_string):
18641871
# we have matching Categorical dtypes in X
18651872
# so should preserve the merged column
18661873
merged = merge(left, right, on="X")
18671874
result = merged.dtypes.sort_index()
1875+
dtype = np.dtype("O") if not using_infer_string else "string"
18681876
expected = Series(
18691877
[
18701878
CategoricalDtype(categories=["foo", "bar"]),
1871-
np.dtype("O"),
1879+
dtype,
18721880
np.dtype("int64"),
18731881
],
18741882
index=["X", "Y", "Z"],
@@ -1972,16 +1980,17 @@ def test_multiindex_merge_with_unordered_categoricalindex(self, ordered):
19721980
).set_index(["id", "p"])
19731981
tm.assert_frame_equal(result, expected)
19741982

1975-
def test_other_columns(self, left, right):
1983+
def test_other_columns(self, left, right, using_infer_string):
19761984
# non-merge columns should preserve if possible
19771985
right = right.assign(Z=right.Z.astype("category"))
19781986

19791987
merged = merge(left, right, on="X")
19801988
result = merged.dtypes.sort_index()
1989+
dtype = np.dtype("O") if not using_infer_string else "string"
19811990
expected = Series(
19821991
[
19831992
CategoricalDtype(categories=["foo", "bar"]),
1984-
np.dtype("O"),
1993+
dtype,
19851994
CategoricalDtype(categories=[1, 2]),
19861995
],
19871996
index=["X", "Y", "Z"],
@@ -2000,7 +2009,9 @@ def test_other_columns(self, left, right):
20002009
lambda x: x.astype(CategoricalDtype(ordered=True)),
20012010
],
20022011
)
2003-
def test_dtype_on_merged_different(self, change, join_type, left, right):
2012+
def test_dtype_on_merged_different(
2013+
self, change, join_type, left, right, using_infer_string
2014+
):
20042015
# our merging columns, X now has 2 different dtypes
20052016
# so we must be object as a result
20062017

@@ -2012,9 +2023,8 @@ def test_dtype_on_merged_different(self, change, join_type, left, right):
20122023
merged = merge(left, right, on="X", how=join_type)
20132024

20142025
result = merged.dtypes.sort_index()
2015-
expected = Series(
2016-
[np.dtype("O"), np.dtype("O"), np.dtype("int64")], index=["X", "Y", "Z"]
2017-
)
2026+
dtype = np.dtype("O") if not using_infer_string else "string"
2027+
expected = Series([dtype, dtype, np.dtype("int64")], index=["X", "Y", "Z"])
20182028
tm.assert_series_equal(result, expected)
20192029

20202030
def test_self_join_multiple_categories(self):
@@ -2471,7 +2481,7 @@ def test_merge_multiindex_columns():
24712481
expected_index = MultiIndex.from_tuples(tuples, names=["outer", "inner"])
24722482
expected = DataFrame(columns=expected_index)
24732483

2474-
tm.assert_frame_equal(result, expected)
2484+
tm.assert_frame_equal(result, expected, check_dtype=False)
24752485

24762486

24772487
def test_merge_datetime_upcast_dtype():

pandas/tests/reshape/merge/test_merge_asof.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -3063,8 +3063,11 @@ def test_on_float_by_int(self):
30633063

30643064
tm.assert_frame_equal(result, expected)
30653065

3066-
def test_merge_datatype_error_raises(self):
3067-
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"
3066+
def test_merge_datatype_error_raises(self, using_infer_string):
3067+
if using_infer_string:
3068+
msg = "incompatible merge keys"
3069+
else:
3070+
msg = r"Incompatible merge dtype, .*, both sides must have numeric dtype"
30683071

30693072
left = pd.DataFrame({"left_val": [1, 5, 10], "a": ["a", "b", "c"]})
30703073
right = pd.DataFrame({"right_val": [1, 2, 3, 6, 7], "a": [1, 2, 3, 6, 7]})
@@ -3116,7 +3119,7 @@ def test_merge_on_nans(self, func, side):
31163119
else:
31173120
merge_asof(df, df_null, on="a")
31183121

3119-
def test_by_nullable(self, any_numeric_ea_dtype):
3122+
def test_by_nullable(self, any_numeric_ea_dtype, using_infer_string):
31203123
# Note: this test passes if instead of using pd.array we use
31213124
# np.array([np.nan, 1]). Other than that, I (@jbrockmendel)
31223125
# have NO IDEA what the expected behavior is.
@@ -3158,6 +3161,8 @@ def test_by_nullable(self, any_numeric_ea_dtype):
31583161
}
31593162
)
31603163
expected["value_y"] = np.array([np.nan, np.nan, np.nan], dtype=object)
3164+
if using_infer_string:
3165+
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
31613166
tm.assert_frame_equal(result, expected)
31623167

31633168
def test_merge_by_col_tz_aware(self):
@@ -3183,7 +3188,7 @@ def test_merge_by_col_tz_aware(self):
31833188
)
31843189
tm.assert_frame_equal(result, expected)
31853190

3186-
def test_by_mixed_tz_aware(self):
3191+
def test_by_mixed_tz_aware(self, using_infer_string):
31873192
# GH 26649
31883193
left = pd.DataFrame(
31893194
{
@@ -3207,6 +3212,8 @@ def test_by_mixed_tz_aware(self):
32073212
columns=["by_col1", "by_col2", "on_col", "value_x"],
32083213
)
32093214
expected["value_y"] = np.array([np.nan], dtype=object)
3215+
if using_infer_string:
3216+
expected["value_y"] = expected["value_y"].astype("string[pyarrow_numpy]")
32103217
tm.assert_frame_equal(result, expected)
32113218

32123219
@pytest.mark.parametrize("dtype", ["float64", "int16", "m8[ns]", "M8[us]"])

pandas/tests/reshape/merge/test_multi.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def test_join_multi_levels_outer(self, portfolio, household, expected):
635635
axis=0,
636636
sort=True,
637637
).reindex(columns=expected.columns)
638-
tm.assert_frame_equal(result, expected)
638+
tm.assert_frame_equal(result, expected, check_index_type=False)
639639

640640
def test_join_multi_levels_invalid(self, portfolio, household):
641641
portfolio = portfolio.copy()

0 commit comments

Comments
 (0)