diff --git a/doc/source/user_guide/io.rst b/doc/source/user_guide/io.rst
index 91cd3335d9db6..3c3a655626bb6 100644
--- a/doc/source/user_guide/io.rst
+++ b/doc/source/user_guide/io.rst
@@ -6033,6 +6033,14 @@ values will have ``object`` data type.
``int64`` for all integer types and ``float64`` for floating point data. By default,
the Stata data types are preserved when importing.
+.. note::
+
+ All :class:`~pandas.io.stata.StataReader` objects, whether created by :func:`~pandas.read_stata`
+ (when using ``iterator=True`` or ``chunksize``) or instantiated by hand, must be used as context
+ managers (e.g. the ``with`` statement).
+ While the :meth:`~pandas.io.stata.StataReader.close` method is available, its use is unsupported.
+ It is not part of the public API and will be removed in with future without warning.
+
.. ipython:: python
:suppress:
diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst
index bdbde438217b9..a8d6f3fce5bb7 100644
--- a/doc/source/whatsnew/v2.0.0.rst
+++ b/doc/source/whatsnew/v2.0.0.rst
@@ -857,6 +857,7 @@ Deprecations
- Deprecated :meth:`Series.backfill` in favor of :meth:`Series.bfill` (:issue:`33396`)
- Deprecated :meth:`DataFrame.pad` in favor of :meth:`DataFrame.ffill` (:issue:`33396`)
- Deprecated :meth:`DataFrame.backfill` in favor of :meth:`DataFrame.bfill` (:issue:`33396`)
+- Deprecated :meth:`~pandas.io.stata.StataReader.close`. Use :class:`~pandas.io.stata.StataReader` as a context manager instead (:issue:`49228`)
.. ---------------------------------------------------------------------------
.. _whatsnew_200.prior_deprecations:
@@ -1163,6 +1164,8 @@ Performance improvements
- Fixed a reference leak in :func:`read_hdf` (:issue:`37441`)
- Fixed a memory leak in :meth:`DataFrame.to_json` and :meth:`Series.to_json` when serializing datetimes and timedeltas (:issue:`40443`)
- Decreased memory usage in many :class:`DataFrameGroupBy` methods (:issue:`51090`)
+- Memory improvement in :class:`StataReader` when reading seekable files (:issue:`48922`)
+
.. ---------------------------------------------------------------------------
.. _whatsnew_200.bug_fixes:
diff --git a/pandas/io/stata.py b/pandas/io/stata.py
index 35ca4e1f6b6c4..9542ae46a0d05 100644
--- a/pandas/io/stata.py
+++ b/pandas/io/stata.py
@@ -23,6 +23,7 @@
TYPE_CHECKING,
Any,
AnyStr,
+ Callable,
Final,
Hashable,
Sequence,
@@ -182,10 +183,10 @@
>>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP
>>> df.to_stata('filename.dta') # doctest: +SKIP
->>> itr = pd.read_stata('filename.dta', chunksize=10000) # doctest: +SKIP
->>> for chunk in itr:
-... # Operate on a single chunk, e.g., chunk.mean()
-... pass # doctest: +SKIP
+>>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP
+>>> for chunk in itr:
+... # Operate on a single chunk, e.g., chunk.mean()
+... pass # doctest: +SKIP
"""
_read_method_doc = f"""\
@@ -1114,6 +1115,8 @@ def __init__(self) -> None:
class StataReader(StataParser, abc.Iterator):
__doc__ = _stata_reader_doc
+ _path_or_buf: IO[bytes] # Will be assigned by `_open_file`.
+
def __init__(
self,
path_or_buf: FilePath | ReadBuffer[bytes],
@@ -1129,7 +1132,7 @@ def __init__(
storage_options: StorageOptions = None,
) -> None:
super().__init__()
- self.col_sizes: list[int] = []
+ self._col_sizes: list[int] = []
# Arguments to the reader (can be temporarily overridden in
# calls to read).
@@ -1140,15 +1143,20 @@ def __init__(
self._preserve_dtypes = preserve_dtypes
self._columns = columns
self._order_categoricals = order_categoricals
+ self._original_path_or_buf = path_or_buf
+ self._compression = compression
+ self._storage_options = storage_options
self._encoding = ""
self._chunksize = chunksize
self._using_iterator = False
+ self._entered = False
if self._chunksize is None:
self._chunksize = 1
elif not isinstance(chunksize, int) or chunksize <= 0:
raise ValueError("chunksize must be a positive integer when set.")
# State variables for the file
+ self._close_file: Callable[[], None] | None = None
self._has_string_data = False
self._missing_values = False
self._can_read_value_labels = False
@@ -1159,21 +1167,48 @@ def __init__(
self._lines_read = 0
self._native_byteorder = _set_endianness(sys.byteorder)
- with get_handle(
- path_or_buf,
+
+ def _ensure_open(self) -> None:
+ """
+ Ensure the file has been opened and its header data read.
+ """
+ if not hasattr(self, "_path_or_buf"):
+ self._open_file()
+
+ def _open_file(self) -> None:
+ """
+ Open the file (with compression options, etc.), and read header information.
+ """
+ if not self._entered:
+ warnings.warn(
+ "StataReader is being used without using a context manager. "
+ "Using StataReader as a context manager is the only supported method.",
+ ResourceWarning,
+ stacklevel=find_stack_level(),
+ )
+ handles = get_handle(
+ self._original_path_or_buf,
"rb",
- storage_options=storage_options,
+ storage_options=self._storage_options,
is_text=False,
- compression=compression,
- ) as handles:
- # Copy to BytesIO, and ensure no encoding
- self.path_or_buf = BytesIO(handles.handle.read())
+ compression=self._compression,
+ )
+ if hasattr(handles.handle, "seekable") and handles.handle.seekable():
+ # If the handle is directly seekable, use it without an extra copy.
+ self._path_or_buf = handles.handle
+ self._close_file = handles.close
+ else:
+ # Copy to memory, and ensure no encoding.
+ with handles:
+ self._path_or_buf = BytesIO(handles.handle.read())
+ self._close_file = self._path_or_buf.close
self._read_header()
self._setup_dtype()
def __enter__(self) -> StataReader:
"""enter context manager"""
+ self._entered = True
return self
def __exit__(
@@ -1182,119 +1217,142 @@ def __exit__(
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
- """exit context manager"""
- self.close()
+ if self._close_file:
+ self._close_file()
def close(self) -> None:
- """close the handle if its open"""
- self.path_or_buf.close()
+ """Close the handle if its open.
+
+ .. deprecated: 2.0.0
+
+ The close method is not part of the public API.
+ The only supported way to use StataReader is to use it as a context manager.
+ """
+ warnings.warn(
+ "The StataReader.close() method is not part of the public API and "
+ "will be removed in a future version without notice. "
+ "Using StataReader as a context manager is the only supported method.",
+ FutureWarning,
+ stacklevel=find_stack_level(),
+ )
+ if self._close_file:
+ self._close_file()
def _set_encoding(self) -> None:
"""
Set string encoding which depends on file version
"""
- if self.format_version < 118:
+ if self._format_version < 118:
self._encoding = "latin-1"
else:
self._encoding = "utf-8"
+ def _read_int8(self) -> int:
+ return struct.unpack("b", self._path_or_buf.read(1))[0]
+
+ def _read_uint8(self) -> int:
+ return struct.unpack("B", self._path_or_buf.read(1))[0]
+
+ def _read_uint16(self) -> int:
+ return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0]
+
+ def _read_uint32(self) -> int:
+ return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0]
+
+ def _read_uint64(self) -> int:
+ return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0]
+
+ def _read_int16(self) -> int:
+ return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0]
+
+ def _read_int32(self) -> int:
+ return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0]
+
+ def _read_int64(self) -> int:
+ return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0]
+
+ def _read_char8(self) -> bytes:
+ return struct.unpack("c", self._path_or_buf.read(1))[0]
+
+ def _read_int16_count(self, count: int) -> tuple[int, ...]:
+ return struct.unpack(
+ f"{self._byteorder}{'h' * count}",
+ self._path_or_buf.read(2 * count),
+ )
+
def _read_header(self) -> None:
- first_char = self.path_or_buf.read(1)
- if struct.unpack("c", first_char)[0] == b"<":
+ first_char = self._read_char8()
+ if first_char == b"<":
self._read_new_header()
else:
self._read_old_header(first_char)
- self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0
+ self._has_string_data = len([x for x in self._typlist if type(x) is int]) > 0
# calculate size of a data record
- self.col_sizes = [self._calcsize(typ) for typ in self.typlist]
+ self._col_sizes = [self._calcsize(typ) for typ in self._typlist]
def _read_new_header(self) -> None:
# The first part of the header is common to 117 - 119.
- self.path_or_buf.read(27) # stata_dta>
- self.format_version = int(self.path_or_buf.read(3))
- if self.format_version not in [117, 118, 119]:
- raise ValueError(_version_error.format(version=self.format_version))
+ self._path_or_buf.read(27) # stata_dta>
+ self._format_version = int(self._path_or_buf.read(3))
+ if self._format_version not in [117, 118, 119]:
+ raise ValueError(_version_error.format(version=self._format_version))
self._set_encoding()
- self.path_or_buf.read(21) #
- self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<"
- self.path_or_buf.read(15) #
- nvar_type = "H" if self.format_version <= 118 else "I"
- nvar_size = 2 if self.format_version <= 118 else 4
- self.nvar = struct.unpack(
- self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
- )[0]
- self.path_or_buf.read(7) #
-
- self.nobs = self._get_nobs()
- self.path_or_buf.read(11) #
- self.time_stamp = self._get_time_stamp()
- self.path_or_buf.read(26) #
+ self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<"
+ self._path_or_buf.read(15) #
+ self._nvar = (
+ self._read_uint16() if self._format_version <= 118 else self._read_uint32()
)
+ self._path_or_buf.read(7) #
+
+ self._nobs = self._get_nobs()
+ self._path_or_buf.read(11) #
+ self._time_stamp = self._get_time_stamp()
+ self._path_or_buf.read(26) #
+ self._path_or_buf.read(8) # 0x0000000000000000
+ self._path_or_buf.read(8) # position of
+
+ self._seek_vartypes = self._read_int64() + 16
+ self._seek_varnames = self._read_int64() + 10
+ self._seek_sortlist = self._read_int64() + 10
+ self._seek_formats = self._read_int64() + 9
+ self._seek_value_label_names = self._read_int64() + 19
# Requires version-specific treatment
self._seek_variable_labels = self._get_seek_variable_labels()
- self.path_or_buf.read(8) #
- self.data_location = (
- struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
- )
- self.seek_strls = (
- struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
- )
- self.seek_value_labels = (
- struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
- )
+ self._path_or_buf.read(8) #
+ self._data_location = self._read_int64() + 6
+ self._seek_strls = self._read_int64() + 7
+ self._seek_value_labels = self._read_int64() + 14
- self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
+ self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes)
- self.path_or_buf.seek(self._seek_varnames)
- self.varlist = self._get_varlist()
+ self._path_or_buf.seek(self._seek_varnames)
+ self._varlist = self._get_varlist()
- self.path_or_buf.seek(self._seek_sortlist)
- self.srtlist = struct.unpack(
- self.byteorder + ("h" * (self.nvar + 1)),
- self.path_or_buf.read(2 * (self.nvar + 1)),
- )[:-1]
+ self._path_or_buf.seek(self._seek_sortlist)
+ self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
- self.path_or_buf.seek(self._seek_formats)
- self.fmtlist = self._get_fmtlist()
+ self._path_or_buf.seek(self._seek_formats)
+ self._fmtlist = self._get_fmtlist()
- self.path_or_buf.seek(self._seek_value_label_names)
- self.lbllist = self._get_lbllist()
+ self._path_or_buf.seek(self._seek_value_label_names)
+ self._lbllist = self._get_lbllist()
- self.path_or_buf.seek(self._seek_variable_labels)
+ self._path_or_buf.seek(self._seek_variable_labels)
self._variable_labels = self._get_variable_labels()
# Get data type information, works for versions 117-119.
def _get_dtypes(
self, seek_vartypes: int
) -> tuple[list[int | str], list[str | np.dtype]]:
- self.path_or_buf.seek(seek_vartypes)
- raw_typlist = [
- struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
- for _ in range(self.nvar)
- ]
+ self._path_or_buf.seek(seek_vartypes)
+ raw_typlist = [self._read_uint16() for _ in range(self._nvar)]
def f(typ: int) -> int | str:
if typ <= 2045:
@@ -1320,112 +1378,110 @@ def g(typ: int) -> str | np.dtype:
def _get_varlist(self) -> list[str]:
# 33 in order formats, 129 in formats 118 and 119
- b = 33 if self.format_version < 118 else 129
- return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
+ b = 33 if self._format_version < 118 else 129
+ return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
# Returns the format list
def _get_fmtlist(self) -> list[str]:
- if self.format_version >= 118:
+ if self._format_version >= 118:
b = 57
- elif self.format_version > 113:
+ elif self._format_version > 113:
b = 49
- elif self.format_version > 104:
+ elif self._format_version > 104:
b = 12
else:
b = 7
- return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
+ return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
# Returns the label list
def _get_lbllist(self) -> list[str]:
- if self.format_version >= 118:
+ if self._format_version >= 118:
b = 129
- elif self.format_version > 108:
+ elif self._format_version > 108:
b = 33
else:
b = 9
- return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)]
+ return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)]
def _get_variable_labels(self) -> list[str]:
- if self.format_version >= 118:
+ if self._format_version >= 118:
vlblist = [
- self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar)
+ self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar)
]
- elif self.format_version > 105:
+ elif self._format_version > 105:
vlblist = [
- self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar)
+ self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar)
]
else:
vlblist = [
- self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar)
+ self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar)
]
return vlblist
def _get_nobs(self) -> int:
- if self.format_version >= 118:
- return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
+ if self._format_version >= 118:
+ return self._read_uint64()
else:
- return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
+ return self._read_uint32()
def _get_data_label(self) -> str:
- if self.format_version >= 118:
- strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
- return self._decode(self.path_or_buf.read(strlen))
- elif self.format_version == 117:
- strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
- return self._decode(self.path_or_buf.read(strlen))
- elif self.format_version > 105:
- return self._decode(self.path_or_buf.read(81))
+ if self._format_version >= 118:
+ strlen = self._read_uint16()
+ return self._decode(self._path_or_buf.read(strlen))
+ elif self._format_version == 117:
+ strlen = self._read_int8()
+ return self._decode(self._path_or_buf.read(strlen))
+ elif self._format_version > 105:
+ return self._decode(self._path_or_buf.read(81))
else:
- return self._decode(self.path_or_buf.read(32))
+ return self._decode(self._path_or_buf.read(32))
def _get_time_stamp(self) -> str:
- if self.format_version >= 118:
- strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
- return self.path_or_buf.read(strlen).decode("utf-8")
- elif self.format_version == 117:
- strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
- return self._decode(self.path_or_buf.read(strlen))
- elif self.format_version > 104:
- return self._decode(self.path_or_buf.read(18))
+ if self._format_version >= 118:
+ strlen = self._read_int8()
+ return self._path_or_buf.read(strlen).decode("utf-8")
+ elif self._format_version == 117:
+ strlen = self._read_int8()
+ return self._decode(self._path_or_buf.read(strlen))
+ elif self._format_version > 104:
+ return self._decode(self._path_or_buf.read(18))
else:
raise ValueError()
def _get_seek_variable_labels(self) -> int:
- if self.format_version == 117:
- self.path_or_buf.read(8) # , throw away
+ if self._format_version == 117:
+ self._path_or_buf.read(8) # , throw away
# Stata 117 data files do not follow the described format. This is
# a work around that uses the previous label, 33 bytes for each
# variable, 20 for the closing tag and 17 for the opening tag
- return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
- elif self.format_version >= 118:
- return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
+ return self._seek_value_label_names + (33 * self._nvar) + 20 + 17
+ elif self._format_version >= 118:
+ return self._read_int64() + 17
else:
raise ValueError()
def _read_old_header(self, first_char: bytes) -> None:
- self.format_version = struct.unpack("b", first_char)[0]
- if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
- raise ValueError(_version_error.format(version=self.format_version))
+ self._format_version = int(first_char[0])
+ if self._format_version not in [104, 105, 108, 111, 113, 114, 115]:
+ raise ValueError(_version_error.format(version=self._format_version))
self._set_encoding()
- self.byteorder = (
- ">" if struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 else "<"
- )
- self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
- self.path_or_buf.read(1) # unused
+ self._byteorder = ">" if self._read_int8() == 0x1 else "<"
+ self._filetype = self._read_int8()
+ self._path_or_buf.read(1) # unused
- self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
- self.nobs = self._get_nobs()
+ self._nvar = self._read_uint16()
+ self._nobs = self._get_nobs()
self._data_label = self._get_data_label()
- self.time_stamp = self._get_time_stamp()
+ self._time_stamp = self._get_time_stamp()
# descriptors
- if self.format_version > 108:
- typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
+ if self._format_version > 108:
+ typlist = [int(c) for c in self._path_or_buf.read(self._nvar)]
else:
- buf = self.path_or_buf.read(self.nvar)
+ buf = self._path_or_buf.read(self._nvar)
typlistb = np.frombuffer(buf, dtype=np.uint8)
typlist = []
for tp in typlistb:
@@ -1435,32 +1491,29 @@ def _read_old_header(self, first_char: bytes) -> None:
typlist.append(tp - 127) # bytes
try:
- self.typlist = [self.TYPE_MAP[typ] for typ in typlist]
+ self._typlist = [self.TYPE_MAP[typ] for typ in typlist]
except ValueError as err:
invalid_types = ",".join([str(x) for x in typlist])
raise ValueError(f"cannot convert stata types [{invalid_types}]") from err
try:
- self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
+ self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
except ValueError as err:
invalid_dtypes = ",".join([str(x) for x in typlist])
raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err
- if self.format_version > 108:
- self.varlist = [
- self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar)
+ if self._format_version > 108:
+ self._varlist = [
+ self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar)
]
else:
- self.varlist = [
- self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
+ self._varlist = [
+ self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar)
]
- self.srtlist = struct.unpack(
- self.byteorder + ("h" * (self.nvar + 1)),
- self.path_or_buf.read(2 * (self.nvar + 1)),
- )[:-1]
+ self._srtlist = self._read_int16_count(self._nvar + 1)[:-1]
- self.fmtlist = self._get_fmtlist()
+ self._fmtlist = self._get_fmtlist()
- self.lbllist = self._get_lbllist()
+ self._lbllist = self._get_lbllist()
self._variable_labels = self._get_variable_labels()
@@ -1469,25 +1522,19 @@ def _read_old_header(self, first_char: bytes) -> None:
# the size of the next read, which you discard. You then continue
# like this until you read 5 bytes of zeros.
- if self.format_version > 104:
+ if self._format_version > 104:
while True:
- data_type = struct.unpack(
- self.byteorder + "b", self.path_or_buf.read(1)
- )[0]
- if self.format_version > 108:
- data_len = struct.unpack(
- self.byteorder + "i", self.path_or_buf.read(4)
- )[0]
+ data_type = self._read_int8()
+ if self._format_version > 108:
+ data_len = self._read_int32()
else:
- data_len = struct.unpack(
- self.byteorder + "h", self.path_or_buf.read(2)
- )[0]
+ data_len = self._read_int16()
if data_type == 0:
break
- self.path_or_buf.read(data_len)
+ self._path_or_buf.read(data_len)
# necessary data to continue parsing
- self.data_location = self.path_or_buf.tell()
+ self._data_location = self._path_or_buf.tell()
def _setup_dtype(self) -> np.dtype:
"""Map between numpy and state dtypes"""
@@ -1495,12 +1542,12 @@ def _setup_dtype(self) -> np.dtype:
return self._dtype
dtypes = [] # Convert struct data types to numpy data type
- for i, typ in enumerate(self.typlist):
+ for i, typ in enumerate(self._typlist):
if typ in self.NUMPY_TYPE_MAP:
typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
- dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ]))
+ dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}"))
else:
- dtypes.append(("s" + str(i), "S" + str(typ)))
+ dtypes.append((f"s{i}", f"S{typ}"))
self._dtype = np.dtype(dtypes)
return self._dtype
@@ -1508,7 +1555,7 @@ def _setup_dtype(self) -> np.dtype:
def _calcsize(self, fmt: int | str) -> int:
if isinstance(fmt, int):
return fmt
- return struct.calcsize(self.byteorder + fmt)
+ return struct.calcsize(self._byteorder + fmt)
def _decode(self, s: bytes) -> str:
# have bytes not strings, so must decode
@@ -1532,82 +1579,85 @@ def _decode(self, s: bytes) -> str:
return s.decode("latin-1")
def _read_value_labels(self) -> None:
+ self._ensure_open()
if self._value_labels_read:
# Don't read twice
return
- if self.format_version <= 108:
+ if self._format_version <= 108:
# Value labels are not supported in version 108 and earlier.
self._value_labels_read = True
- self.value_label_dict: dict[str, dict[float, str]] = {}
+ self._value_label_dict: dict[str, dict[float, str]] = {}
return
- if self.format_version >= 117:
- self.path_or_buf.seek(self.seek_value_labels)
+ if self._format_version >= 117:
+ self._path_or_buf.seek(self._seek_value_labels)
else:
assert self._dtype is not None
- offset = self.nobs * self._dtype.itemsize
- self.path_or_buf.seek(self.data_location + offset)
+ offset = self._nobs * self._dtype.itemsize
+ self._path_or_buf.seek(self._data_location + offset)
self._value_labels_read = True
- self.value_label_dict = {}
+ self._value_label_dict = {}
while True:
- if self.format_version >= 117:
- if self.path_or_buf.read(5) == b"
+ if self._format_version >= 117:
+ if self._path_or_buf.read(5) == b"
break # end of value label table
- slength = self.path_or_buf.read(4)
+ slength = self._path_or_buf.read(4)
if not slength:
break # end of value label table (format < 117)
- if self.format_version <= 117:
- labname = self._decode(self.path_or_buf.read(33))
+ if self._format_version <= 117:
+ labname = self._decode(self._path_or_buf.read(33))
else:
- labname = self._decode(self.path_or_buf.read(129))
- self.path_or_buf.read(3) # padding
+ labname = self._decode(self._path_or_buf.read(129))
+ self._path_or_buf.read(3) # padding
- n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
- txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
+ n = self._read_uint32()
+ txtlen = self._read_uint32()
off = np.frombuffer(
- self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
+ self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
)
val = np.frombuffer(
- self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
+ self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n
)
ii = np.argsort(off)
off = off[ii]
val = val[ii]
- txt = self.path_or_buf.read(txtlen)
- self.value_label_dict[labname] = {}
+ txt = self._path_or_buf.read(txtlen)
+ self._value_label_dict[labname] = {}
for i in range(n):
end = off[i + 1] if i < n - 1 else txtlen
- self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end])
- if self.format_version >= 117:
- self.path_or_buf.read(6) #
+ self._value_label_dict[labname][val[i]] = self._decode(
+ txt[off[i] : end]
+ )
+ if self._format_version >= 117:
+ self._path_or_buf.read(6) #
self._value_labels_read = True
def _read_strls(self) -> None:
- self.path_or_buf.seek(self.seek_strls)
+ self._path_or_buf.seek(self._seek_strls)
# Wrap v_o in a string to allow uint64 values as keys on 32bit OS
self.GSO = {"0": ""}
while True:
- if self.path_or_buf.read(3) != b"GSO":
+ if self._path_or_buf.read(3) != b"GSO":
break
- if self.format_version == 117:
- v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
+ if self._format_version == 117:
+ v_o = self._read_uint64()
else:
- buf = self.path_or_buf.read(12)
+ buf = self._path_or_buf.read(12)
# Only tested on little endian file on little endian machine.
- v_size = 2 if self.format_version == 118 else 3
- if self.byteorder == "<":
+ v_size = 2 if self._format_version == 118 else 3
+ if self._byteorder == "<":
buf = buf[0:v_size] + buf[4 : (12 - v_size)]
else:
# This path may not be correct, impossible to test
buf = buf[0:v_size] + buf[(4 + v_size) :]
v_o = struct.unpack("Q", buf)[0]
- typ = struct.unpack("B", self.path_or_buf.read(1))[0]
- length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
- va = self.path_or_buf.read(length)
+ typ = self._read_uint8()
+ length = self._read_uint32()
+ va = self._path_or_buf.read(length)
if typ == 130:
decoded_va = va[0:-1].decode(self._encoding)
else:
@@ -1649,14 +1699,14 @@ def read(
columns: Sequence[str] | None = None,
order_categoricals: bool | None = None,
) -> DataFrame:
+ self._ensure_open()
# Handle empty file or chunk. If reading incrementally raise
# StopIteration. If reading the whole thing return an empty
# data frame.
- if (self.nobs == 0) and (nrows is None):
+ if (self._nobs == 0) and (nrows is None):
self._can_read_value_labels = True
self._data_read = True
- self.close()
- return DataFrame(columns=self.varlist)
+ return DataFrame(columns=self._varlist)
# Handle options
if convert_dates is None:
@@ -1675,16 +1725,16 @@ def read(
index_col = self._index_col
if nrows is None:
- nrows = self.nobs
+ nrows = self._nobs
- if (self.format_version >= 117) and (not self._value_labels_read):
+ if (self._format_version >= 117) and (not self._value_labels_read):
self._can_read_value_labels = True
self._read_strls()
# Read data
assert self._dtype is not None
dtype = self._dtype
- max_read_len = (self.nobs - self._lines_read) * dtype.itemsize
+ max_read_len = (self._nobs - self._lines_read) * dtype.itemsize
read_len = nrows * dtype.itemsize
read_len = min(read_len, max_read_len)
if read_len <= 0:
@@ -1692,31 +1742,30 @@ def read(
# we are reading the file incrementally
if convert_categoricals:
self._read_value_labels()
- self.close()
raise StopIteration
offset = self._lines_read * dtype.itemsize
- self.path_or_buf.seek(self.data_location + offset)
- read_lines = min(nrows, self.nobs - self._lines_read)
+ self._path_or_buf.seek(self._data_location + offset)
+ read_lines = min(nrows, self._nobs - self._lines_read)
raw_data = np.frombuffer(
- self.path_or_buf.read(read_len), dtype=dtype, count=read_lines
+ self._path_or_buf.read(read_len), dtype=dtype, count=read_lines
)
self._lines_read += read_lines
- if self._lines_read == self.nobs:
+ if self._lines_read == self._nobs:
self._can_read_value_labels = True
self._data_read = True
# if necessary, swap the byte order to native here
- if self.byteorder != self._native_byteorder:
+ if self._byteorder != self._native_byteorder:
raw_data = raw_data.byteswap().newbyteorder()
if convert_categoricals:
self._read_value_labels()
if len(raw_data) == 0:
- data = DataFrame(columns=self.varlist)
+ data = DataFrame(columns=self._varlist)
else:
data = DataFrame.from_records(raw_data)
- data.columns = Index(self.varlist)
+ data.columns = Index(self._varlist)
# If index is not specified, use actual row number rather than
# restarting at 0 for each chunk.
@@ -1725,32 +1774,28 @@ def read(
data.index = Index(rng) # set attr instead of set_index to avoid copy
if columns is not None:
- try:
- data = self._do_select_columns(data, columns)
- except ValueError:
- self.close()
- raise
+ data = self._do_select_columns(data, columns)
# Decode strings
- for col, typ in zip(data, self.typlist):
+ for col, typ in zip(data, self._typlist):
if type(typ) is int:
data[col] = data[col].apply(self._decode, convert_dtype=True)
data = self._insert_strls(data)
- cols_ = np.where([dtyp is not None for dtyp in self.dtyplist])[0]
+ cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0]
# Convert columns (if needed) to match input type
ix = data.index
requires_type_conversion = False
data_formatted = []
for i in cols_:
- if self.dtyplist[i] is not None:
+ if self._dtyplist[i] is not None:
col = data.columns[i]
dtype = data[col].dtype
- if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
+ if dtype != np.dtype(object) and dtype != self._dtyplist[i]:
requires_type_conversion = True
data_formatted.append(
- (col, Series(data[col], ix, self.dtyplist[i]))
+ (col, Series(data[col], ix, self._dtyplist[i]))
)
else:
data_formatted.append((col, data[col]))
@@ -1765,20 +1810,16 @@ def read(
def any_startswith(x: str) -> bool:
return any(x.startswith(fmt) for fmt in _date_formats)
- cols = np.where([any_startswith(x) for x in self.fmtlist])[0]
+ cols = np.where([any_startswith(x) for x in self._fmtlist])[0]
for i in cols:
col = data.columns[i]
- try:
- data[col] = _stata_elapsed_date_to_datetime_vec(
- data[col], self.fmtlist[i]
- )
- except ValueError:
- self.close()
- raise
+ data[col] = _stata_elapsed_date_to_datetime_vec(
+ data[col], self._fmtlist[i]
+ )
- if convert_categoricals and self.format_version > 108:
+ if convert_categoricals and self._format_version > 108:
data = self._do_convert_categoricals(
- data, self.value_label_dict, self.lbllist, order_categoricals
+ data, self._value_label_dict, self._lbllist, order_categoricals
)
if not preserve_dtypes:
@@ -1809,7 +1850,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra
# Check for missing values, and replace if found
replacements = {}
for i, colname in enumerate(data):
- fmt = self.typlist[i]
+ fmt = self._typlist[i]
if fmt not in self.VALID_RANGE:
continue
@@ -1855,7 +1896,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra
def _insert_strls(self, data: DataFrame) -> DataFrame:
if not hasattr(self, "GSO") or len(self.GSO) == 0:
return data
- for i, typ in enumerate(self.typlist):
+ for i, typ in enumerate(self._typlist):
if typ != "Q":
continue
# Wrap v_o in a string to allow uint64 values as keys on 32bit OS
@@ -1881,15 +1922,15 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra
lbllist = []
for col in columns:
i = data.columns.get_loc(col)
- dtyplist.append(self.dtyplist[i])
- typlist.append(self.typlist[i])
- fmtlist.append(self.fmtlist[i])
- lbllist.append(self.lbllist[i])
-
- self.dtyplist = dtyplist
- self.typlist = typlist
- self.fmtlist = fmtlist
- self.lbllist = lbllist
+ dtyplist.append(self._dtyplist[i])
+ typlist.append(self._typlist[i])
+ fmtlist.append(self._fmtlist[i])
+ lbllist.append(self._lbllist[i])
+
+ self._dtyplist = dtyplist
+ self._typlist = typlist
+ self._fmtlist = fmtlist
+ self._lbllist = lbllist
self._column_selector_set = True
return data[columns]
@@ -1976,8 +2017,17 @@ def data_label(self) -> str:
"""
Return data label of Stata file.
"""
+ self._ensure_open()
return self._data_label
+ @property
+ def time_stamp(self) -> str:
+ """
+ Return time stamp of Stata file.
+ """
+ self._ensure_open()
+ return self._time_stamp
+
def variable_labels(self) -> dict[str, str]:
"""
Return a dict associating each variable name with corresponding label.
@@ -1986,7 +2036,8 @@ def variable_labels(self) -> dict[str, str]:
-------
dict
"""
- return dict(zip(self.varlist, self._variable_labels))
+ self._ensure_open()
+ return dict(zip(self._varlist, self._variable_labels))
def value_labels(self) -> dict[str, dict[float, str]]:
"""
@@ -1999,7 +2050,7 @@ def value_labels(self) -> dict[str, dict[float, str]]:
if not self._value_labels_read:
self._read_value_labels()
- return self.value_label_dict
+ return self._value_label_dict
@Appender(_read_stata_doc)
diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py
index 5393a15cff19b..75e9f7b744caa 100644
--- a/pandas/tests/io/test_stata.py
+++ b/pandas/tests/io/test_stata.py
@@ -736,10 +736,8 @@ def test_minimal_size_col(self):
original.to_stata(path, write_index=False)
with StataReader(path) as sr:
- typlist = sr.typlist
- variables = sr.varlist
- formats = sr.fmtlist
- for variable, fmt, typ in zip(variables, formats, typlist):
+ sr._ensure_open() # The `_*list` variables are initialized here
+ for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist):
assert int(variable[1:]) == int(fmt[1:-1])
assert int(variable[1:]) == typ
@@ -1891,6 +1889,44 @@ def test_backward_compat(version, datapath):
tm.assert_frame_equal(old_dta, expected, check_dtype=False)
+def test_direct_read(datapath, monkeypatch):
+ file_path = datapath("io", "data", "stata", "stata-compat-118.dta")
+
+ # Test that opening a file path doesn't buffer the file.
+ with StataReader(file_path) as reader:
+ # Must not have been buffered to memory
+ assert not reader.read().empty
+ assert not isinstance(reader._path_or_buf, io.BytesIO)
+
+ # Test that we use a given fp exactly, if possible.
+ with open(file_path, "rb") as fp:
+ with StataReader(fp) as reader:
+ assert not reader.read().empty
+ assert reader._path_or_buf is fp
+
+ # Test that we use a given BytesIO exactly, if possible.
+ with open(file_path, "rb") as fp:
+ with io.BytesIO(fp.read()) as bio:
+ with StataReader(bio) as reader:
+ assert not reader.read().empty
+ assert reader._path_or_buf is bio
+
+
+def test_statareader_warns_when_used_without_context(datapath):
+ file_path = datapath("io", "data", "stata", "stata-compat-118.dta")
+ with tm.assert_produces_warning(
+ ResourceWarning,
+ match="without using a context manager",
+ ):
+ sr = StataReader(file_path)
+ sr.read()
+ with tm.assert_produces_warning(
+ FutureWarning,
+ match="is not part of the public API",
+ ):
+ sr.close()
+
+
@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
@pytest.mark.parametrize("use_dict", [True, False])
@pytest.mark.parametrize("infer", [True, False])