Skip to content

Commit c132858

Browse files
meeseeksmachinebashtagesimonjayhawkins
authored
Backport PR #37302 on branch 1.1.x (BUG: Allow empty chunksize in stata reader when using iterator) (#37364)
* Backport PR #37302: BUG: Allow empty chunksize in stata reader when using iterator * remove match argument to assert_produces_warning Co-authored-by: Kevin Sheppard <[email protected]> Co-authored-by: Simon Hawkins <[email protected]>
1 parent 60edd86 commit c132858

File tree

3 files changed

+57
-42
lines changed

3 files changed

+57
-42
lines changed

doc/source/whatsnew/v1.1.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Fixed regressions
2222
- Fixed regression in :class:`RollingGroupby` causing a segmentation fault with Index of dtype object (:issue:`36727`)
2323
- Fixed regression in :class:`PeriodDtype` comparing both equal and unequal to its string representation (:issue:`37265`)
2424
- Fixed regression in certain offsets (:meth:`pd.offsets.Day() <pandas.tseries.offsets.Day>` and below) no longer being hashable (:issue:`37267`)
25+
- Fixed regression in :class:`StataReader` which required ``chunksize`` to be manually set when using an iterator to read a dataset (:issue:`37280`)
2526

2627
.. ---------------------------------------------------------------------------
2728

pandas/io/stata.py

+41-39
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class PossiblePrecisionLoss(Warning):
477477

478478

479479
precision_loss_doc = """
480-
Column converted from %s to %s, and some data are outside of the lossless
480+
Column converted from {0} to {1}, and some data are outside of the lossless
481481
conversion range. This may result in a loss of precision in the saved data.
482482
"""
483483

@@ -551,7 +551,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
551551
object in a DataFrame.
552552
"""
553553
ws = ""
554-
# original, if small, if large
554+
# original, if small, if large
555555
conversion_data = (
556556
(np.bool_, np.int8, np.int8),
557557
(np.uint8, np.int8, np.int16),
@@ -571,7 +571,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
571571
dtype = c_data[1]
572572
else:
573573
dtype = c_data[2]
574-
if c_data[2] == np.float64: # Warn if necessary
574+
if c_data[2] == np.int64: # Warn if necessary
575575
if data[col].max() >= 2 ** 53:
576576
ws = precision_loss_doc.format("uint64", "float64")
577577

@@ -635,12 +635,12 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
635635
self.value_labels = list(zip(np.arange(len(categories)), categories))
636636
self.value_labels.sort(key=lambda x: x[0])
637637
self.text_len = 0
638-
self.off: List[int] = []
639-
self.val: List[int] = []
640638
self.txt: List[bytes] = []
641639
self.n = 0
642640

643641
# Compute lengths and setup lists of offsets and labels
642+
offsets: List[int] = []
643+
values: List[int] = []
644644
for vl in self.value_labels:
645645
category = vl[1]
646646
if not isinstance(category, str):
@@ -650,9 +650,9 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
650650
ValueLabelTypeMismatch,
651651
)
652652
category = category.encode(encoding)
653-
self.off.append(self.text_len)
653+
offsets.append(self.text_len)
654654
self.text_len += len(category) + 1 # +1 for the padding
655-
self.val.append(vl[0])
655+
values.append(vl[0])
656656
self.txt.append(category)
657657
self.n += 1
658658

@@ -663,8 +663,8 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
663663
)
664664

665665
# Ensure int32
666-
self.off = np.array(self.off, dtype=np.int32)
667-
self.val = np.array(self.val, dtype=np.int32)
666+
self.off = np.array(offsets, dtype=np.int32)
667+
self.val = np.array(values, dtype=np.int32)
668668

669669
# Total length
670670
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
@@ -876,23 +876,23 @@ def __init__(self):
876876
# with a label, but the underlying variable is -127 to 100
877877
# we're going to drop the label and cast to int
878878
self.DTYPE_MAP = dict(
879-
list(zip(range(1, 245), ["a" + str(i) for i in range(1, 245)]))
879+
list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
880880
+ [
881-
(251, np.int8),
882-
(252, np.int16),
883-
(253, np.int32),
884-
(254, np.float32),
885-
(255, np.float64),
881+
(251, np.dtype(np.int8)),
882+
(252, np.dtype(np.int16)),
883+
(253, np.dtype(np.int32)),
884+
(254, np.dtype(np.float32)),
885+
(255, np.dtype(np.float64)),
886886
]
887887
)
888888
self.DTYPE_MAP_XML = dict(
889889
[
890-
(32768, np.uint8), # Keys to GSO
891-
(65526, np.float64),
892-
(65527, np.float32),
893-
(65528, np.int32),
894-
(65529, np.int16),
895-
(65530, np.int8),
890+
(32768, np.dtype(np.uint8)), # Keys to GSO
891+
(65526, np.dtype(np.float64)),
892+
(65527, np.dtype(np.float32)),
893+
(65528, np.dtype(np.int32)),
894+
(65529, np.dtype(np.int16)),
895+
(65530, np.dtype(np.int8)),
896896
]
897897
)
898898
self.TYPE_MAP = list(range(251)) + list("bhlfd")
@@ -1050,9 +1050,10 @@ def __init__(
10501050
self._order_categoricals = order_categoricals
10511051
self._encoding = ""
10521052
self._chunksize = chunksize
1053-
if self._chunksize is not None and (
1054-
not isinstance(chunksize, int) or chunksize <= 0
1055-
):
1053+
self._using_iterator = False
1054+
if self._chunksize is None:
1055+
self._chunksize = 1
1056+
elif not isinstance(chunksize, int) or chunksize <= 0:
10561057
raise ValueError("chunksize must be a positive integer when set.")
10571058

10581059
# State variables for the file
@@ -1062,7 +1063,7 @@ def __init__(
10621063
self._column_selector_set = False
10631064
self._value_labels_read = False
10641065
self._data_read = False
1065-
self._dtype = None
1066+
self._dtype: Optional[np.dtype] = None
10661067
self._lines_read = 0
10671068

10681069
self._native_byteorder = _set_endianness(sys.byteorder)
@@ -1195,7 +1196,7 @@ def _read_new_header(self) -> None:
11951196
# Get data type information, works for versions 117-119.
11961197
def _get_dtypes(
11971198
self, seek_vartypes: int
1198-
) -> Tuple[List[Union[int, str]], List[Union[int, np.dtype]]]:
1199+
) -> Tuple[List[Union[int, str]], List[Union[str, np.dtype]]]:
11991200

12001201
self.path_or_buf.seek(seek_vartypes)
12011202
raw_typlist = [
@@ -1519,11 +1520,8 @@ def _read_strls(self) -> None:
15191520
self.GSO[str(v_o)] = decoded_va
15201521

15211522
def __next__(self) -> DataFrame:
1522-
if self._chunksize is None:
1523-
raise ValueError(
1524-
"chunksize must be set to a positive integer to use as an iterator."
1525-
)
1526-
return self.read(nrows=self._chunksize or 1)
1523+
self._using_iterator = True
1524+
return self.read(nrows=self._chunksize)
15271525

15281526
def get_chunk(self, size: Optional[int] = None) -> DataFrame:
15291527
"""
@@ -1692,11 +1690,15 @@ def any_startswith(x: str) -> bool:
16921690
convert = False
16931691
for col in data:
16941692
dtype = data[col].dtype
1695-
if dtype in (np.float16, np.float32):
1696-
dtype = np.float64
1693+
if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1694+
dtype = np.dtype(np.float64)
16971695
convert = True
1698-
elif dtype in (np.int8, np.int16, np.int32):
1699-
dtype = np.int64
1696+
elif dtype in (
1697+
np.dtype(np.int8),
1698+
np.dtype(np.int16),
1699+
np.dtype(np.int32),
1700+
):
1701+
dtype = np.dtype(np.int64)
17001702
convert = True
17011703
retyped_data.append((col, data[col].astype(dtype)))
17021704
if convert:
@@ -1807,14 +1809,14 @@ def _do_convert_categoricals(
18071809
keys = np.array(list(vl.keys()))
18081810
column = data[col]
18091811
key_matches = column.isin(keys)
1810-
if self._chunksize is not None and key_matches.all():
1811-
initial_categories = keys
1812+
if self._using_iterator and key_matches.all():
1813+
initial_categories: Optional[np.ndarray] = keys
18121814
# If all categories are in the keys and we are iterating,
18131815
# use the same keys for all chunks. If some are missing
18141816
# value labels, then we will fall back to the categories
18151817
# varying across chunks.
18161818
else:
1817-
if self._chunksize is not None:
1819+
if self._using_iterator:
18181820
# warn is using an iterator
18191821
warnings.warn(
18201822
categorical_conversion_warning, CategoricalConversionWarning
@@ -2010,7 +2012,7 @@ def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
20102012
"ty",
20112013
"%ty",
20122014
]:
2013-
return np.float64 # Stata expects doubles for SIFs
2015+
return np.dtype(np.float64) # Stata expects doubles for SIFs
20142016
else:
20152017
raise NotImplementedError(f"Format {fmt} not implemented")
20162018

pandas/tests/io/test_stata.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -1966,9 +1966,6 @@ def test_iterator_errors(dirpath):
19661966
StataReader(dta_file, chunksize=0)
19671967
with pytest.raises(ValueError, match="chunksize must be a positive"):
19681968
StataReader(dta_file, chunksize="apple")
1969-
with pytest.raises(ValueError, match="chunksize must be set to a positive"):
1970-
with StataReader(dta_file) as reader:
1971-
reader.__next__()
19721969

19731970

19741971
def test_iterator_value_labels():
@@ -1983,3 +1980,18 @@ def test_iterator_value_labels():
19831980
for i in range(2):
19841981
tm.assert_index_equal(chunk.dtypes[i].categories, expected)
19851982
tm.assert_frame_equal(chunk, df.iloc[j * 100 : (j + 1) * 100])
1983+
1984+
1985+
def test_precision_loss():
1986+
df = DataFrame(
1987+
[[sum(2 ** i for i in range(60)), sum(2 ** i for i in range(52))]],
1988+
columns=["big", "little"],
1989+
)
1990+
with tm.ensure_clean() as path:
1991+
with tm.assert_produces_warning(PossiblePrecisionLoss):
1992+
df.to_stata(path, write_index=False)
1993+
reread = read_stata(path)
1994+
expected_dt = Series([np.float64, np.float64], index=["big", "little"])
1995+
tm.assert_series_equal(reread.dtypes, expected_dt)
1996+
assert reread.loc[0, "little"] == df.loc[0, "little"]
1997+
assert reread.loc[0, "big"] == float(df.loc[0, "big"])

0 commit comments

Comments
 (0)