Skip to content

Commit dd3c2a2

Browse files
BUG/TST (string dtype): fix and update tests for Stata IO (pandas-dev#60130)
(cherry picked from commit e7d54a5)
1 parent fa7c87b commit dd3c2a2

File tree

2 files changed

+48
-39
lines changed

2 files changed

+48
-39
lines changed

pandas/io/stata.py

+5
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,11 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
605605
if getattr(data[col].dtype, "numpy_dtype", None) is not None:
606606
data[col] = data[col].astype(data[col].dtype.numpy_dtype)
607607
elif is_string_dtype(data[col].dtype):
608+
# TODO could avoid converting string dtype to object here,
609+
# but handle string dtype in _encode_strings
608610
data[col] = data[col].astype("object")
611+
# generate_table checks for None values
612+
data.loc[data[col].isna(), col] = None
609613

610614
dtype = data[col].dtype
611615
empty_df = data.shape[0] == 0
@@ -2671,6 +2675,7 @@ def _encode_strings(self) -> None:
26712675
continue
26722676
column = self.data[col]
26732677
dtype = column.dtype
2678+
# TODO could also handle string dtype here specifically
26742679
if dtype.type is np.object_:
26752680
inferred_dtype = infer_dtype(column, skipna=True)
26762681
if not ((inferred_dtype == "string") or len(column) == 0):

pandas/tests/io/test_stata.py

+43-39
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
import numpy as np
1212
import pytest
1313

14-
from pandas._config import using_string_dtype
15-
1614
import pandas.util._test_decorators as td
1715

1816
import pandas as pd
@@ -347,9 +345,8 @@ def test_write_dta6(self, datapath):
347345
check_index_type=False,
348346
)
349347

350-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
351348
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
352-
def test_read_write_dta10(self, version):
349+
def test_read_write_dta10(self, version, using_infer_string):
353350
original = DataFrame(
354351
data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]],
355352
columns=["string", "object", "integer", "floating", "datetime"],
@@ -362,12 +359,17 @@ def test_read_write_dta10(self, version):
362359
with tm.ensure_clean() as path:
363360
original.to_stata(path, convert_dates={"datetime": "tc"}, version=version)
364361
written_and_read_again = self.read_dta(path)
365-
# original.index is np.int32, read index is np.int64
366-
tm.assert_frame_equal(
367-
written_and_read_again.set_index("index"),
368-
original,
369-
check_index_type=False,
370-
)
362+
363+
expected = original.copy()
364+
if using_infer_string:
365+
expected["object"] = expected["object"].astype("str")
366+
367+
# original.index is np.int32, read index is np.int64
368+
tm.assert_frame_equal(
369+
written_and_read_again.set_index("index"),
370+
expected,
371+
check_index_type=False,
372+
)
371373

372374
def test_stata_doc_examples(self):
373375
with tm.ensure_clean() as path:
@@ -1153,7 +1155,6 @@ def test_categorical_ordering(self, file, datapath):
11531155
assert parsed[col].cat.ordered
11541156
assert not parsed_unordered[col].cat.ordered
11551157

1156-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
11571158
@pytest.mark.filterwarnings("ignore::UserWarning")
11581159
@pytest.mark.parametrize(
11591160
"file",
@@ -1215,6 +1216,10 @@ def _convert_categorical(from_frame: DataFrame) -> DataFrame:
12151216
if cat.categories.dtype == object:
12161217
categories = pd.Index._with_infer(cat.categories._values)
12171218
cat = cat.set_categories(categories)
1219+
elif cat.categories.dtype == "string" and len(cat.categories) == 0:
1220+
# if the read categories are empty, it comes back as object dtype
1221+
categories = cat.categories.astype(object)
1222+
cat = cat.set_categories(categories)
12181223
from_frame[col] = cat
12191224
return from_frame
12201225

@@ -1244,7 +1249,6 @@ def test_iterator(self, datapath):
12441249
from_chunks = pd.concat(itr)
12451250
tm.assert_frame_equal(parsed, from_chunks)
12461251

1247-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
12481252
@pytest.mark.filterwarnings("ignore::UserWarning")
12491253
@pytest.mark.parametrize(
12501254
"file",
@@ -1548,12 +1552,11 @@ def test_inf(self, infval):
15481552
with tm.ensure_clean() as path:
15491553
df.to_stata(path)
15501554

1551-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
15521555
def test_path_pathlib(self):
15531556
df = DataFrame(
15541557
1.1 * np.arange(120).reshape((30, 4)),
1555-
columns=pd.Index(list("ABCD"), dtype=object),
1556-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1558+
columns=pd.Index(list("ABCD")),
1559+
index=pd.Index([f"i-{i}" for i in range(30)]),
15571560
)
15581561
df.index.name = "index"
15591562
reader = lambda x: read_stata(x).set_index("index")
@@ -1584,13 +1587,12 @@ def test_value_labels_iterator(self, write_index):
15841587
value_labels = dta_iter.value_labels()
15851588
assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}}
15861589

1587-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
15881590
def test_set_index(self):
15891591
# GH 17328
15901592
df = DataFrame(
15911593
1.1 * np.arange(120).reshape((30, 4)),
1592-
columns=pd.Index(list("ABCD"), dtype=object),
1593-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1594+
columns=pd.Index(list("ABCD")),
1595+
index=pd.Index([f"i-{i}" for i in range(30)]),
15941596
)
15951597
df.index.name = "index"
15961598
with tm.ensure_clean() as path:
@@ -1618,8 +1620,7 @@ def test_date_parsing_ignores_format_details(self, column, datapath):
16181620
formatted = df.loc[0, column + "_fmt"]
16191621
assert unformatted == formatted
16201622

1621-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
1622-
def test_writer_117(self):
1623+
def test_writer_117(self, using_infer_string):
16231624
original = DataFrame(
16241625
data=[
16251626
[
@@ -1682,13 +1683,17 @@ def test_writer_117(self):
16821683
version=117,
16831684
)
16841685
written_and_read_again = self.read_dta(path)
1685-
# original.index is np.int32, read index is np.int64
1686-
tm.assert_frame_equal(
1687-
written_and_read_again.set_index("index"),
1688-
original,
1689-
check_index_type=False,
1690-
)
1691-
tm.assert_frame_equal(original, copy)
1686+
1687+
expected = original[:]
1688+
if using_infer_string:
1689+
# object dtype (with only strings/None) comes back as string dtype
1690+
expected["object"] = expected["object"].astype("str")
1691+
1692+
tm.assert_frame_equal(
1693+
written_and_read_again.set_index("index"),
1694+
expected,
1695+
)
1696+
tm.assert_frame_equal(original, copy)
16921697

16931698
def test_convert_strl_name_swap(self):
16941699
original = DataFrame(
@@ -1725,15 +1730,14 @@ def test_invalid_date_conversion(self):
17251730
with pytest.raises(ValueError, match=msg):
17261731
original.to_stata(path, convert_dates={"wrong_name": "tc"})
17271732

1728-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
17291733
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
17301734
def test_nonfile_writing(self, version):
17311735
# GH 21041
17321736
bio = io.BytesIO()
17331737
df = DataFrame(
17341738
1.1 * np.arange(120).reshape((30, 4)),
1735-
columns=pd.Index(list("ABCD"), dtype=object),
1736-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1739+
columns=pd.Index(list("ABCD")),
1740+
index=pd.Index([f"i-{i}" for i in range(30)]),
17371741
)
17381742
df.index.name = "index"
17391743
with tm.ensure_clean() as path:
@@ -1744,13 +1748,12 @@ def test_nonfile_writing(self, version):
17441748
reread = read_stata(path, index_col="index")
17451749
tm.assert_frame_equal(df, reread)
17461750

1747-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
17481751
def test_gzip_writing(self):
17491752
# writing version 117 requires seek and cannot be used with gzip
17501753
df = DataFrame(
17511754
1.1 * np.arange(120).reshape((30, 4)),
1752-
columns=pd.Index(list("ABCD"), dtype=object),
1753-
index=pd.Index([f"i-{i}" for i in range(30)], dtype=object),
1755+
columns=pd.Index(list("ABCD")),
1756+
index=pd.Index([f"i-{i}" for i in range(30)]),
17541757
)
17551758
df.index.name = "index"
17561759
with tm.ensure_clean() as path:
@@ -1777,8 +1780,7 @@ def test_unicode_dta_118(self, datapath):
17771780

17781781
tm.assert_frame_equal(unicode_df, expected)
17791782

1780-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
1781-
def test_mixed_string_strl(self):
1783+
def test_mixed_string_strl(self, using_infer_string):
17821784
# GH 23633
17831785
output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}]
17841786
output = DataFrame(output)
@@ -1796,7 +1798,10 @@ def test_mixed_string_strl(self):
17961798
path, write_index=False, convert_strl=["mixed"], version=117
17971799
)
17981800
reread = read_stata(path)
1799-
expected = output.fillna("")
1801+
expected = output.copy()
1802+
if using_infer_string:
1803+
expected["mixed"] = expected["mixed"].astype("str")
1804+
expected = expected.fillna("")
18001805
tm.assert_frame_equal(reread, expected)
18011806

18021807
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
@@ -1875,7 +1880,7 @@ def test_stata_119(self, datapath):
18751880
reader._ensure_open()
18761881
assert reader._nvar == 32999
18771882

1878-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
1883+
@pytest.mark.filterwarnings("ignore:Downcasting behavior:FutureWarning")
18791884
@pytest.mark.parametrize("version", [118, 119, None])
18801885
def test_utf8_writer(self, version):
18811886
cat = pd.Categorical(["a", "β", "ĉ"], ordered=True)
@@ -2143,14 +2148,13 @@ def test_iterator_errors(datapath, chunksize):
21432148
pass
21442149

21452150

2146-
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
21472151
def test_iterator_value_labels():
21482152
# GH 31544
21492153
values = ["c_label", "b_label"] + ["a_label"] * 500
21502154
df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
21512155
with tm.ensure_clean() as path:
21522156
df.to_stata(path, write_index=False)
2153-
expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
2157+
expected = pd.Index(["a_label", "b_label", "c_label"])
21542158
with read_stata(path, chunksize=100) as reader:
21552159
for j, chunk in enumerate(reader):
21562160
for i in range(2):

0 commit comments

Comments
 (0)