Skip to content

Commit 035e1fe

Browse files
authored
BUG/ENH: Improve categorical construction when using the iterator in StataReader (#34128)
* BUG/ENH: Correct categorical on iterators Return categoricals with the same categories if possible when reading data through an interator. Warn if not possible. closes #31544
1 parent 7388ee5 commit 035e1fe

File tree

4 files changed

+129
-8
lines changed

4 files changed

+129
-8
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -949,6 +949,7 @@ I/O
949949
- Bug in :meth:`~DataFrame.read_feather` was raising an `ArrowIOError` when reading an s3 or http file path (:issue:`29055`)
950950
- Bug in :meth:`~DataFrame.to_excel` could not handle the column name `render` and was raising an ``KeyError`` (:issue:`34331`)
951951
- Bug in :meth:`~SQLDatabase.execute` was raising a ``ProgrammingError`` for some DB-API drivers when the SQL statement contained the `%` character and no parameters were present (:issue:`34211`)
952+
- Bug in :meth:`~pandas.io.stata.StataReader` which resulted in categorical variables with difference dtypes when reading data using an iterator. (:issue:`31544`)
952953

953954
Plotting
954955
^^^^^^^^

pandas/io/stata.py

+68-8
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@
106106
iterator : bool, default False
107107
Return StataReader object."""
108108

109+
_reader_notes = """\
110+
Notes
111+
-----
112+
Categorical variables read through an iterator may not have the same
113+
categories and dtype. This occurs when a variable stored in a DTA
114+
file is associated to an incomplete set of value labels that only
115+
label a strict subset of the values."""
116+
109117
_read_stata_doc = f"""
110118
Read Stata file into DataFrame.
111119
@@ -135,6 +143,8 @@
135143
io.stata.StataReader : Low-level reader for Stata data files.
136144
DataFrame.to_stata: Export Stata data files.
137145
146+
{_reader_notes}
147+
138148
Examples
139149
--------
140150
Read a Stata dta file:
@@ -176,6 +186,8 @@
176186
{_statafile_processing_params1}
177187
{_statafile_processing_params2}
178188
{_chunksize_params}
189+
190+
{_reader_notes}
179191
"""
180192

181193

@@ -497,6 +509,21 @@ class InvalidColumnName(Warning):
497509
"""
498510

499511

512+
class CategoricalConversionWarning(Warning):
513+
pass
514+
515+
516+
categorical_conversion_warning = """
517+
One or more series with value labels are not fully labeled. Reading this
518+
dataset with an iterator results in categorical variable with different
519+
categories. This occurs since it is not possible to know all possible values
520+
until the entire dataset has been read. To avoid this warning, you can either
521+
read dataset without an interator, or manually convert categorical data by
522+
``convert_categoricals`` to False and then accessing the variable labels
523+
through the value_labels method of the reader.
524+
"""
525+
526+
500527
def _cast_to_stata_types(data: DataFrame) -> DataFrame:
501528
"""
502529
Checks the dtypes of the columns of a pandas DataFrame for
@@ -1023,6 +1050,10 @@ def __init__(
10231050
self._order_categoricals = order_categoricals
10241051
self._encoding = ""
10251052
self._chunksize = chunksize
1053+
if self._chunksize is not None and (
1054+
not isinstance(chunksize, int) or chunksize <= 0
1055+
):
1056+
raise ValueError("chunksize must be a positive integer when set.")
10261057

10271058
# State variables for the file
10281059
self._has_string_data = False
@@ -1488,6 +1519,10 @@ def _read_strls(self) -> None:
14881519
self.GSO[str(v_o)] = decoded_va
14891520

14901521
def __next__(self) -> DataFrame:
1522+
if self._chunksize is None:
1523+
raise ValueError(
1524+
"chunksize must be set to a positive integer to use as an iterator."
1525+
)
14911526
return self.read(nrows=self._chunksize or 1)
14921527

14931528
def get_chunk(self, size: Optional[int] = None) -> DataFrame:
@@ -1753,8 +1788,8 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra
17531788

17541789
return data[columns]
17551790

1756-
@staticmethod
17571791
def _do_convert_categoricals(
1792+
self,
17581793
data: DataFrame,
17591794
value_label_dict: Dict[str, Dict[Union[float, int], str]],
17601795
lbllist: Sequence[str],
@@ -1768,14 +1803,39 @@ def _do_convert_categoricals(
17681803
for col, label in zip(data, lbllist):
17691804
if label in value_labels:
17701805
# Explicit call with ordered=True
1771-
cat_data = Categorical(data[col], ordered=order_categoricals)
1772-
categories = []
1773-
for category in cat_data.categories:
1774-
if category in value_label_dict[label]:
1775-
categories.append(value_label_dict[label][category])
1776-
else:
1777-
categories.append(category) # Partially labeled
1806+
vl = value_label_dict[label]
1807+
keys = np.array(list(vl.keys()))
1808+
column = data[col]
1809+
key_matches = column.isin(keys)
1810+
if self._chunksize is not None and key_matches.all():
1811+
initial_categories = keys
1812+
# If all categories are in the keys and we are iterating,
1813+
# use the same keys for all chunks. If some are missing
1814+
# value labels, then we will fall back to the categories
1815+
# varying across chunks.
1816+
else:
1817+
if self._chunksize is not None:
1818+
# warn is using an iterator
1819+
warnings.warn(
1820+
categorical_conversion_warning, CategoricalConversionWarning
1821+
)
1822+
initial_categories = None
1823+
cat_data = Categorical(
1824+
column, categories=initial_categories, ordered=order_categoricals
1825+
)
1826+
if initial_categories is None:
1827+
# If None here, then we need to match the cats in the Categorical
1828+
categories = []
1829+
for category in cat_data.categories:
1830+
if category in vl:
1831+
categories.append(vl[category])
1832+
else:
1833+
categories.append(category)
1834+
else:
1835+
# If all cats are matched, we can use the values
1836+
categories = list(vl.values())
17781837
try:
1838+
# Try to catch duplicate categories
17791839
cat_data.categories = categories
17801840
except ValueError as err:
17811841
vc = Series(categories).value_counts()
Binary file not shown.

pandas/tests/io/test_stata.py

+60
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pandas.io.parsers import read_csv
2222
from pandas.io.stata import (
23+
CategoricalConversionWarning,
2324
InvalidColumnName,
2425
PossiblePrecisionLoss,
2526
StataMissingValue,
@@ -1923,3 +1924,62 @@ def test_compression_dict(method, file_ext):
19231924
fp = path
19241925
reread = read_stata(fp, index_col="index")
19251926
tm.assert_frame_equal(reread, df)
1927+
1928+
1929+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
1930+
def test_chunked_categorical(version):
1931+
df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")})
1932+
df.index.name = "index"
1933+
with tm.ensure_clean() as path:
1934+
df.to_stata(path, version=version)
1935+
reader = StataReader(path, chunksize=2, order_categoricals=False)
1936+
for i, block in enumerate(reader):
1937+
block = block.set_index("index")
1938+
assert "cats" in block
1939+
tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)])
1940+
1941+
1942+
def test_chunked_categorical_partial(dirpath):
1943+
dta_file = os.path.join(dirpath, "stata-dta-partially-labeled.dta")
1944+
values = ["a", "b", "a", "b", 3.0]
1945+
with StataReader(dta_file, chunksize=2) as reader:
1946+
with tm.assert_produces_warning(CategoricalConversionWarning):
1947+
for i, block in enumerate(reader):
1948+
assert list(block.cats) == values[2 * i : 2 * (i + 1)]
1949+
if i < 2:
1950+
idx = pd.Index(["a", "b"])
1951+
else:
1952+
idx = pd.Float64Index([3.0])
1953+
tm.assert_index_equal(block.cats.cat.categories, idx)
1954+
with tm.assert_produces_warning(CategoricalConversionWarning):
1955+
with StataReader(dta_file, chunksize=5) as reader:
1956+
large_chunk = reader.__next__()
1957+
direct = read_stata(dta_file)
1958+
tm.assert_frame_equal(direct, large_chunk)
1959+
1960+
1961+
def test_iterator_errors(dirpath):
1962+
dta_file = os.path.join(dirpath, "stata-dta-partially-labeled.dta")
1963+
with pytest.raises(ValueError, match="chunksize must be a positive"):
1964+
StataReader(dta_file, chunksize=-1)
1965+
with pytest.raises(ValueError, match="chunksize must be a positive"):
1966+
StataReader(dta_file, chunksize=0)
1967+
with pytest.raises(ValueError, match="chunksize must be a positive"):
1968+
StataReader(dta_file, chunksize="apple")
1969+
with pytest.raises(ValueError, match="chunksize must be set to a positive"):
1970+
with StataReader(dta_file) as reader:
1971+
reader.__next__()
1972+
1973+
1974+
def test_iterator_value_labels():
1975+
# GH 31544
1976+
values = ["c_label", "b_label"] + ["a_label"] * 500
1977+
df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)})
1978+
with tm.ensure_clean() as path:
1979+
df.to_stata(path, write_index=False)
1980+
reader = pd.read_stata(path, chunksize=100)
1981+
expected = pd.Index(["a_label", "b_label", "c_label"], dtype="object")
1982+
for j, chunk in enumerate(reader):
1983+
for i in range(2):
1984+
tm.assert_index_equal(chunk.dtypes[i].categories, expected)
1985+
tm.assert_frame_equal(chunk, df.iloc[j * 100 : (j + 1) * 100])

0 commit comments

Comments
 (0)