Skip to content

Commit 0a069be

Browse files
Kevin Sheppardbashtage
Kevin Sheppard
authored andcommitted
BUG/ENH: Correct categorical on iterators
Return categoricals with the same categories if possible when reading data through an interator. Warn if not possible. closes pandas-dev#31544
1 parent a493fe1 commit 0a069be

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

doc/source/whatsnew/v1.1.0.rst

+4
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,10 @@ I/O
905905
- Bug in :meth:`~DataFrame.read_feather` was raising an `ArrowIOError` when reading an s3 or http file path (:issue:`29055`)
906906
- Bug in :meth:`~DataFrame.to_excel` could not handle the column name `render` and was raising an ``KeyError`` (:issue:`34331`)
907907
- Bug in :meth:`~SQLDatabase.execute` was raising a ``ProgrammingError`` for some DB-API drivers when the SQL statement contained the `%` character and no parameters were present (:issue:`34211`)
908+
- Bug in :meth:`read_parquet` was raising a ``FileNotFoundError`` when passed an s3 directory path. (:issue:`26388`)
909+
- Bug in :meth:`~DataFrame.to_parquet` was throwing an ``AttributeError`` when writing a partitioned parquet file to s3 (:issue:`27596`)
910+
- Bug in :meth:`~pandas.io.stata.StataReader` which resulted in categorical variables with difference dtypes when reading data using an iterator. (:issue:`31544`)
911+
908912

909913
Plotting
910914
^^^^^^^^

pandas/io/stata.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,21 @@ class InvalidColumnName(Warning):
497497
"""
498498

499499

500+
class CategoricalConversionWarning(Warning):
501+
pass
502+
503+
504+
categorical_conversion_warning = """
505+
One or more series with value labels are not fully labeled. Reading this
506+
dataset with an iterator results in categorical variable with different
507+
categories. This occurs since it is not possible to know all possible values
508+
until the entire dataset has been read. To avoid this warning, you can either
509+
read dataset without an interator, or manually convert categorical data by
510+
``convert_categoricals`` to False and then accessing the variable labels
511+
through the value_labels method of the reader.
512+
"""
513+
514+
500515
def _cast_to_stata_types(data: DataFrame) -> DataFrame:
501516
"""
502517
Checks the dtypes of the columns of a pandas DataFrame for
@@ -1753,8 +1768,8 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra
17531768

17541769
return data[columns]
17551770

1756-
@staticmethod
17571771
def _do_convert_categoricals(
1772+
self,
17581773
data: DataFrame,
17591774
value_label_dict: Dict[str, Dict[Union[float, int], str]],
17601775
lbllist: Sequence[str],
@@ -1768,14 +1783,36 @@ def _do_convert_categoricals(
17681783
for col, label in zip(data, lbllist):
17691784
if label in value_labels:
17701785
# Explicit call with ordered=True
1771-
cat_data = Categorical(data[col], ordered=order_categoricals)
1772-
categories = []
1773-
for category in cat_data.categories:
1774-
if category in value_label_dict[label]:
1775-
categories.append(value_label_dict[label][category])
1776-
else:
1777-
categories.append(category) # Partially labeled
1786+
vl = value_label_dict[label]
1787+
keys = np.array([k for k in vl.keys()])
1788+
column = data[col]
1789+
if column.isin(keys).all() and self._chunksize:
1790+
# If all categories are in the keys and we are iterating,
1791+
# use the same keys for all chunks. If some are missing
1792+
# value labels, then we will fall back to the categories
1793+
# varying across chunks.
1794+
initial_categories = keys
1795+
warnings.warn(
1796+
categorical_conversion_warning, CategoricalConversionWarning
1797+
)
1798+
else:
1799+
initial_categories = None
1800+
cat_data = Categorical(
1801+
column, categories=initial_categories, ordered=order_categoricals
1802+
)
1803+
if initial_categories is None:
1804+
# If None here, then we need to match the cats in the Categorical
1805+
categories = []
1806+
for category in cat_data.categories:
1807+
if category in vl:
1808+
categories.append(vl[category])
1809+
else:
1810+
categories.append(category)
1811+
else:
1812+
# If all cats are matched, we can use the values
1813+
categories = [v for v in vl.values()]
17781814
try:
1815+
# Try to catch duplicate categories
17791816
cat_data.categories = categories
17801817
except ValueError as err:
17811818
vc = Series(categories).value_counts()
Binary file not shown.

pandas/tests/io/test_stata.py

+32
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pandas.io.parsers import read_csv
2222
from pandas.io.stata import (
23+
CategoricalConversionWarning,
2324
InvalidColumnName,
2425
PossiblePrecisionLoss,
2526
StataMissingValue,
@@ -1923,3 +1924,34 @@ def test_compression_dict(method, file_ext):
19231924
fp = path
19241925
reread = read_stata(fp, index_col="index")
19251926
tm.assert_frame_equal(reread, df)
1927+
1928+
1929+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
1930+
def test_chunked_categorical(version):
1931+
df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")})
1932+
df.index.name = "index"
1933+
with tm.ensure_clean() as path:
1934+
df.to_stata(path, version=version)
1935+
reader = StataReader(path, chunksize=2, order_categoricals=False)
1936+
for i, block in enumerate(reader):
1937+
block = block.set_index("index")
1938+
assert "cats" in block
1939+
tm.assert_series_equal(block.cats, df.cats.iloc[2 * i : 2 * (i + 1)])
1940+
1941+
1942+
def test_chunked_categorical_partial(dirpath):
1943+
dta_file = os.path.join(dirpath, "stata-dta-partially-labeled.dta")
1944+
reader = StataReader(dta_file, chunksize=2)
1945+
values = ["a", "b", "a", "b", 3.0]
1946+
with pytest.warns(CategoricalConversionWarning, match="One or more series"):
1947+
for i, block in enumerate(reader):
1948+
assert list(block.cats) == values[2 * i : 2 * (i + 1)]
1949+
if i < 2:
1950+
idx = pd.Index(["a", "b"])
1951+
else:
1952+
idx = pd.Float64Index([3.0])
1953+
tm.assert_index_equal(block.cats.cat.categories, idx)
1954+
reader = StataReader(dta_file, chunksize=5)
1955+
large_chunk = reader.__next__()
1956+
direct = read_stata(dta_file)
1957+
tm.assert_frame_equal(direct, large_chunk)

0 commit comments

Comments
 (0)