Skip to content

Commit 5ad63c0

Browse files
committed
BUG: Ensure chunksize is set if not provided
Remvoe error message inorrectl added Fixed new issues identified by mypy Add test to ensure conversion of large ints is correct closes #37280
1 parent 6ac3765 commit 5ad63c0

File tree

3 files changed

+61
-43
lines changed

3 files changed

+61
-43
lines changed

doc/source/whatsnew/v1.1.4.rst

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Fixed regressions
2222
- Fixed regression in :class:`RollingGroupby` causing a segmentation fault with Index of dtype object (:issue:`36727`)
2323
- Fixed regression in :meth:`DataFrame.resample(...).apply(...)` raised ``AttributeError`` when input was a :class:`DataFrame` and only a :class:`Series` was evaluated (:issue:`36951`)
2424
- Fixed regression in :class:`PeriodDtype` comparing both equal and unequal to its string representation (:issue:`37265`)
25+
- Fixed regression in :class:`StataReader` which required ``chunksize`` to be manually set when using an iterator to read a dataset (:issue:`37280`)
2526

2627
.. ---------------------------------------------------------------------------
2728

pandas/io/stata.py

+43-40
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ class PossiblePrecisionLoss(Warning):
469469

470470

471471
precision_loss_doc = """
472-
Column converted from %s to %s, and some data are outside of the lossless
472+
Column converted from {0} to {1}, and some data are outside of the lossless
473473
conversion range. This may result in a loss of precision in the saved data.
474474
"""
475475

@@ -543,7 +543,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
543543
object in a DataFrame.
544544
"""
545545
ws = ""
546-
# original, if small, if large
546+
# original, if small, if large
547547
conversion_data = (
548548
(np.bool_, np.int8, np.int8),
549549
(np.uint8, np.int8, np.int16),
@@ -563,7 +563,7 @@ def _cast_to_stata_types(data: DataFrame) -> DataFrame:
563563
dtype = c_data[1]
564564
else:
565565
dtype = c_data[2]
566-
if c_data[2] == np.float64: # Warn if necessary
566+
if c_data[2] == np.int64: # Warn if necessary
567567
if data[col].max() >= 2 ** 53:
568568
ws = precision_loss_doc.format("uint64", "float64")
569569

@@ -627,12 +627,12 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
627627
self.value_labels = list(zip(np.arange(len(categories)), categories))
628628
self.value_labels.sort(key=lambda x: x[0])
629629
self.text_len = 0
630-
self.off: List[int] = []
631-
self.val: List[int] = []
632630
self.txt: List[bytes] = []
633631
self.n = 0
634632

635633
# Compute lengths and setup lists of offsets and labels
634+
offsets: List[int] = []
635+
values: List[int] = []
636636
for vl in self.value_labels:
637637
category = vl[1]
638638
if not isinstance(category, str):
@@ -642,9 +642,9 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
642642
ValueLabelTypeMismatch,
643643
)
644644
category = category.encode(encoding)
645-
self.off.append(self.text_len)
645+
offsets.append(self.text_len)
646646
self.text_len += len(category) + 1 # +1 for the padding
647-
self.val.append(vl[0])
647+
values.append(vl[0])
648648
self.txt.append(category)
649649
self.n += 1
650650

@@ -655,8 +655,8 @@ def __init__(self, catarray: Series, encoding: str = "latin-1"):
655655
)
656656

657657
# Ensure int32
658-
self.off = np.array(self.off, dtype=np.int32)
659-
self.val = np.array(self.val, dtype=np.int32)
658+
self.off = np.array(offsets, dtype=np.int32)
659+
self.val = np.array(values, dtype=np.int32)
660660

661661
# Total length
662662
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
@@ -868,23 +868,23 @@ def __init__(self):
868868
# with a label, but the underlying variable is -127 to 100
869869
# we're going to drop the label and cast to int
870870
self.DTYPE_MAP = dict(
871-
list(zip(range(1, 245), ["a" + str(i) for i in range(1, 245)]))
871+
list(zip(range(1, 245), [np.dtype("a" + str(i)) for i in range(1, 245)]))
872872
+ [
873-
(251, np.int8),
874-
(252, np.int16),
875-
(253, np.int32),
876-
(254, np.float32),
877-
(255, np.float64),
873+
(251, np.dtype(np.int8)),
874+
(252, np.dtype(np.int16)),
875+
(253, np.dtype(np.int32)),
876+
(254, np.dtype(np.float32)),
877+
(255, np.dtype(np.float64)),
878878
]
879879
)
880880
self.DTYPE_MAP_XML = dict(
881881
[
882-
(32768, np.uint8), # Keys to GSO
883-
(65526, np.float64),
884-
(65527, np.float32),
885-
(65528, np.int32),
886-
(65529, np.int16),
887-
(65530, np.int8),
882+
(32768, np.dtype(np.uint8)), # Keys to GSO
883+
(65526, np.dtype(np.float64)),
884+
(65527, np.dtype(np.float32)),
885+
(65528, np.dtype(np.int32)),
886+
(65529, np.dtype(np.int16)),
887+
(65530, np.dtype(np.int8)),
888888
]
889889
)
890890
# error: Argument 1 to "list" has incompatible type "str";
@@ -1045,10 +1045,12 @@ def __init__(
10451045
self._order_categoricals = order_categoricals
10461046
self._encoding = ""
10471047
self._chunksize = chunksize
1048-
if self._chunksize is not None and (
1049-
not isinstance(chunksize, int) or chunksize <= 0
1050-
):
1051-
raise ValueError("chunksize must be a positive integer when set.")
1048+
self._using_iterator = False
1049+
if self._chunksize is None:
1050+
self._chunksize = 1
1051+
else:
1052+
if not isinstance(chunksize, int) or chunksize <= 0:
1053+
raise ValueError("chunksize must be a positive integer when set.")
10521054

10531055
# State variables for the file
10541056
self._has_string_data = False
@@ -1057,7 +1059,7 @@ def __init__(
10571059
self._column_selector_set = False
10581060
self._value_labels_read = False
10591061
self._data_read = False
1060-
self._dtype = None
1062+
self._dtype: Optional[np.dtype] = None
10611063
self._lines_read = 0
10621064

10631065
self._native_byteorder = _set_endianness(sys.byteorder)
@@ -1193,7 +1195,7 @@ def _read_new_header(self) -> None:
11931195
# Get data type information, works for versions 117-119.
11941196
def _get_dtypes(
11951197
self, seek_vartypes: int
1196-
) -> Tuple[List[Union[int, str]], List[Union[int, np.dtype]]]:
1198+
) -> Tuple[List[Union[int, str]], List[Union[str, np.dtype]]]:
11971199

11981200
self.path_or_buf.seek(seek_vartypes)
11991201
raw_typlist = [
@@ -1518,11 +1520,8 @@ def _read_strls(self) -> None:
15181520
self.GSO[str(v_o)] = decoded_va
15191521

15201522
def __next__(self) -> DataFrame:
1521-
if self._chunksize is None:
1522-
raise ValueError(
1523-
"chunksize must be set to a positive integer to use as an iterator."
1524-
)
1525-
return self.read(nrows=self._chunksize or 1)
1523+
self._using_iterator = True
1524+
return self.read(nrows=self._chunksize)
15261525

15271526
def get_chunk(self, size: Optional[int] = None) -> DataFrame:
15281527
"""
@@ -1690,11 +1689,15 @@ def any_startswith(x: str) -> bool:
16901689
convert = False
16911690
for col in data:
16921691
dtype = data[col].dtype
1693-
if dtype in (np.float16, np.float32):
1694-
dtype = np.float64
1692+
if dtype in (np.dtype(np.float16), np.dtype(np.float32)):
1693+
dtype = np.dtype(np.float64)
16951694
convert = True
1696-
elif dtype in (np.int8, np.int16, np.int32):
1697-
dtype = np.int64
1695+
elif dtype in (
1696+
np.dtype(np.int8),
1697+
np.dtype(np.int16),
1698+
np.dtype(np.int32),
1699+
):
1700+
dtype = np.dtype(np.int64)
16981701
convert = True
16991702
retyped_data.append((col, data[col].astype(dtype)))
17001703
if convert:
@@ -1806,14 +1809,14 @@ def _do_convert_categoricals(
18061809
keys = np.array(list(vl.keys()))
18071810
column = data[col]
18081811
key_matches = column.isin(keys)
1809-
if self._chunksize is not None and key_matches.all():
1810-
initial_categories = keys
1812+
if self._using_iterator and key_matches.all():
1813+
initial_categories: Optional[np.ndarray] = keys
18111814
# If all categories are in the keys and we are iterating,
18121815
# use the same keys for all chunks. If some are missing
18131816
# value labels, then we will fall back to the categories
18141817
# varying across chunks.
18151818
else:
1816-
if self._chunksize is not None:
1819+
if self._using_iterator:
18171820
# warn is using an iterator
18181821
warnings.warn(
18191822
categorical_conversion_warning, CategoricalConversionWarning
@@ -2024,7 +2027,7 @@ def _convert_datetime_to_stata_type(fmt: str) -> np.dtype:
20242027
"ty",
20252028
"%ty",
20262029
]:
2027-
return np.float64 # Stata expects doubles for SIFs
2030+
return np.dtype(np.float64) # Stata expects doubles for SIFs
20282031
else:
20292032
raise NotImplementedError(f"Format {fmt} not implemented")
20302033

pandas/tests/io/test_stata.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1966,9 +1966,6 @@ def test_iterator_errors(dirpath):
19661966
StataReader(dta_file, chunksize=0)
19671967
with pytest.raises(ValueError, match="chunksize must be a positive"):
19681968
StataReader(dta_file, chunksize="apple")
1969-
with pytest.raises(ValueError, match="chunksize must be set to a positive"):
1970-
with StataReader(dta_file) as reader:
1971-
reader.__next__()
19721969

19731970

19741971
def test_iterator_value_labels():
@@ -1983,3 +1980,20 @@ def test_iterator_value_labels():
19831980
for i in range(2):
19841981
tm.assert_index_equal(chunk.dtypes[i].categories, expected)
19851982
tm.assert_frame_equal(chunk, df.iloc[j * 100 : (j + 1) * 100])
1983+
1984+
1985+
def test_precision_loss():
1986+
df = DataFrame(
1987+
[[sum(2 ** i for i in range(60)), sum(2 ** i for i in range(52))]],
1988+
columns=["big", "little"],
1989+
)
1990+
with tm.ensure_clean() as path:
1991+
with tm.assert_produces_warning(
1992+
PossiblePrecisionLoss, match="Column converted from int64 to float64"
1993+
):
1994+
df.to_stata(path, write_index=False)
1995+
reread = read_stata(path)
1996+
expected_dt = Series([np.float64, np.float64], index=["big", "little"])
1997+
tm.assert_series_equal(reread.dtypes, expected_dt)
1998+
assert reread.loc[0, "little"] == df.loc[0, "little"]
1999+
assert reread.loc[0, "big"] == float(df.loc[0, "big"])

0 commit comments

Comments
 (0)