Skip to content

Commit 3f09064

Browse files
Kevin Sheppardbashtage
Kevin Sheppard
authored andcommitted
BUG/ENH: Correct categorical on iterators
Return categoricals with the same categories if possible when reading data through an interator. Warn if not possible. closes pandas-dev#31544
1 parent 9a741d3 commit 3f09064

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ I/O
777777
- Bug in :meth:`~DataFrame.read_feather` was raising an `ArrowIOError` when reading an s3 or http file path (:issue:`29055`)
778778
- Bug in :meth:`read_parquet` was raising a ``FileNotFoundError`` when passed an s3 directory path. (:issue:`26388`)
779779
- Bug in :meth:`~DataFrame.to_parquet` was throwing an ``AttributeError`` when writing a partitioned parquet file to s3 (:issue:`27596`)
780+
- Bug in :meth:`~pandas.io.stata.StataReader` which resulted in categorical variables with difference dtypes when reading data using an iterator. (:issue:`31544`)
780781

781782
Plotting
782783
^^^^^^^^

pandas/io/stata.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,21 @@ class InvalidColumnName(Warning):
480480
"""
481481

482482

483+
class CategoricalConversionWarning(Warning):
484+
pass
485+
486+
487+
categorical_conversion_warning = """
488+
One or more series with value labels are not fully labeled. Reading this
489+
dataset with an iterator results in categorical variable with different
490+
categories. This occurs since it is not possible to know all possible values
491+
until the entire dataset has been read. To avoid this warning, you can either
492+
read dataset without an interator, or manually convert categorical data by
493+
``convert_categoricals`` to False and then accessing the variable labels
494+
through the value_labels method of the reader.
495+
"""
496+
497+
483498
def _cast_to_stata_types(data: DataFrame) -> DataFrame:
484499
"""
485500
Checks the dtypes of the columns of a pandas DataFrame for
@@ -1736,8 +1751,8 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra
17361751

17371752
return data[columns]
17381753

1739-
@staticmethod
17401754
def _do_convert_categoricals(
1755+
self,
17411756
data: DataFrame,
17421757
value_label_dict: Dict[str, Dict[Union[float, int], str]],
17431758
lbllist: Sequence[str],
@@ -1751,14 +1766,36 @@ def _do_convert_categoricals(
17511766
for col, label in zip(data, lbllist):
17521767
if label in value_labels:
17531768
# Explicit call with ordered=True
1754-
cat_data = Categorical(data[col], ordered=order_categoricals)
1755-
categories = []
1756-
for category in cat_data.categories:
1757-
if category in value_label_dict[label]:
1758-
categories.append(value_label_dict[label][category])
1759-
else:
1760-
categories.append(category) # Partially labeled
1769+
vl = value_label_dict[label]
1770+
keys = np.array([k for k in vl.keys()])
1771+
column = data[col]
1772+
if column.isin(keys).all() and self._chunksize:
1773+
# If all categories are in the keys and we are iterating,
1774+
# use the same keys for all chunks. If some are missing
1775+
# value labels, then we will fall back to the categories
1776+
# varying across chunks.
1777+
initial_categories = keys
1778+
warnings.warn(
1779+
categorical_conversion_warning, CategoricalConversionWarning
1780+
)
1781+
else:
1782+
initial_categories = None
1783+
cat_data = Categorical(
1784+
column, categories=initial_categories, ordered=order_categoricals
1785+
)
1786+
if initial_categories is None:
1787+
# If None here, then we need to match the cats in the Categorical
1788+
categories = []
1789+
for category in cat_data.categories:
1790+
if category in vl:
1791+
categories.append(vl[category])
1792+
else:
1793+
categories.append(category)
1794+
else:
1795+
# If all cats are matched, we can use the values
1796+
categories = [v for v in vl.values()]
17611797
try:
1798+
# Try to catch duplicate categories
17621799
cat_data.categories = categories
17631800
except ValueError as err:
17641801
vc = Series(categories).value_counts()
Binary file not shown.

pandas/tests/io/test_stata.py

+32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from pandas.io.parsers import read_csv
1919
from pandas.io.stata import (
20+
CategoricalConversionWarning,
2021
InvalidColumnName,
2122
PossiblePrecisionLoss,
2223
StataMissingValue,
@@ -1853,3 +1854,34 @@ def test_writer_118_exceptions(self):
18531854
with tm.ensure_clean() as path:
18541855
with pytest.raises(ValueError, match="You must use version 119"):
18551856
StataWriterUTF8(path, df, version=118)
1857+
1858+
1859+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
1860+
def test_chunked_categorical(version):
1861+
df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")})
1862+
df.index.name = "index"
1863+
with tm.ensure_clean() as path:
1864+
df.to_stata(path, version=version)
1865+
reader = StataReader(path, chunksize=2, order_categoricals=False)
1866+
for i, block in enumerate(reader):
1867+
block = block.set_index("index")
1868+
assert "cats" in block
1869+
tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)])
1870+
1871+
1872+
def test_chunked_categorical_partial(dirpath):
1873+
dta_file = os.path.join(dirpath, "stata-dta-partially-labeled.dta")
1874+
reader = StataReader(dta_file, chunksize=2)
1875+
values = ["a", "b", "a", "b", 3.0]
1876+
with pytest.warns(CategoricalConversionWarning, match="One or more series"):
1877+
for i, block in enumerate(reader):
1878+
assert list(block.cats) == values[2 * i : 2 * (i + 1)]
1879+
if i < 2:
1880+
idx = pd.Index(["a", "b"])
1881+
else:
1882+
idx = pd.Float64Index([3.0])
1883+
tm.assert_index_equal(block.cats.cat.categories, idx)
1884+
reader = StataReader(dta_file, chunksize=5)
1885+
large_chunk = reader.__next__()
1886+
direct = read_stata(dta_file)
1887+
tm.assert_frame_equal(direct, large_chunk)

0 commit comments

Comments
 (0)