Skip to content

Commit 30ddd47

Browse files
jbrockmendelSeeminSyed
authored andcommitted
TST: tighten check_categorical=False tests (pandas-dev#32636)
1 parent a496c42 commit 30ddd47

File tree

7 files changed

+84
-78
lines changed

7 files changed

+84
-78
lines changed

pandas/_testing.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,14 @@ def assert_categorical_equal(
824824
left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes",
825825
)
826826
else:
827+
try:
828+
lc = left.categories.sort_values()
829+
rc = right.categories.sort_values()
830+
except TypeError:
831+
# e.g. '<' not supported between instances of 'int' and 'str'
832+
lc, rc = left.categories, right.categories
827833
assert_index_equal(
828-
left.categories.sort_values(),
829-
right.categories.sort_values(),
830-
obj=f"{obj}.categories",
834+
lc, rc, obj=f"{obj}.categories",
831835
)
832836
assert_index_equal(
833837
left.categories.take(left.codes),
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,51 @@
1+
import numpy as np
12
import pytest
23

34
import pandas as pd
45
import pandas._testing as tm
56

67

78
@pytest.mark.parametrize(
8-
"to_replace,value,expected,check_types,check_categorical",
9+
"to_replace,value,expected,flip_categories",
910
[
1011
# one-to-one
11-
(1, 2, [2, 2, 3], True, True),
12-
(1, 4, [4, 2, 3], True, True),
13-
(4, 1, [1, 2, 3], True, True),
14-
(5, 6, [1, 2, 3], True, True),
12+
(1, 2, [2, 2, 3], False),
13+
(1, 4, [4, 2, 3], False),
14+
(4, 1, [1, 2, 3], False),
15+
(5, 6, [1, 2, 3], False),
1516
# many-to-one
16-
([1], 2, [2, 2, 3], True, True),
17-
([1, 2], 3, [3, 3, 3], True, True),
18-
([1, 2], 4, [4, 4, 3], True, True),
19-
((1, 2, 4), 5, [5, 5, 3], True, True),
20-
((5, 6), 2, [1, 2, 3], True, True),
17+
([1], 2, [2, 2, 3], False),
18+
([1, 2], 3, [3, 3, 3], False),
19+
([1, 2], 4, [4, 4, 3], False),
20+
((1, 2, 4), 5, [5, 5, 3], False),
21+
((5, 6), 2, [1, 2, 3], False),
2122
# many-to-many, handled outside of Categorical and results in separate dtype
22-
([1], [2], [2, 2, 3], False, False),
23-
([1, 4], [5, 2], [5, 2, 3], False, False),
23+
([1], [2], [2, 2, 3], True),
24+
([1, 4], [5, 2], [5, 2, 3], True),
2425
# check_categorical sorts categories, which crashes on mixed dtypes
25-
(3, "4", [1, 2, "4"], True, False),
26-
([1, 2, "3"], "5", ["5", "5", 3], True, False),
26+
(3, "4", [1, 2, "4"], False),
27+
([1, 2, "3"], "5", ["5", "5", 3], True),
2728
],
2829
)
29-
def test_replace(to_replace, value, expected, check_types, check_categorical):
30+
def test_replace(to_replace, value, expected, flip_categories):
3031
# GH 31720
32+
stays_categorical = not isinstance(value, list)
33+
3134
s = pd.Series([1, 2, 3], dtype="category")
3235
result = s.replace(to_replace, value)
3336
expected = pd.Series(expected, dtype="category")
3437
s.replace(to_replace, value, inplace=True)
38+
39+
if flip_categories:
40+
expected = expected.cat.set_categories(expected.cat.categories[::-1])
41+
42+
if not stays_categorical:
43+
# the replace call loses categorical dtype
44+
expected = pd.Series(np.asarray(expected))
45+
3546
tm.assert_series_equal(
36-
expected,
37-
result,
38-
check_dtype=check_types,
39-
check_categorical=check_categorical,
40-
check_category_order=False,
47+
expected, result, check_category_order=False,
4148
)
4249
tm.assert_series_equal(
43-
expected,
44-
s,
45-
check_dtype=check_types,
46-
check_categorical=check_categorical,
47-
check_category_order=False,
50+
expected, s, check_category_order=False,
4851
)

pandas/tests/generic/test_frame.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -273,17 +273,13 @@ def test_to_xarray_index_types(self, index):
273273
assert isinstance(result, Dataset)
274274

275275
# idempotency
276-
# categoricals are not preserved
277276
# datetimes w/tz are preserved
278277
# column names are lost
279278
expected = df.copy()
280279
expected["f"] = expected["f"].astype(object)
281280
expected.columns.name = None
282281
tm.assert_frame_equal(
283-
result.to_dataframe(),
284-
expected,
285-
check_index_type=False,
286-
check_categorical=False,
282+
result.to_dataframe(), expected,
287283
)
288284

289285
@td.skip_if_no("xarray", min_version="0.7.0")

pandas/tests/io/pytables/test_store.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from pandas.compat import is_platform_little_endian, is_platform_windows
1414
import pandas.util._test_decorators as td
1515

16-
from pandas.core.dtypes.common import is_categorical_dtype
17-
1816
import pandas as pd
1917
from pandas import (
2018
Categorical,
@@ -1057,18 +1055,7 @@ def test_latin_encoding(self, setup_path, dtype, val):
10571055

10581056
s_nan = ser.replace(nan_rep, np.nan)
10591057

1060-
if is_categorical_dtype(s_nan):
1061-
assert is_categorical_dtype(retr)
1062-
tm.assert_series_equal(
1063-
s_nan, retr, check_dtype=False, check_categorical=False
1064-
)
1065-
else:
1066-
tm.assert_series_equal(s_nan, retr)
1067-
1068-
# FIXME: don't leave commented-out
1069-
# fails:
1070-
# for x in examples:
1071-
# roundtrip(s, nan_rep=b'\xf8\xfc')
1058+
tm.assert_series_equal(s_nan, retr)
10721059

10731060
def test_append_some_nans(self, setup_path):
10741061

pandas/tests/io/test_stata.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,14 @@ def test_categorical_with_stata_missing_values(self, version):
10261026
original.to_stata(path, version=version)
10271027
written_and_read_again = self.read_dta(path)
10281028
res = written_and_read_again.set_index("index")
1029-
tm.assert_frame_equal(res, original, check_categorical=False)
1029+
1030+
expected = original.copy()
1031+
for col in expected:
1032+
cat = expected[col]._values
1033+
new_cats = cat.remove_unused_categories().categories
1034+
cat = cat.set_categories(new_cats, ordered=True)
1035+
expected[col] = cat
1036+
tm.assert_frame_equal(res, expected)
10301037

10311038
@pytest.mark.parametrize("file", ["dta19_115", "dta19_117"])
10321039
def test_categorical_order(self, file):
@@ -1044,15 +1051,17 @@ def test_categorical_order(self, file):
10441051
cols = []
10451052
for is_cat, col, labels, codes in expected:
10461053
if is_cat:
1047-
cols.append((col, pd.Categorical.from_codes(codes, labels)))
1054+
cols.append(
1055+
(col, pd.Categorical.from_codes(codes, labels, ordered=True))
1056+
)
10481057
else:
10491058
cols.append((col, pd.Series(labels, dtype=np.float32)))
10501059
expected = DataFrame.from_dict(dict(cols))
10511060

10521061
# Read with and with out categoricals, ensure order is identical
10531062
file = getattr(self, file)
10541063
parsed = read_stata(file)
1055-
tm.assert_frame_equal(expected, parsed, check_categorical=False)
1064+
tm.assert_frame_equal(expected, parsed)
10561065

10571066
# Check identity of codes
10581067
for col in expected:
@@ -1137,18 +1146,30 @@ def test_read_chunks_117(
11371146
chunk = itr.read(chunksize)
11381147
except StopIteration:
11391148
break
1140-
from_frame = parsed.iloc[pos : pos + chunksize, :]
1149+
from_frame = parsed.iloc[pos : pos + chunksize, :].copy()
1150+
from_frame = self._convert_categorical(from_frame)
11411151
tm.assert_frame_equal(
1142-
from_frame,
1143-
chunk,
1144-
check_dtype=False,
1145-
check_datetimelike_compat=True,
1146-
check_categorical=False,
1152+
from_frame, chunk, check_dtype=False, check_datetimelike_compat=True,
11471153
)
11481154

11491155
pos += chunksize
11501156
itr.close()
11511157

1158+
@staticmethod
1159+
def _convert_categorical(from_frame: DataFrame) -> DataFrame:
1160+
"""
1161+
Emulate the categorical casting behavior we expect from roundtripping.
1162+
"""
1163+
for col in from_frame:
1164+
ser = from_frame[col]
1165+
if is_categorical_dtype(ser.dtype):
1166+
cat = ser._values.remove_unused_categories()
1167+
if cat.categories.dtype == object:
1168+
categories = pd.Index(cat.categories._values)
1169+
cat = cat.set_categories(categories)
1170+
from_frame[col] = cat
1171+
return from_frame
1172+
11521173
def test_iterator(self):
11531174

11541175
fname = self.dta3_117
@@ -1223,13 +1244,10 @@ def test_read_chunks_115(
12231244
chunk = itr.read(chunksize)
12241245
except StopIteration:
12251246
break
1226-
from_frame = parsed.iloc[pos : pos + chunksize, :]
1247+
from_frame = parsed.iloc[pos : pos + chunksize, :].copy()
1248+
from_frame = self._convert_categorical(from_frame)
12271249
tm.assert_frame_equal(
1228-
from_frame,
1229-
chunk,
1230-
check_dtype=False,
1231-
check_datetimelike_compat=True,
1232-
check_categorical=False,
1250+
from_frame, chunk, check_dtype=False, check_datetimelike_compat=True,
12331251
)
12341252

12351253
pos += chunksize

pandas/tests/reshape/merge/test_merge.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -2077,8 +2077,7 @@ def test_merge_equal_cat_dtypes(cat_dtype, reverse):
20772077
}
20782078
).set_index("foo")
20792079

2080-
# Categorical is unordered, so don't check ordering.
2081-
tm.assert_frame_equal(result, expected, check_categorical=False)
2080+
tm.assert_frame_equal(result, expected)
20822081

20832082

20842083
def test_merge_equal_cat_dtypes2():
@@ -2100,8 +2099,7 @@ def test_merge_equal_cat_dtypes2():
21002099
{"left": [1, 2], "right": [3, 2], "foo": Series(["a", "b"]).astype(cat_dtype)}
21012100
).set_index("foo")
21022101

2103-
# Categorical is unordered, so don't check ordering.
2104-
tm.assert_frame_equal(result, expected, check_categorical=False)
2102+
tm.assert_frame_equal(result, expected)
21052103

21062104

21072105
def test_merge_on_cat_and_ext_array():

pandas/tests/series/test_dtypes.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -296,18 +296,18 @@ def cmp(a, b):
296296
# array conversion
297297
tm.assert_almost_equal(np.array(s), np.array(s.values))
298298

299-
# valid conversion
300-
for valid in [
301-
lambda x: x.astype("category"),
302-
lambda x: x.astype(CategoricalDtype()),
303-
lambda x: x.astype("object").astype("category"),
304-
lambda x: x.astype("object").astype(CategoricalDtype()),
305-
]:
306-
307-
result = valid(s)
308-
# compare series values
309-
# internal .categories can't be compared because it is sorted
310-
tm.assert_series_equal(result, s, check_categorical=False)
299+
tm.assert_series_equal(s.astype("category"), s)
300+
tm.assert_series_equal(s.astype(CategoricalDtype()), s)
301+
302+
roundtrip_expected = s.cat.set_categories(
303+
s.cat.categories.sort_values()
304+
).cat.remove_unused_categories()
305+
tm.assert_series_equal(
306+
s.astype("object").astype("category"), roundtrip_expected
307+
)
308+
tm.assert_series_equal(
309+
s.astype("object").astype(CategoricalDtype()), roundtrip_expected
310+
)
311311

312312
# invalid conversion (these are NOT a dtype)
313313
msg = (

0 commit comments

Comments
 (0)