Skip to content

Commit 3a7d8f1

Browse files
committed
Merge pull request #8045 from bashtage/stata_unpack_performance
PERF: StataReader is slow due to excessive lookups for missing vales
2 parents 989a51b + 53f5598 commit 3a7d8f1

File tree

8 files changed

+224
-41
lines changed

8 files changed

+224
-41
lines changed

doc/source/io.rst

+7
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,13 @@ read and used to create a ``Categorical`` variable from them. Value labels can
35583558
also be retrieved by the function ``variable_labels``, which requires data to be
35593559
called before (see ``pandas.io.stata.StataReader``).
35603560

3561+
The parameter ``convert_missing`` indicates whether missing value
3562+
representations in Stata should be preserved. If ``False`` (the default),
3563+
missing values are represented as ``np.nan``. If ``True``, missing values are
3564+
represented using ``StataMissingValue`` objects, and columns containing missing
3565+
values will have ``dtype`` set to ``object``.
3566+
3567+
35613568
The StataReader supports .dta Formats 104, 105, 108, 113-115 and 117.
35623569
Alternatively, the function :func:`~pandas.io.stata.read_stata` can be used
35633570

doc/source/v0.15.0.txt

+6
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ API changes
144144
strings must contain 244 or fewer characters. Attempting to write Stata
145145
dta files with strings longer than 244 characters raises a ``ValueError``. (:issue:`7858`)
146146

147+
- ``read_stata`` and ``StataReader`` can import missing data information into a
148+
``DataFrame`` by setting the argument ``convert_missing`` to ``True``. When
149+
using this options, missing values are returned as ``StataMissingValue``
150+
objects and columns containing missing values have ``object`` data type. (:issue:`8045`)
151+
147152
- ``Index.isin`` now supports a ``level`` argument to specify which index level
148153
to use for membership tests (:issue:`7892`, :issue:`7890`)
149154

@@ -414,6 +419,7 @@ Performance
414419
- Performance improvements in ``DatetimeIndex.__iter__`` to allow faster iteration (:issue:`7683`)
415420
- Performance improvements in ``Period`` creation (and ``PeriodIndex`` setitem) (:issue:`5155`)
416421
- Improvements in Series.transform for significant performance gains (revised) (:issue:`6496`)
422+
- Performance improvements in ``StataReader`` when reading large files (:issue:`8040`)
417423

418424

419425

pandas/io/stata.py

+138-40
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
You can find more information on http://presbrey.mit.edu/PyDTA and
1010
http://statsmodels.sourceforge.net/devel/
1111
"""
12-
# TODO: Fix this module so it can use cross-compatible zip, map, and range
1312
import numpy as np
1413

1514
import sys
@@ -20,14 +19,16 @@
2019
from pandas.core.categorical import Categorical
2120
import datetime
2221
from pandas import compat
23-
from pandas.compat import long, lrange, lmap, lzip, text_type, string_types
22+
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
23+
zip
2424
from pandas import isnull
2525
from pandas.io.common import get_filepath_or_buffer
2626
from pandas.lib import max_len_string_array, is_string_array
2727
from pandas.tslib import NaT
2828

2929
def read_stata(filepath_or_buffer, convert_dates=True,
30-
convert_categoricals=True, encoding=None, index=None):
30+
convert_categoricals=True, encoding=None, index=None,
31+
convert_missing=False):
3132
"""
3233
Read Stata file into DataFrame
3334
@@ -44,10 +45,19 @@ def read_stata(filepath_or_buffer, convert_dates=True,
4445
support unicode. None defaults to cp1252.
4546
index : identifier of index column
4647
identifier of column that should be used as index of the DataFrame
48+
convert_missing : boolean, defaults to False
49+
Flag indicating whether to convert missing values to their Stata
50+
representations. If False, missing values are replaced with nans.
51+
If True, columns containing missing values are returned with
52+
object data types and missing values are represented by
53+
StataMissingValue objects.
4754
"""
4855
reader = StataReader(filepath_or_buffer, encoding)
4956

50-
return reader.data(convert_dates, convert_categoricals, index)
57+
return reader.data(convert_dates,
58+
convert_categoricals,
59+
index,
60+
convert_missing)
5161

5262
_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
5363

@@ -291,35 +301,76 @@ class StataMissingValue(StringMixin):
291301
292302
Parameters
293303
-----------
294-
offset
295-
value
304+
value : int8, int16, int32, float32 or float64
305+
The Stata missing value code
296306
297307
Attributes
298308
----------
299-
string
300-
value
309+
string : string
310+
String representation of the Stata missing value
311+
value : int8, int16, int32, float32 or float64
312+
The original encoded missing value
301313
302314
Notes
303315
-----
304316
More information: <http://www.stata.com/help.cgi?missing>
317+
318+
Integer missing values make the code '.', '.a', ..., '.z' to the ranges
319+
101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ...
320+
2147483647 (for int32). Missing values for floating point data types are
321+
more complex but the pattern is simple to discern from the following table.
322+
323+
np.float32 missing values (float in Stata)
324+
0000007f .
325+
0008007f .a
326+
0010007f .b
327+
...
328+
00c0007f .x
329+
00c8007f .y
330+
00d0007f .z
331+
332+
np.float64 missing values (double in Stata)
333+
000000000000e07f .
334+
000000000001e07f .a
335+
000000000002e07f .b
336+
...
337+
000000000018e07f .x
338+
000000000019e07f .y
339+
00000000001ae07f .z
305340
"""
306-
# TODO: Needs test
307-
def __init__(self, offset, value):
341+
342+
# Construct a dictionary of missing values
343+
MISSING_VALUES = {}
344+
bases = (101, 32741, 2147483621)
345+
for b in bases:
346+
MISSING_VALUES[b] = '.'
347+
for i in range(1, 27):
348+
MISSING_VALUES[i + b] = '.' + chr(96 + i)
349+
350+
base = b'\x00\x00\x00\x7f'
351+
increment = struct.unpack('<i', b'\x00\x08\x00\x00')[0]
352+
for i in range(27):
353+
value = struct.unpack('<f', base)[0]
354+
MISSING_VALUES[value] = '.'
355+
if i > 0:
356+
MISSING_VALUES[value] += chr(96 + i)
357+
int_value = struct.unpack('<i', struct.pack('<f', value))[0] + increment
358+
base = struct.pack('<i', int_value)
359+
360+
base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
361+
increment = struct.unpack('q', b'\x00\x00\x00\x00\x00\x01\x00\x00')[0]
362+
for i in range(27):
363+
value = struct.unpack('<d', base)[0]
364+
MISSING_VALUES[value] = '.'
365+
if i > 0:
366+
MISSING_VALUES[value] += chr(96 + i)
367+
int_value = struct.unpack('q', struct.pack('<d', value))[0] + increment
368+
base = struct.pack('q', int_value)
369+
370+
def __init__(self, value):
308371
self._value = value
309-
value_type = type(value)
310-
if value_type in int:
311-
loc = value - offset
312-
elif value_type in (float, np.float32, np.float64):
313-
if value <= np.finfo(np.float32).max: # float32
314-
conv_str, byte_loc, scale = '<f', 1, 8
315-
else:
316-
conv_str, byte_loc, scale = '<d', 5, 1
317-
value_bytes = struct.pack(conv_str, value)
318-
loc = (struct.unpack('<b', value_bytes[byte_loc])[0] / scale) + 0
319-
else:
320-
# Should never be hit
321-
loc = 0
322-
self._str = loc is 0 and '.' or ('.' + chr(loc + 96))
372+
self._str = self.MISSING_VALUES[value]
373+
323374
string = property(lambda self: self._str,
324375
doc="The Stata representation of the missing value: "
325376
"'.', '.a'..'.z'")
@@ -333,6 +384,10 @@ def __repr__(self):
333384
# not perfect :-/
334385
return "%s(%s)" % (self.__class__, self)
335386

387+
def __eq__(self, other):
388+
return (isinstance(other, self.__class__)
389+
and self.string == other.string and self.value == other.value)
390+
336391

337392
class StataParser(object):
338393
_default_encoding = 'cp1252'
@@ -711,15 +766,7 @@ def _col_size(self, k=None):
711766
return self.col_sizes[k]
712767

713768
def _unpack(self, fmt, byt):
714-
d = struct.unpack(self.byteorder + fmt, byt)[0]
715-
if fmt[-1] in self.VALID_RANGE:
716-
nmin, nmax = self.VALID_RANGE[fmt[-1]]
717-
if d < nmin or d > nmax:
718-
if self._missing_values:
719-
return StataMissingValue(nmax, d)
720-
else:
721-
return None
722-
return d
769+
return struct.unpack(self.byteorder + fmt, byt)[0]
723770

724771
def _null_terminate(self, s):
725772
if compat.PY3 or self._encoding is not None: # have bytes not strings,
@@ -752,16 +799,15 @@ def _next(self):
752799
)
753800
return data
754801
else:
755-
return list(
756-
map(
802+
return lmap(
757803
lambda i: self._unpack(typlist[i],
758804
self.path_or_buf.read(
759805
self._col_size(i)
760806
)),
761807
range(self.nvar)
762-
)
763808
)
764809

810+
765811
def _dataset(self):
766812
"""
767813
Returns a Python generator object for iterating over the dataset.
@@ -853,7 +899,8 @@ def _read_strls(self):
853899
self.GSO[v_o] = self.path_or_buf.read(length-1)
854900
self.path_or_buf.read(1) # zero-termination
855901

856-
def data(self, convert_dates=True, convert_categoricals=True, index=None):
902+
def data(self, convert_dates=True, convert_categoricals=True, index=None,
903+
convert_missing=False):
857904
"""
858905
Reads observations from Stata file, converting them into a dataframe
859906
@@ -866,11 +913,18 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None):
866913
variables
867914
index : identifier of index column
868915
identifier of column that should be used as index of the DataFrame
916+
convert_missing : boolean, defaults to False
917+
Flag indicating whether to convert missing values to their Stata
918+
representation. If False, missing values are replaced with
919+
nans. If True, columns containing missing values are returned with
920+
object data types and missing values are represented by
921+
StataMissingValue objects.
869922
870923
Returns
871924
-------
872925
y : DataFrame instance
873926
"""
927+
self._missing_values = convert_missing
874928
if self._data_read:
875929
raise Exception("Data has already been read.")
876930
self._data_read = True
@@ -894,18 +948,62 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None):
894948
if convert_categoricals:
895949
self._read_value_labels()
896950

951+
# TODO: Refactor to use a dictionary constructor and the correct dtype from the start?
897952
if len(data)==0:
898953
data = DataFrame(columns=self.varlist, index=index)
899954
else:
900955
data = DataFrame(data, columns=self.varlist, index=index)
901956

902957
cols_ = np.where(self.dtyplist)[0]
958+
959+
# Convert columns (if needed) to match input type
960+
index = data.index
961+
requires_type_conversion = False
962+
data_formatted = []
903963
for i in cols_:
904964
if self.dtyplist[i] is not None:
905965
col = data.columns[i]
906-
if data[col].dtype is not np.dtype(object):
907-
data[col] = Series(data[col], data[col].index,
908-
self.dtyplist[i])
966+
dtype = data[col].dtype
967+
if (dtype != np.dtype(object)) and (dtype != self.dtyplist[i]):
968+
requires_type_conversion = True
969+
data_formatted.append((col, Series(data[col], index, self.dtyplist[i])))
970+
else:
971+
data_formatted.append((col, data[col]))
972+
if requires_type_conversion:
973+
data = DataFrame.from_items(data_formatted)
974+
del data_formatted
975+
976+
# Check for missing values, and replace if found
977+
for i, colname in enumerate(data):
978+
fmt = self.typlist[i]
979+
if fmt not in self.VALID_RANGE:
980+
continue
981+
982+
nmin, nmax = self.VALID_RANGE[fmt]
983+
series = data[colname]
984+
missing = np.logical_or(series < nmin, series > nmax)
985+
986+
if not missing.any():
987+
continue
988+
989+
if self._missing_values: # Replacement follows Stata notation
990+
missing_loc = np.argwhere(missing)
991+
umissing, umissing_loc = np.unique(series[missing],
992+
return_inverse=True)
993+
replacement = Series(series, dtype=np.object)
994+
for i, um in enumerate(umissing):
995+
missing_value = StataMissingValue(um)
996+
997+
loc = missing_loc[umissing_loc == i]
998+
replacement.iloc[loc] = missing_value
999+
else: # All replacements are identical
1000+
dtype = series.dtype
1001+
if dtype not in (np.float32, np.float64):
1002+
dtype = np.float64
1003+
replacement = Series(series, dtype=dtype)
1004+
replacement[missing] = np.nan
1005+
1006+
data[colname] = replacement
9091007

9101008
if convert_dates:
9111009
cols = np.where(lmap(lambda x: x in _date_formats,

pandas/io/tests/data/stata8_113.dta

1.41 KB
Binary file not shown.

pandas/io/tests/data/stata8_115.dta

1.59 KB
Binary file not shown.

pandas/io/tests/data/stata8_117.dta

2.01 KB
Binary file not shown.

0 commit comments

Comments
 (0)