Skip to content

Commit b3fa178

Browse files
authored
PERF: read_stata for wide columns (pandas-dev#55515)
* PERF: read_stata for wide columns * Add PR number * Remove unused variable * typing * Don't materialize * Simplify
1 parent c9dc91d commit b3fa178

File tree

2 files changed

+42
-81
lines changed

2 files changed

+42
-81
lines changed

doc/source/whatsnew/v2.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ Other Deprecations
277277
Performance improvements
278278
~~~~~~~~~~~~~~~~~~~~~~~~
279279
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
280+
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
280281
- Performance improvement in :func:`to_dict` on converting DataFrame to dictionary (:issue:`50990`)
281282
- Performance improvement in :meth:`DataFrame.groupby` when aggregating pyarrow timestamp and duration dtypes (:issue:`55031`)
282283
- Performance improvement in :meth:`DataFrame.sort_index` and :meth:`Series.sort_index` when indexed by a :class:`MultiIndex` (:issue:`54835`)

pandas/io/stata.py

+41-81
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from pandas.core.arrays.integer import IntegerDtype
6767
from pandas.core.frame import DataFrame
6868
from pandas.core.indexes.base import Index
69+
from pandas.core.indexes.range import RangeIndex
6970
from pandas.core.series import Series
7071
from pandas.core.shared_docs import _shared_docs
7172

@@ -690,10 +691,7 @@ def __init__(
690691
self.labname = catarray.name
691692
self._encoding = encoding
692693
categories = catarray.cat.categories
693-
self.value_labels: list[tuple[float, str]] = list(
694-
zip(np.arange(len(categories)), categories)
695-
)
696-
self.value_labels.sort(key=lambda x: x[0])
694+
self.value_labels = enumerate(categories)
697695

698696
self._prepare_value_labels()
699697

@@ -819,7 +817,7 @@ def __init__(
819817

820818
self.labname = labname
821819
self._encoding = encoding
822-
self.value_labels: list[tuple[float, str]] = sorted(
820+
self.value_labels = sorted( # type: ignore[assignment]
823821
value_labels.items(), key=lambda x: x[0]
824822
)
825823
self._prepare_value_labels()
@@ -1054,7 +1052,7 @@ def __init__(self) -> None:
10541052
}
10551053

10561054
# Reserved words cannot be used as variable names
1057-
self.RESERVED_WORDS = (
1055+
self.RESERVED_WORDS = {
10581056
"aggregate",
10591057
"array",
10601058
"boolean",
@@ -1115,7 +1113,7 @@ def __init__(self) -> None:
11151113
"_se",
11161114
"with",
11171115
"_n",
1118-
)
1116+
}
11191117

11201118

11211119
class StataReader(StataParser, abc.Iterator):
@@ -1138,7 +1136,6 @@ def __init__(
11381136
storage_options: StorageOptions | None = None,
11391137
) -> None:
11401138
super().__init__()
1141-
self._col_sizes: list[int] = []
11421139

11431140
# Arguments to the reader (can be temporarily overridden in
11441141
# calls to read).
@@ -1163,7 +1160,6 @@ def __init__(
11631160

11641161
# State variables for the file
11651162
self._close_file: Callable[[], None] | None = None
1166-
self._has_string_data = False
11671163
self._missing_values = False
11681164
self._can_read_value_labels = False
11691165
self._column_selector_set = False
@@ -1293,13 +1289,6 @@ def _read_header(self) -> None:
12931289
else:
12941290
self._read_old_header(first_char)
12951291

1296-
self._has_string_data = (
1297-
len([x for x in self._typlist if isinstance(x, int)]) > 0
1298-
)
1299-
1300-
# calculate size of a data record
1301-
self._col_sizes = [self._calcsize(typ) for typ in self._typlist]
1302-
13031292
def _read_new_header(self) -> None:
13041293
# The first part of the header is common to 117 - 119.
13051294
self._path_or_buf.read(27) # stata_dta><header><release>
@@ -1360,29 +1349,21 @@ def _get_dtypes(
13601349
self, seek_vartypes: int
13611350
) -> tuple[list[int | str], list[str | np.dtype]]:
13621351
self._path_or_buf.seek(seek_vartypes)
1363-
raw_typlist = [self._read_uint16() for _ in range(self._nvar)]
1364-
1365-
def f(typ: int) -> int | str:
1352+
typlist = []
1353+
dtyplist = []
1354+
for _ in range(self._nvar):
1355+
typ = self._read_uint16()
13661356
if typ <= 2045:
1367-
return typ
1368-
try:
1369-
return self.TYPE_MAP_XML[typ]
1370-
except KeyError as err:
1371-
raise ValueError(f"cannot convert stata types [{typ}]") from err
1372-
1373-
typlist = [f(x) for x in raw_typlist]
1374-
1375-
def g(typ: int) -> str | np.dtype:
1376-
if typ <= 2045:
1377-
return str(typ)
1378-
try:
1379-
return self.DTYPE_MAP_XML[typ]
1380-
except KeyError as err:
1381-
raise ValueError(f"cannot convert stata dtype [{typ}]") from err
1382-
1383-
dtyplist = [g(x) for x in raw_typlist]
1357+
typlist.append(typ)
1358+
dtyplist.append(str(typ))
1359+
else:
1360+
try:
1361+
typlist.append(self.TYPE_MAP_XML[typ]) # type: ignore[arg-type]
1362+
dtyplist.append(self.DTYPE_MAP_XML[typ]) # type: ignore[arg-type]
1363+
except KeyError as err:
1364+
raise ValueError(f"cannot convert stata types [{typ}]") from err
13841365

1385-
return typlist, dtyplist
1366+
return typlist, dtyplist # type: ignore[return-value]
13861367

13871368
def _get_varlist(self) -> list[str]:
13881369
# 33 in order formats, 129 in formats 118 and 119
@@ -1560,11 +1541,6 @@ def _setup_dtype(self) -> np.dtype:
15601541

15611542
return self._dtype
15621543

1563-
def _calcsize(self, fmt: int | str) -> int:
1564-
if isinstance(fmt, int):
1565-
return fmt
1566-
return struct.calcsize(self._byteorder + fmt)
1567-
15681544
def _decode(self, s: bytes) -> str:
15691545
# have bytes not strings, so must decode
15701546
s = s.partition(b"\0")[0]
@@ -1787,8 +1763,9 @@ def read(
17871763
# If index is not specified, use actual row number rather than
17881764
# restarting at 0 for each chunk.
17891765
if index_col is None:
1790-
rng = range(self._lines_read - read_lines, self._lines_read)
1791-
data.index = Index(rng) # set attr instead of set_index to avoid copy
1766+
data.index = RangeIndex(
1767+
self._lines_read - read_lines, self._lines_read
1768+
) # set attr instead of set_index to avoid copy
17921769

17931770
if columns is not None:
17941771
data = self._do_select_columns(data, columns)
@@ -1800,39 +1777,22 @@ def read(
18001777

18011778
data = self._insert_strls(data)
18021779

1803-
cols_ = np.where([dtyp is not None for dtyp in self._dtyplist])[0]
18041780
# Convert columns (if needed) to match input type
1805-
ix = data.index
1806-
requires_type_conversion = False
1807-
data_formatted = []
1808-
for i in cols_:
1809-
if self._dtyplist[i] is not None:
1810-
col = data.columns[i]
1811-
dtype = data[col].dtype
1812-
if dtype != np.dtype(object) and dtype != self._dtyplist[i]:
1813-
requires_type_conversion = True
1814-
data_formatted.append(
1815-
(col, Series(data[col], ix, self._dtyplist[i]))
1816-
)
1817-
else:
1818-
data_formatted.append((col, data[col]))
1819-
if requires_type_conversion:
1820-
data = DataFrame.from_dict(dict(data_formatted))
1821-
del data_formatted
1781+
valid_dtypes = [i for i, dtyp in enumerate(self._dtyplist) if dtyp is not None]
1782+
object_type = np.dtype(object)
1783+
for idx in valid_dtypes:
1784+
dtype = data.iloc[:, idx].dtype
1785+
if dtype not in (object_type, self._dtyplist[idx]):
1786+
data.iloc[:, idx] = data.iloc[:, idx].astype(dtype)
18221787

18231788
data = self._do_convert_missing(data, convert_missing)
18241789

18251790
if convert_dates:
1826-
1827-
def any_startswith(x: str) -> bool:
1828-
return any(x.startswith(fmt) for fmt in _date_formats)
1829-
1830-
cols = np.where([any_startswith(x) for x in self._fmtlist])[0]
1831-
for i in cols:
1832-
col = data.columns[i]
1833-
data[col] = _stata_elapsed_date_to_datetime_vec(
1834-
data[col], self._fmtlist[i]
1835-
)
1791+
for i, fmt in enumerate(self._fmtlist):
1792+
if any(fmt.startswith(date_fmt) for date_fmt in _date_formats):
1793+
data.iloc[:, i] = _stata_elapsed_date_to_datetime_vec(
1794+
data.iloc[:, i], fmt
1795+
)
18361796

18371797
if convert_categoricals and self._format_version > 108:
18381798
data = self._do_convert_categoricals(
@@ -1866,14 +1826,14 @@ def any_startswith(x: str) -> bool:
18661826
def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame:
18671827
# Check for missing values, and replace if found
18681828
replacements = {}
1869-
for i, colname in enumerate(data):
1829+
for i in range(len(data.columns)):
18701830
fmt = self._typlist[i]
18711831
if fmt not in self.VALID_RANGE:
18721832
continue
18731833

18741834
fmt = cast(str, fmt) # only strs in VALID_RANGE
18751835
nmin, nmax = self.VALID_RANGE[fmt]
1876-
series = data[colname]
1836+
series = data.iloc[:, i]
18771837

18781838
# appreciably faster to do this with ndarray instead of Series
18791839
svals = series._values
@@ -1903,11 +1863,10 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra
19031863
# Note: operating on ._values is much faster than directly
19041864
# TODO: can we fix that?
19051865
replacement._values[missing] = np.nan
1906-
replacements[colname] = replacement
1907-
1866+
replacements[i] = replacement
19081867
if replacements:
1909-
for col, value in replacements.items():
1910-
data[col] = value
1868+
for idx, value in replacements.items():
1869+
data.iloc[:, idx] = value
19111870
return data
19121871

19131872
def _insert_strls(self, data: DataFrame) -> DataFrame:
@@ -1962,10 +1921,11 @@ def _do_convert_categoricals(
19621921
"""
19631922
Converts categorical columns to Categorical type.
19641923
"""
1965-
value_labels = list(value_label_dict.keys())
1924+
if not value_label_dict:
1925+
return data
19661926
cat_converted_data = []
19671927
for col, label in zip(data, lbllist):
1968-
if label in value_labels:
1928+
if label in value_label_dict:
19691929
# Explicit call with ordered=True
19701930
vl = value_label_dict[label]
19711931
keys = np.array(list(vl.keys()))
@@ -2466,7 +2426,7 @@ def _prepare_categoricals(self, data: DataFrame) -> DataFrame:
24662426
Check for categorical columns, retain categorical information for
24672427
Stata file and convert categorical data to int
24682428
"""
2469-
is_cat = [isinstance(data[col].dtype, CategoricalDtype) for col in data]
2429+
is_cat = [isinstance(dtype, CategoricalDtype) for dtype in data.dtypes]
24702430
if not any(is_cat):
24712431
return data
24722432

0 commit comments

Comments
 (0)