Skip to content

Commit 4602aed

Browse files
committed
FIX: StataReader: defer opening file to when data is required
1 parent b4db2b4 commit 4602aed

File tree

2 files changed

+28
-42
lines changed

2 files changed

+28
-42
lines changed

pandas/io/stata.py

+26-38
Original file line numberDiff line numberDiff line change
@@ -1114,6 +1114,8 @@ def __init__(self) -> None:
11141114
class StataReader(StataParser, abc.Iterator):
11151115
__doc__ = _stata_reader_doc
11161116

1117+
_path_or_buf: IO[bytes] # Will be assigned by `_open_file`.
1118+
11171119
def __init__(
11181120
self,
11191121
path_or_buf: FilePath | ReadBuffer[bytes],
@@ -1140,6 +1142,9 @@ def __init__(
11401142
self._preserve_dtypes = preserve_dtypes
11411143
self._columns = columns
11421144
self._order_categoricals = order_categoricals
1145+
self._original_path_or_buf = path_or_buf
1146+
self._compression = compression
1147+
self._storage_options = storage_options
11431148
self._encoding = ""
11441149
self._chunksize = chunksize
11451150
self._using_iterator = False
@@ -1149,6 +1154,7 @@ def __init__(
11491154
raise ValueError("chunksize must be a positive integer when set.")
11501155

11511156
# State variables for the file
1157+
self._close_file: Callable[[], None] | None = None
11521158
self._has_string_data = False
11531159
self._missing_values = False
11541160
self._can_read_value_labels = False
@@ -1159,12 +1165,24 @@ def __init__(
11591165
self._lines_read = 0
11601166

11611167
self._native_byteorder = _set_endianness(sys.byteorder)
1168+
1169+
def _ensure_open(self) -> None:
1170+
"""
1171+
Ensure the file has been opened and its header data read.
1172+
"""
1173+
if not hasattr(self, "_path_or_buf"):
1174+
self._open_file()
1175+
1176+
def _open_file(self) -> None:
1177+
"""
1178+
Open the file (with compression options, etc.), and read header information.
1179+
"""
11621180
with get_handle(
1163-
path_or_buf,
1181+
self._original_path_or_buf,
11641182
"rb",
1165-
storage_options=storage_options,
1183+
storage_options=self._storage_options,
11661184
is_text=False,
1167-
compression=compression,
1185+
compression=self._compression,
11681186
) as handles:
11691187
# Copy to BytesIO, and ensure no encoding
11701188
self._path_or_buf = BytesIO(handles.handle.read())
@@ -1530,6 +1548,7 @@ def _decode(self, s: bytes) -> str:
15301548
return s.decode("latin-1")
15311549

15321550
def _read_value_labels(self) -> None:
1551+
self._ensure_open()
15331552
if self._value_labels_read:
15341553
# Don't read twice
15351554
return
@@ -1649,6 +1668,7 @@ def read(
16491668
columns: Sequence[str] | None = None,
16501669
order_categoricals: bool | None = None,
16511670
) -> DataFrame:
1671+
self._ensure_open()
16521672
# Handle empty file or chunk. If reading incrementally raise
16531673
# StopIteration. If reading the whole thing return an empty
16541674
# data frame.
@@ -1976,48 +1996,15 @@ def data_label(self) -> str:
19761996
"""
19771997
Return data label of Stata file.
19781998
"""
1999+
self._ensure_open()
19792000
return self._data_label
19802001

1981-
@property
1982-
def typlist(self) -> list[int | str]:
1983-
"""
1984-
Return list of variable types.
1985-
"""
1986-
return self._typlist
1987-
1988-
@property
1989-
def dtyplist(self) -> list[str | np.dtype]:
1990-
"""
1991-
Return list of variable types.
1992-
"""
1993-
return self._dtyplist
1994-
1995-
@property
1996-
def lbllist(self) -> list[str]:
1997-
"""
1998-
Return list of variable labels.
1999-
"""
2000-
return self._lbllist
2001-
2002-
@property
2003-
def varlist(self) -> list[str]:
2004-
"""
2005-
Return list of variable names.
2006-
"""
2007-
return self._varlist
2008-
2009-
@property
2010-
def fmtlist(self) -> list[str]:
2011-
"""
2012-
Return list of variable formats.
2013-
"""
2014-
return self._fmtlist
2015-
20162002
@property
20172003
def time_stamp(self) -> str:
20182004
"""
20192005
Return time stamp of Stata file.
20202006
"""
2007+
self._ensure_open()
20212008
return self._time_stamp
20222009

20232010
def variable_labels(self) -> dict[str, str]:
@@ -2028,6 +2015,7 @@ def variable_labels(self) -> dict[str, str]:
20282015
-------
20292016
dict
20302017
"""
2018+
self._ensure_open()
20312019
return dict(zip(self._varlist, self._variable_labels))
20322020

20332021
def value_labels(self) -> dict[str, dict[float, str]]:

pandas/tests/io/test_stata.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -736,10 +736,8 @@ def test_minimal_size_col(self):
736736
original.to_stata(path, write_index=False)
737737

738738
with StataReader(path) as sr:
739-
typlist = sr.typlist
740-
variables = sr.varlist
741-
formats = sr.fmtlist
742-
for variable, fmt, typ in zip(variables, formats, typlist):
739+
sr._ensure_open() # The `_*list` variables are initialized here
740+
for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist):
743741
assert int(variable[1:]) == int(fmt[1:-1])
744742
assert int(variable[1:]) == typ
745743

0 commit comments

Comments
 (0)