From 93a43322e68fd01d669163888390d90fb0caa93c Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 24 Feb 2023 00:09:27 +0200 Subject: [PATCH] Backport PR #49228: CLN/FIX/PERF: Don't buffer entire Stata file into memory --- doc/source/user_guide/io.rst | 8 + doc/source/whatsnew/v2.0.0.rst | 3 + pandas/io/stata.py | 549 ++++++++++++++++++--------------- pandas/tests/io/test_stata.py | 44 ++- 4 files changed, 351 insertions(+), 253 deletions(-) diff --git a/doc/source/user_guide/io.rst b/doc/source/user_guide/io.rst index 91cd3335d9db6..3c3a655626bb6 100644 --- a/doc/source/user_guide/io.rst +++ b/doc/source/user_guide/io.rst @@ -6033,6 +6033,14 @@ values will have ``object`` data type. ``int64`` for all integer types and ``float64`` for floating point data. By default, the Stata data types are preserved when importing. +.. note:: + + All :class:`~pandas.io.stata.StataReader` objects, whether created by :func:`~pandas.read_stata` + (when using ``iterator=True`` or ``chunksize``) or instantiated by hand, must be used as context + managers (e.g. the ``with`` statement). + While the :meth:`~pandas.io.stata.StataReader.close` method is available, its use is unsupported. + It is not part of the public API and will be removed in with future without warning. + .. ipython:: python :suppress: diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index bdbde438217b9..a8d6f3fce5bb7 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -857,6 +857,7 @@ Deprecations - Deprecated :meth:`Series.backfill` in favor of :meth:`Series.bfill` (:issue:`33396`) - Deprecated :meth:`DataFrame.pad` in favor of :meth:`DataFrame.ffill` (:issue:`33396`) - Deprecated :meth:`DataFrame.backfill` in favor of :meth:`DataFrame.bfill` (:issue:`33396`) +- Deprecated :meth:`~pandas.io.stata.StataReader.close`. Use :class:`~pandas.io.stata.StataReader` as a context manager instead (:issue:`49228`) .. --------------------------------------------------------------------------- .. _whatsnew_200.prior_deprecations: @@ -1163,6 +1164,8 @@ Performance improvements - Fixed a reference leak in :func:`read_hdf` (:issue:`37441`) - Fixed a memory leak in :meth:`DataFrame.to_json` and :meth:`Series.to_json` when serializing datetimes and timedeltas (:issue:`40443`) - Decreased memory usage in many :class:`DataFrameGroupBy` methods (:issue:`51090`) +- Memory improvement in :class:`StataReader` when reading seekable files (:issue:`48922`) + .. --------------------------------------------------------------------------- .. _whatsnew_200.bug_fixes: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 35ca4e1f6b6c4..9542ae46a0d05 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -23,6 +23,7 @@ TYPE_CHECKING, Any, AnyStr, + Callable, Final, Hashable, Sequence, @@ -182,10 +183,10 @@ >>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP >>> df.to_stata('filename.dta') # doctest: +SKIP ->>> itr = pd.read_stata('filename.dta', chunksize=10000) # doctest: +SKIP ->>> for chunk in itr: -... # Operate on a single chunk, e.g., chunk.mean() -... pass # doctest: +SKIP +>>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP +>>> for chunk in itr: +... # Operate on a single chunk, e.g., chunk.mean() +... pass # doctest: +SKIP """ _read_method_doc = f"""\ @@ -1114,6 +1115,8 @@ def __init__(self) -> None: class StataReader(StataParser, abc.Iterator): __doc__ = _stata_reader_doc + _path_or_buf: IO[bytes] # Will be assigned by `_open_file`. + def __init__( self, path_or_buf: FilePath | ReadBuffer[bytes], @@ -1129,7 +1132,7 @@ def __init__( storage_options: StorageOptions = None, ) -> None: super().__init__() - self.col_sizes: list[int] = [] + self._col_sizes: list[int] = [] # Arguments to the reader (can be temporarily overridden in # calls to read). @@ -1140,15 +1143,20 @@ def __init__( self._preserve_dtypes = preserve_dtypes self._columns = columns self._order_categoricals = order_categoricals + self._original_path_or_buf = path_or_buf + self._compression = compression + self._storage_options = storage_options self._encoding = "" self._chunksize = chunksize self._using_iterator = False + self._entered = False if self._chunksize is None: self._chunksize = 1 elif not isinstance(chunksize, int) or chunksize <= 0: raise ValueError("chunksize must be a positive integer when set.") # State variables for the file + self._close_file: Callable[[], None] | None = None self._has_string_data = False self._missing_values = False self._can_read_value_labels = False @@ -1159,21 +1167,48 @@ def __init__( self._lines_read = 0 self._native_byteorder = _set_endianness(sys.byteorder) - with get_handle( - path_or_buf, + + def _ensure_open(self) -> None: + """ + Ensure the file has been opened and its header data read. + """ + if not hasattr(self, "_path_or_buf"): + self._open_file() + + def _open_file(self) -> None: + """ + Open the file (with compression options, etc.), and read header information. + """ + if not self._entered: + warnings.warn( + "StataReader is being used without using a context manager. " + "Using StataReader as a context manager is the only supported method.", + ResourceWarning, + stacklevel=find_stack_level(), + ) + handles = get_handle( + self._original_path_or_buf, "rb", - storage_options=storage_options, + storage_options=self._storage_options, is_text=False, - compression=compression, - ) as handles: - # Copy to BytesIO, and ensure no encoding - self.path_or_buf = BytesIO(handles.handle.read()) + compression=self._compression, + ) + if hasattr(handles.handle, "seekable") and handles.handle.seekable(): + # If the handle is directly seekable, use it without an extra copy. + self._path_or_buf = handles.handle + self._close_file = handles.close + else: + # Copy to memory, and ensure no encoding. + with handles: + self._path_or_buf = BytesIO(handles.handle.read()) + self._close_file = self._path_or_buf.close self._read_header() self._setup_dtype() def __enter__(self) -> StataReader: """enter context manager""" + self._entered = True return self def __exit__( @@ -1182,119 +1217,142 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """exit context manager""" - self.close() + if self._close_file: + self._close_file() def close(self) -> None: - """close the handle if its open""" - self.path_or_buf.close() + """Close the handle if its open. + + .. deprecated: 2.0.0 + + The close method is not part of the public API. + The only supported way to use StataReader is to use it as a context manager. + """ + warnings.warn( + "The StataReader.close() method is not part of the public API and " + "will be removed in a future version without notice. " + "Using StataReader as a context manager is the only supported method.", + FutureWarning, + stacklevel=find_stack_level(), + ) + if self._close_file: + self._close_file() def _set_encoding(self) -> None: """ Set string encoding which depends on file version """ - if self.format_version < 118: + if self._format_version < 118: self._encoding = "latin-1" else: self._encoding = "utf-8" + def _read_int8(self) -> int: + return struct.unpack("b", self._path_or_buf.read(1))[0] + + def _read_uint8(self) -> int: + return struct.unpack("B", self._path_or_buf.read(1))[0] + + def _read_uint16(self) -> int: + return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0] + + def _read_uint32(self) -> int: + return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0] + + def _read_uint64(self) -> int: + return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0] + + def _read_int16(self) -> int: + return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0] + + def _read_int32(self) -> int: + return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0] + + def _read_int64(self) -> int: + return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0] + + def _read_char8(self) -> bytes: + return struct.unpack("c", self._path_or_buf.read(1))[0] + + def _read_int16_count(self, count: int) -> tuple[int, ...]: + return struct.unpack( + f"{self._byteorder}{'h' * count}", + self._path_or_buf.read(2 * count), + ) + def _read_header(self) -> None: - first_char = self.path_or_buf.read(1) - if struct.unpack("c", first_char)[0] == b"<": + first_char = self._read_char8() + if first_char == b"<": self._read_new_header() else: self._read_old_header(first_char) - self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0 + self._has_string_data = len([x for x in self._typlist if type(x) is int]) > 0 # calculate size of a data record - self.col_sizes = [self._calcsize(typ) for typ in self.typlist] + self._col_sizes = [self._calcsize(typ) for typ in self._typlist] def _read_new_header(self) -> None: # The first part of the header is common to 117 - 119. - self.path_or_buf.read(27) # stata_dta>
- self.format_version = int(self.path_or_buf.read(3)) - if self.format_version not in [117, 118, 119]: - raise ValueError(_version_error.format(version=self.format_version)) + self._path_or_buf.read(27) # stata_dta>
+ self._format_version = int(self._path_or_buf.read(3)) + if self._format_version not in [117, 118, 119]: + raise ValueError(_version_error.format(version=self._format_version)) self._set_encoding() - self.path_or_buf.read(21) # - self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<" - self.path_or_buf.read(15) # - nvar_type = "H" if self.format_version <= 118 else "I" - nvar_size = 2 if self.format_version <= 118 else 4 - self.nvar = struct.unpack( - self.byteorder + nvar_type, self.path_or_buf.read(nvar_size) - )[0] - self.path_or_buf.read(7) # - - self.nobs = self._get_nobs() - self.path_or_buf.read(11) # - self.time_stamp = self._get_time_stamp() - self.path_or_buf.read(26) #
- self.path_or_buf.read(8) # 0x0000000000000000 - self.path_or_buf.read(8) # position of - - self._seek_vartypes = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16 - ) - self._seek_varnames = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) - self._seek_sortlist = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10 - ) - self._seek_formats = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9 - ) - self._seek_value_label_names = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19 + self._path_or_buf.read(21) #
+ self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<" + self._path_or_buf.read(15) # + self._nvar = ( + self._read_uint16() if self._format_version <= 118 else self._read_uint32() ) + self._path_or_buf.read(7) # + + self._nobs = self._get_nobs() + self._path_or_buf.read(11) # + self._time_stamp = self._get_time_stamp() + self._path_or_buf.read(26) #
+ self._path_or_buf.read(8) # 0x0000000000000000 + self._path_or_buf.read(8) # position of + + self._seek_vartypes = self._read_int64() + 16 + self._seek_varnames = self._read_int64() + 10 + self._seek_sortlist = self._read_int64() + 10 + self._seek_formats = self._read_int64() + 9 + self._seek_value_label_names = self._read_int64() + 19 # Requires version-specific treatment self._seek_variable_labels = self._get_seek_variable_labels() - self.path_or_buf.read(8) # - self.data_location = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6 - ) - self.seek_strls = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7 - ) - self.seek_value_labels = ( - struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14 - ) + self._path_or_buf.read(8) # + self._data_location = self._read_int64() + 6 + self._seek_strls = self._read_int64() + 7 + self._seek_value_labels = self._read_int64() + 14 - self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes) + self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes) - self.path_or_buf.seek(self._seek_varnames) - self.varlist = self._get_varlist() + self._path_or_buf.seek(self._seek_varnames) + self._varlist = self._get_varlist() - self.path_or_buf.seek(self._seek_sortlist) - self.srtlist = struct.unpack( - self.byteorder + ("h" * (self.nvar + 1)), - self.path_or_buf.read(2 * (self.nvar + 1)), - )[:-1] + self._path_or_buf.seek(self._seek_sortlist) + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] - self.path_or_buf.seek(self._seek_formats) - self.fmtlist = self._get_fmtlist() + self._path_or_buf.seek(self._seek_formats) + self._fmtlist = self._get_fmtlist() - self.path_or_buf.seek(self._seek_value_label_names) - self.lbllist = self._get_lbllist() + self._path_or_buf.seek(self._seek_value_label_names) + self._lbllist = self._get_lbllist() - self.path_or_buf.seek(self._seek_variable_labels) + self._path_or_buf.seek(self._seek_variable_labels) self._variable_labels = self._get_variable_labels() # Get data type information, works for versions 117-119. def _get_dtypes( self, seek_vartypes: int ) -> tuple[list[int | str], list[str | np.dtype]]: - self.path_or_buf.seek(seek_vartypes) - raw_typlist = [ - struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] - for _ in range(self.nvar) - ] + self._path_or_buf.seek(seek_vartypes) + raw_typlist = [self._read_uint16() for _ in range(self._nvar)] def f(typ: int) -> int | str: if typ <= 2045: @@ -1320,112 +1378,110 @@ def g(typ: int) -> str | np.dtype: def _get_varlist(self) -> list[str]: # 33 in order formats, 129 in formats 118 and 119 - b = 33 if self.format_version < 118 else 129 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + b = 33 if self._format_version < 118 else 129 + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] # Returns the format list def _get_fmtlist(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: b = 57 - elif self.format_version > 113: + elif self._format_version > 113: b = 49 - elif self.format_version > 104: + elif self._format_version > 104: b = 12 else: b = 7 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] # Returns the label list def _get_lbllist(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: b = 129 - elif self.format_version > 108: + elif self._format_version > 108: b = 33 else: b = 9 - return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] def _get_variable_labels(self) -> list[str]: - if self.format_version >= 118: + if self._format_version >= 118: vlblist = [ - self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar) ] - elif self.format_version > 105: + elif self._format_version > 105: vlblist = [ - self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar) ] else: vlblist = [ - self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar) + self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar) ] return vlblist def _get_nobs(self) -> int: - if self.format_version >= 118: - return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + if self._format_version >= 118: + return self._read_uint64() else: - return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + return self._read_uint32() def _get_data_label(self) -> str: - if self.format_version >= 118: - strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version == 117: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version > 105: - return self._decode(self.path_or_buf.read(81)) + if self._format_version >= 118: + strlen = self._read_uint16() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version == 117: + strlen = self._read_int8() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 105: + return self._decode(self._path_or_buf.read(81)) else: - return self._decode(self.path_or_buf.read(32)) + return self._decode(self._path_or_buf.read(32)) def _get_time_stamp(self) -> str: - if self.format_version >= 118: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] - return self.path_or_buf.read(strlen).decode("utf-8") - elif self.format_version == 117: - strlen = struct.unpack("b", self.path_or_buf.read(1))[0] - return self._decode(self.path_or_buf.read(strlen)) - elif self.format_version > 104: - return self._decode(self.path_or_buf.read(18)) + if self._format_version >= 118: + strlen = self._read_int8() + return self._path_or_buf.read(strlen).decode("utf-8") + elif self._format_version == 117: + strlen = self._read_int8() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 104: + return self._decode(self._path_or_buf.read(18)) else: raise ValueError() def _get_seek_variable_labels(self) -> int: - if self.format_version == 117: - self.path_or_buf.read(8) # , throw away + if self._format_version == 117: + self._path_or_buf.read(8) # , throw away # Stata 117 data files do not follow the described format. This is # a work around that uses the previous label, 33 bytes for each # variable, 20 for the closing tag and 17 for the opening tag - return self._seek_value_label_names + (33 * self.nvar) + 20 + 17 - elif self.format_version >= 118: - return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17 + return self._seek_value_label_names + (33 * self._nvar) + 20 + 17 + elif self._format_version >= 118: + return self._read_int64() + 17 else: raise ValueError() def _read_old_header(self, first_char: bytes) -> None: - self.format_version = struct.unpack("b", first_char)[0] - if self.format_version not in [104, 105, 108, 111, 113, 114, 115]: - raise ValueError(_version_error.format(version=self.format_version)) + self._format_version = int(first_char[0]) + if self._format_version not in [104, 105, 108, 111, 113, 114, 115]: + raise ValueError(_version_error.format(version=self._format_version)) self._set_encoding() - self.byteorder = ( - ">" if struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 else "<" - ) - self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0] - self.path_or_buf.read(1) # unused + self._byteorder = ">" if self._read_int8() == 0x1 else "<" + self._filetype = self._read_int8() + self._path_or_buf.read(1) # unused - self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] - self.nobs = self._get_nobs() + self._nvar = self._read_uint16() + self._nobs = self._get_nobs() self._data_label = self._get_data_label() - self.time_stamp = self._get_time_stamp() + self._time_stamp = self._get_time_stamp() # descriptors - if self.format_version > 108: - typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)] + if self._format_version > 108: + typlist = [int(c) for c in self._path_or_buf.read(self._nvar)] else: - buf = self.path_or_buf.read(self.nvar) + buf = self._path_or_buf.read(self._nvar) typlistb = np.frombuffer(buf, dtype=np.uint8) typlist = [] for tp in typlistb: @@ -1435,32 +1491,29 @@ def _read_old_header(self, first_char: bytes) -> None: typlist.append(tp - 127) # bytes try: - self.typlist = [self.TYPE_MAP[typ] for typ in typlist] + self._typlist = [self.TYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_types = ",".join([str(x) for x in typlist]) raise ValueError(f"cannot convert stata types [{invalid_types}]") from err try: - self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] + self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] except ValueError as err: invalid_dtypes = ",".join([str(x) for x in typlist]) raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err - if self.format_version > 108: - self.varlist = [ - self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar) + if self._format_version > 108: + self._varlist = [ + self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar) ] else: - self.varlist = [ - self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar) + self._varlist = [ + self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar) ] - self.srtlist = struct.unpack( - self.byteorder + ("h" * (self.nvar + 1)), - self.path_or_buf.read(2 * (self.nvar + 1)), - )[:-1] + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] - self.fmtlist = self._get_fmtlist() + self._fmtlist = self._get_fmtlist() - self.lbllist = self._get_lbllist() + self._lbllist = self._get_lbllist() self._variable_labels = self._get_variable_labels() @@ -1469,25 +1522,19 @@ def _read_old_header(self, first_char: bytes) -> None: # the size of the next read, which you discard. You then continue # like this until you read 5 bytes of zeros. - if self.format_version > 104: + if self._format_version > 104: while True: - data_type = struct.unpack( - self.byteorder + "b", self.path_or_buf.read(1) - )[0] - if self.format_version > 108: - data_len = struct.unpack( - self.byteorder + "i", self.path_or_buf.read(4) - )[0] + data_type = self._read_int8() + if self._format_version > 108: + data_len = self._read_int32() else: - data_len = struct.unpack( - self.byteorder + "h", self.path_or_buf.read(2) - )[0] + data_len = self._read_int16() if data_type == 0: break - self.path_or_buf.read(data_len) + self._path_or_buf.read(data_len) # necessary data to continue parsing - self.data_location = self.path_or_buf.tell() + self._data_location = self._path_or_buf.tell() def _setup_dtype(self) -> np.dtype: """Map between numpy and state dtypes""" @@ -1495,12 +1542,12 @@ def _setup_dtype(self) -> np.dtype: return self._dtype dtypes = [] # Convert struct data types to numpy data type - for i, typ in enumerate(self.typlist): + for i, typ in enumerate(self._typlist): if typ in self.NUMPY_TYPE_MAP: typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP - dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}")) else: - dtypes.append(("s" + str(i), "S" + str(typ))) + dtypes.append((f"s{i}", f"S{typ}")) self._dtype = np.dtype(dtypes) return self._dtype @@ -1508,7 +1555,7 @@ def _setup_dtype(self) -> np.dtype: def _calcsize(self, fmt: int | str) -> int: if isinstance(fmt, int): return fmt - return struct.calcsize(self.byteorder + fmt) + return struct.calcsize(self._byteorder + fmt) def _decode(self, s: bytes) -> str: # have bytes not strings, so must decode @@ -1532,82 +1579,85 @@ def _decode(self, s: bytes) -> str: return s.decode("latin-1") def _read_value_labels(self) -> None: + self._ensure_open() if self._value_labels_read: # Don't read twice return - if self.format_version <= 108: + if self._format_version <= 108: # Value labels are not supported in version 108 and earlier. self._value_labels_read = True - self.value_label_dict: dict[str, dict[float, str]] = {} + self._value_label_dict: dict[str, dict[float, str]] = {} return - if self.format_version >= 117: - self.path_or_buf.seek(self.seek_value_labels) + if self._format_version >= 117: + self._path_or_buf.seek(self._seek_value_labels) else: assert self._dtype is not None - offset = self.nobs * self._dtype.itemsize - self.path_or_buf.seek(self.data_location + offset) + offset = self._nobs * self._dtype.itemsize + self._path_or_buf.seek(self._data_location + offset) self._value_labels_read = True - self.value_label_dict = {} + self._value_label_dict = {} while True: - if self.format_version >= 117: - if self.path_or_buf.read(5) == b" + if self._format_version >= 117: + if self._path_or_buf.read(5) == b" break # end of value label table - slength = self.path_or_buf.read(4) + slength = self._path_or_buf.read(4) if not slength: break # end of value label table (format < 117) - if self.format_version <= 117: - labname = self._decode(self.path_or_buf.read(33)) + if self._format_version <= 117: + labname = self._decode(self._path_or_buf.read(33)) else: - labname = self._decode(self.path_or_buf.read(129)) - self.path_or_buf.read(3) # padding + labname = self._decode(self._path_or_buf.read(129)) + self._path_or_buf.read(3) # padding - n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] + n = self._read_uint32() + txtlen = self._read_uint32() off = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n ) val = np.frombuffer( - self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n ) ii = np.argsort(off) off = off[ii] val = val[ii] - txt = self.path_or_buf.read(txtlen) - self.value_label_dict[labname] = {} + txt = self._path_or_buf.read(txtlen) + self._value_label_dict[labname] = {} for i in range(n): end = off[i + 1] if i < n - 1 else txtlen - self.value_label_dict[labname][val[i]] = self._decode(txt[off[i] : end]) - if self.format_version >= 117: - self.path_or_buf.read(6) # + self._value_label_dict[labname][val[i]] = self._decode( + txt[off[i] : end] + ) + if self._format_version >= 117: + self._path_or_buf.read(6) # self._value_labels_read = True def _read_strls(self) -> None: - self.path_or_buf.seek(self.seek_strls) + self._path_or_buf.seek(self._seek_strls) # Wrap v_o in a string to allow uint64 values as keys on 32bit OS self.GSO = {"0": ""} while True: - if self.path_or_buf.read(3) != b"GSO": + if self._path_or_buf.read(3) != b"GSO": break - if self.format_version == 117: - v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] + if self._format_version == 117: + v_o = self._read_uint64() else: - buf = self.path_or_buf.read(12) + buf = self._path_or_buf.read(12) # Only tested on little endian file on little endian machine. - v_size = 2 if self.format_version == 118 else 3 - if self.byteorder == "<": + v_size = 2 if self._format_version == 118 else 3 + if self._byteorder == "<": buf = buf[0:v_size] + buf[4 : (12 - v_size)] else: # This path may not be correct, impossible to test buf = buf[0:v_size] + buf[(4 + v_size) :] v_o = struct.unpack("Q", buf)[0] - typ = struct.unpack("B", self.path_or_buf.read(1))[0] - length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - va = self.path_or_buf.read(length) + typ = self._read_uint8() + length = self._read_uint32() + va = self._path_or_buf.read(length) if typ == 130: decoded_va = va[0:-1].decode(self._encoding) else: @@ -1649,14 +1699,14 @@ def read( columns: Sequence[str] | None = None, order_categoricals: bool | None = None, ) -> DataFrame: + self._ensure_open() # Handle empty file or chunk. If reading incrementally raise # StopIteration. If reading the whole thing return an empty # data frame. - if (self.nobs == 0) and (nrows is None): + if (self._nobs == 0) and (nrows is None): self._can_read_value_labels = True self._data_read = True - self.close() - return DataFrame(columns=self.varlist) + return DataFrame(columns=self._varlist) # Handle options if convert_dates is None: @@ -1675,16 +1725,16 @@ def read( index_col = self._index_col if nrows is None: - nrows = self.nobs + nrows = self._nobs - if (self.format_version >= 117) and (not self._value_labels_read): + if (self._format_version >= 117) and (not self._value_labels_read): self._can_read_value_labels = True self._read_strls() # Read data assert self._dtype is not None dtype = self._dtype - max_read_len = (self.nobs - self._lines_read) * dtype.itemsize + max_read_len = (self._nobs - self._lines_read) * dtype.itemsize read_len = nrows * dtype.itemsize read_len = min(read_len, max_read_len) if read_len <= 0: @@ -1692,31 +1742,30 @@ def read( # we are reading the file incrementally if convert_categoricals: self._read_value_labels() - self.close() raise StopIteration offset = self._lines_read * dtype.itemsize - self.path_or_buf.seek(self.data_location + offset) - read_lines = min(nrows, self.nobs - self._lines_read) + self._path_or_buf.seek(self._data_location + offset) + read_lines = min(nrows, self._nobs - self._lines_read) raw_data = np.frombuffer( - self.path_or_buf.read(read_len), dtype=dtype, count=read_lines + self._path_or_buf.read(read_len), dtype=dtype, count=read_lines ) self._lines_read += read_lines - if self._lines_read == self.nobs: + if self._lines_read == self._nobs: self._can_read_value_labels = True self._data_read = True # if necessary, swap the byte order to native here - if self.byteorder != self._native_byteorder: + if self._byteorder != self._native_byteorder: raw_data = raw_data.byteswap().newbyteorder() if convert_categoricals: self._read_value_labels() if len(raw_data) == 0: - data = DataFrame(columns=self.varlist) + data = DataFrame(columns=self._varlist) else: data = DataFrame.from_records(raw_data) - data.columns = Index(self.varlist) + data.columns = Index(self._varlist) # If index is not specified, use actual row number rather than # restarting at 0 for each chunk. @@ -1725,32 +1774,28 @@ def read( data.index = Index(rng) # set attr instead of set_index to avoid copy if columns is not None: - try: - data = self._do_select_columns(data, columns) - except ValueError: - self.close() - raise + data = self._do_select_columns(data, columns) # Decode strings - for col, typ in zip(data, self.typlist): + for col, typ in zip(data, self._typlist): if type(typ) is int: data[col] = data[col].apply(self._decode, convert_dtype=True) data = self._insert_strls(data) - cols_ = np.where([dtyp is not None for dtyp in self.dtyplist])[0] + cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0] # Convert columns (if needed) to match input type ix = data.index requires_type_conversion = False data_formatted = [] for i in cols_: - if self.dtyplist[i] is not None: + if self._dtyplist[i] is not None: col = data.columns[i] dtype = data[col].dtype - if dtype != np.dtype(object) and dtype != self.dtyplist[i]: + if dtype != np.dtype(object) and dtype != self._dtyplist[i]: requires_type_conversion = True data_formatted.append( - (col, Series(data[col], ix, self.dtyplist[i])) + (col, Series(data[col], ix, self._dtyplist[i])) ) else: data_formatted.append((col, data[col])) @@ -1765,20 +1810,16 @@ def read( def any_startswith(x: str) -> bool: return any(x.startswith(fmt) for fmt in _date_formats) - cols = np.where([any_startswith(x) for x in self.fmtlist])[0] + cols = np.where([any_startswith(x) for x in self._fmtlist])[0] for i in cols: col = data.columns[i] - try: - data[col] = _stata_elapsed_date_to_datetime_vec( - data[col], self.fmtlist[i] - ) - except ValueError: - self.close() - raise + data[col] = _stata_elapsed_date_to_datetime_vec( + data[col], self._fmtlist[i] + ) - if convert_categoricals and self.format_version > 108: + if convert_categoricals and self._format_version > 108: data = self._do_convert_categoricals( - data, self.value_label_dict, self.lbllist, order_categoricals + data, self._value_label_dict, self._lbllist, order_categoricals ) if not preserve_dtypes: @@ -1809,7 +1850,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra # Check for missing values, and replace if found replacements = {} for i, colname in enumerate(data): - fmt = self.typlist[i] + fmt = self._typlist[i] if fmt not in self.VALID_RANGE: continue @@ -1855,7 +1896,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra def _insert_strls(self, data: DataFrame) -> DataFrame: if not hasattr(self, "GSO") or len(self.GSO) == 0: return data - for i, typ in enumerate(self.typlist): + for i, typ in enumerate(self._typlist): if typ != "Q": continue # Wrap v_o in a string to allow uint64 values as keys on 32bit OS @@ -1881,15 +1922,15 @@ def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFra lbllist = [] for col in columns: i = data.columns.get_loc(col) - dtyplist.append(self.dtyplist[i]) - typlist.append(self.typlist[i]) - fmtlist.append(self.fmtlist[i]) - lbllist.append(self.lbllist[i]) - - self.dtyplist = dtyplist - self.typlist = typlist - self.fmtlist = fmtlist - self.lbllist = lbllist + dtyplist.append(self._dtyplist[i]) + typlist.append(self._typlist[i]) + fmtlist.append(self._fmtlist[i]) + lbllist.append(self._lbllist[i]) + + self._dtyplist = dtyplist + self._typlist = typlist + self._fmtlist = fmtlist + self._lbllist = lbllist self._column_selector_set = True return data[columns] @@ -1976,8 +2017,17 @@ def data_label(self) -> str: """ Return data label of Stata file. """ + self._ensure_open() return self._data_label + @property + def time_stamp(self) -> str: + """ + Return time stamp of Stata file. + """ + self._ensure_open() + return self._time_stamp + def variable_labels(self) -> dict[str, str]: """ Return a dict associating each variable name with corresponding label. @@ -1986,7 +2036,8 @@ def variable_labels(self) -> dict[str, str]: ------- dict """ - return dict(zip(self.varlist, self._variable_labels)) + self._ensure_open() + return dict(zip(self._varlist, self._variable_labels)) def value_labels(self) -> dict[str, dict[float, str]]: """ @@ -1999,7 +2050,7 @@ def value_labels(self) -> dict[str, dict[float, str]]: if not self._value_labels_read: self._read_value_labels() - return self.value_label_dict + return self._value_label_dict @Appender(_read_stata_doc) diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 5393a15cff19b..75e9f7b744caa 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -736,10 +736,8 @@ def test_minimal_size_col(self): original.to_stata(path, write_index=False) with StataReader(path) as sr: - typlist = sr.typlist - variables = sr.varlist - formats = sr.fmtlist - for variable, fmt, typ in zip(variables, formats, typlist): + sr._ensure_open() # The `_*list` variables are initialized here + for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist): assert int(variable[1:]) == int(fmt[1:-1]) assert int(variable[1:]) == typ @@ -1891,6 +1889,44 @@ def test_backward_compat(version, datapath): tm.assert_frame_equal(old_dta, expected, check_dtype=False) +def test_direct_read(datapath, monkeypatch): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + + # Test that opening a file path doesn't buffer the file. + with StataReader(file_path) as reader: + # Must not have been buffered to memory + assert not reader.read().empty + assert not isinstance(reader._path_or_buf, io.BytesIO) + + # Test that we use a given fp exactly, if possible. + with open(file_path, "rb") as fp: + with StataReader(fp) as reader: + assert not reader.read().empty + assert reader._path_or_buf is fp + + # Test that we use a given BytesIO exactly, if possible. + with open(file_path, "rb") as fp: + with io.BytesIO(fp.read()) as bio: + with StataReader(bio) as reader: + assert not reader.read().empty + assert reader._path_or_buf is bio + + +def test_statareader_warns_when_used_without_context(datapath): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + with tm.assert_produces_warning( + ResourceWarning, + match="without using a context manager", + ): + sr = StataReader(file_path) + sr.read() + with tm.assert_produces_warning( + FutureWarning, + match="is not part of the public API", + ): + sr.close() + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) @pytest.mark.parametrize("use_dict", [True, False]) @pytest.mark.parametrize("infer", [True, False])