Skip to content

Commit 87d41f3

Browse files
committed
Implementation of Stata 13 (format 117) support
1 parent ea63f36 commit 87d41f3

File tree

2 files changed

+207
-79
lines changed

2 files changed

+207
-79
lines changed

pandas/io/stata.py

+196-78
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def read_stata(filepath_or_buffer, convert_dates=True, convert_categoricals=True
4242

4343
_date_formats = ["%tc", "%tC", "%td", "%tw", "%tm", "%tq", "%th", "%ty"]
4444

45+
4546
def _stata_elapsed_date_to_datetime(date, fmt):
4647
"""
4748
Convert from SIF to datetime. http://www.stata.com/help.cgi?datetime
@@ -234,7 +235,28 @@ def __init__(self, encoding=None):
234235
(255, np.float64)
235236
]
236237
)
238+
self.DTYPE_MAP_XML = \
239+
dict(
240+
[
241+
(32768, np.string_),
242+
(65526, np.float64),
243+
(65527, np.float32),
244+
(65528, np.int64),
245+
(65529, np.int32),
246+
(65530, np.int16)
247+
]
248+
)
237249
self.TYPE_MAP = lrange(251) + list('bhlfd')
250+
self.TYPE_MAP_XML = \
251+
dict(
252+
[
253+
(65526, 'd'),
254+
(65527, 'f'),
255+
(65528, 'l'),
256+
(65529, 'h'),
257+
(65530, 'b')
258+
]
259+
)
238260
#NOTE: technically, some of these are wrong. there are more numbers
239261
# that can be represented. it's the 27 ABOVE and BELOW the max listed
240262
# numeric data type in [U] 12.2.2 of the 11.2 manual
@@ -304,86 +326,159 @@ def __init__(self, path_or_buf, encoding='cp1252'):
304326
self._read_header()
305327

306328
def _read_header(self):
307-
# header
308-
self.format_version = struct.unpack('b', self.path_or_buf.read(1))[0]
309-
if self.format_version not in [104, 105, 108, 113, 114, 115]:
310-
raise ValueError("Version of given Stata file is not 104, 105, 108, 113 (Stata 8/9), 114 (Stata 10/11) or 115 (Stata 12)")
311-
self.byteorder = self.path_or_buf.read(1) == 0x1 and '>' or '<'
312-
self.filetype = struct.unpack('b', self.path_or_buf.read(1))[0]
313-
self.path_or_buf.read(1) # unused
314-
315-
self.nvar = struct.unpack(self.byteorder + 'H', self.path_or_buf.read(2))[0]
316-
self.nobs = struct.unpack(self.byteorder + 'I', self.path_or_buf.read(4))[0]
317-
if self.format_version > 105:
318-
self.data_label = self.path_or_buf.read(81)
319-
else:
320-
self.data_label = self.path_or_buf.read(32)
321-
if self.format_version > 104:
322-
self.time_stamp = self.path_or_buf.read(18)
329+
first_char = self.path_or_buf.read(1)
330+
if struct.unpack('c', first_char)[0] is b'<': # format 117 or higher (XML like)
331+
self.path_or_buf.read(27) # stata_dta><header><release>
332+
self.format_version = int(self.path_or_buf.read(3))
333+
if self.format_version not in [117]:
334+
raise ValueError("Version of given Stata file is not 104, 105, 108, 113 (Stata 8/9), 114 (Stata 10/11), 115 (Stata 12) or 117 (Stata 13)")
335+
self.path_or_buf.read(21) # </release><byteorder>
336+
self.byteorder = self.path_or_buf.read(3) == "LSF" and '>' or '<'
337+
self.path_or_buf.read(15) # </byteorder><K>
338+
self.nvar = struct.unpack(self.byteorder + 'H', self.path_or_buf.read(2))[0]
339+
self.path_or_buf.read(7) # </K><N>
340+
self.nobs = struct.unpack(self.byteorder + 'I', self.path_or_buf.read(4))[0]
341+
self.path_or_buf.read(11) # </N><label>
342+
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
343+
self.data_label = self.path_or_buf.read(strlen)
344+
self.path_or_buf.read(19) # </label><timestamp>
345+
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
346+
self.time_stamp = self.path_or_buf.read(strlen)
347+
self.path_or_buf.read(26) # </timestamp></header><map>
348+
self.path_or_buf.read(8) # 0x0000000000000000
349+
self.path_or_buf.read(8) # position of <map>
350+
seek_vartypes = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 16
351+
seek_varnames = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 10
352+
seek_sortlist = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 10
353+
seek_formats = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 9
354+
seek_value_label_names = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 19
355+
seek_variable_labels = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 17
356+
self.path_or_buf.read(8) # <characteristics>
357+
self.data_location = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 6
358+
self.seek_strls = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 7
359+
self.seek_value_labels = struct.unpack(self.byteorder + 'q', self.path_or_buf.read(8))[0] + 14
360+
#self.path_or_buf.read(8) # </stata_dta>
361+
#self.path_or_buf.read(8) # EOF
362+
self.path_or_buf.seek(seek_vartypes)
363+
typlist = [struct.unpack(self.byteorder + 'H', self.path_or_buf.read(2))[0] for i in range(self.nvar)]
364+
self.typlist = [None]*self.nvar
365+
try:
366+
i = 0
367+
for typ in typlist:
368+
if typ <= 2045 or typ == 32768:
369+
self.typlist[i] = None
370+
else:
371+
self.typlist[i] = self.TYPE_MAP_XML[typ]
372+
i += 1
373+
except:
374+
raise ValueError("cannot convert stata types [{0}]".format(','.join(typlist)))
375+
self.dtyplist = [None]*self.nvar
376+
try:
377+
i = 0
378+
for typ in typlist:
379+
if typ <= 2045:
380+
self.dtyplist[i] = str(typ)
381+
else:
382+
self.dtyplist[i] = self.DTYPE_MAP_XML[typ]
383+
i += 1
384+
except:
385+
raise ValueError("cannot convert stata dtypes [{0}]".format(','.join(typlist)))
323386

324-
# descriptors
325-
if self.format_version > 108:
326-
typlist = [ord(self.path_or_buf.read(1)) for i in range(self.nvar)]
327-
else:
328-
typlist = [self.OLD_TYPE_MAPPING[self._decode_bytes(self.path_or_buf.read(1))] for i in range(self.nvar)]
329-
330-
try:
331-
self.typlist = [self.TYPE_MAP[typ] for typ in typlist]
332-
except:
333-
raise ValueError("cannot convert stata types [{0}]".format(','.join(typlist)))
334-
try:
335-
self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
336-
except:
337-
raise ValueError("cannot convert stata dtypes [{0}]".format(','.join(typlist)))
338-
339-
if self.format_version > 108:
387+
self.path_or_buf.seek(seek_varnames)
340388
self.varlist = [self._null_terminate(self.path_or_buf.read(33)) for i in range(self.nvar)]
341-
else:
342-
self.varlist = [self._null_terminate(self.path_or_buf.read(9)) for i in range(self.nvar)]
343-
self.srtlist = struct.unpack(self.byteorder + ('h' * (self.nvar + 1)), self.path_or_buf.read(2 * (self.nvar + 1)))[:-1]
344-
if self.format_version > 113:
389+
390+
self.path_or_buf.seek(seek_sortlist)
391+
self.srtlist = struct.unpack(self.byteorder + ('h' * (self.nvar + 1)), self.path_or_buf.read(2 * (self.nvar + 1)))[:-1]
392+
393+
self.path_or_buf.seek(seek_formats)
345394
self.fmtlist = [self._null_terminate(self.path_or_buf.read(49)) for i in range(self.nvar)]
346-
elif self.format_version > 104:
347-
self.fmtlist = [self._null_terminate(self.path_or_buf.read(12)) for i in range(self.nvar)]
348-
else:
349-
self.fmtlist = [self._null_terminate(self.path_or_buf.read(7)) for i in range(self.nvar)]
350-
if self.format_version > 108:
395+
396+
self.path_or_buf.seek(seek_value_label_names)
351397
self.lbllist = [self._null_terminate(self.path_or_buf.read(33)) for i in range(self.nvar)]
352-
else:
353-
self.lbllist = [self._null_terminate(self.path_or_buf.read(9)) for i in range(self.nvar)]
354-
if self.format_version > 105:
398+
399+
self.path_or_buf.seek(seek_variable_labels)
355400
self.vlblist = [self._null_terminate(self.path_or_buf.read(81)) for i in range(self.nvar)]
356401
else:
357-
self.vlblist = [self._null_terminate(self.path_or_buf.read(32)) for i in range(self.nvar)]
358-
359-
# ignore expansion fields (Format 105 and later)
360-
# When reading, read five bytes; the last four bytes now tell you the
361-
# size of the next read, which you discard. You then continue like
362-
# this until you read 5 bytes of zeros.
363-
364-
if self.format_version > 104:
365-
while True:
366-
data_type = struct.unpack(self.byteorder + 'b', self.path_or_buf.read(1))[0]
367-
if self.format_version > 108:
368-
data_len = struct.unpack(self.byteorder + 'i', self.path_or_buf.read(4))[0]
369-
else:
370-
data_len = struct.unpack(self.byteorder + 'h', self.path_or_buf.read(2))[0]
371-
if data_type == 0:
372-
break
373-
self.path_or_buf.read(data_len)
402+
# header
403+
self.format_version = struct.unpack('b', first_char)[0]
404+
if self.format_version not in [104, 105, 108, 113, 114, 115]:
405+
raise ValueError("Version of given Stata file is not 104, 105, 108, 113 (Stata 8/9), 114 (Stata 10/11), 115 (Stata 12) or 117 (Stata 13)")
406+
self.byteorder = self.path_or_buf.read(1) == 0x1 and '>' or '<'
407+
self.filetype = struct.unpack('b', self.path_or_buf.read(1))[0]
408+
self.path_or_buf.read(1) # unused
409+
410+
self.nvar = struct.unpack(self.byteorder + 'H', self.path_or_buf.read(2))[0]
411+
self.nobs = struct.unpack(self.byteorder + 'I', self.path_or_buf.read(4))[0]
412+
if self.format_version > 105:
413+
self.data_label = self.path_or_buf.read(81)
414+
else:
415+
self.data_label = self.path_or_buf.read(32)
416+
if self.format_version > 104:
417+
self.time_stamp = self.path_or_buf.read(18)
418+
419+
# descriptors
420+
if self.format_version > 108:
421+
typlist = [ord(self.path_or_buf.read(1)) for i in range(self.nvar)]
422+
else:
423+
typlist = [self.OLD_TYPE_MAPPING[self._decode_bytes(self.path_or_buf.read(1))] for i in range(self.nvar)]
424+
425+
try:
426+
self.typlist = [self.TYPE_MAP[typ] for typ in typlist]
427+
except:
428+
raise ValueError("cannot convert stata types [{0}]".format(','.join(typlist)))
429+
try:
430+
self.dtyplist = [self.DTYPE_MAP[typ] for typ in typlist]
431+
except:
432+
raise ValueError("cannot convert stata dtypes [{0}]".format(','.join(typlist)))
433+
434+
if self.format_version > 108:
435+
self.varlist = [self._null_terminate(self.path_or_buf.read(33)) for i in range(self.nvar)]
436+
else:
437+
self.varlist = [self._null_terminate(self.path_or_buf.read(9)) for i in range(self.nvar)]
438+
self.srtlist = struct.unpack(self.byteorder + ('h' * (self.nvar + 1)), self.path_or_buf.read(2 * (self.nvar + 1)))[:-1]
439+
if self.format_version > 113:
440+
self.fmtlist = [self._null_terminate(self.path_or_buf.read(49)) for i in range(self.nvar)]
441+
elif self.format_version > 104:
442+
self.fmtlist = [self._null_terminate(self.path_or_buf.read(12)) for i in range(self.nvar)]
443+
else:
444+
self.fmtlist = [self._null_terminate(self.path_or_buf.read(7)) for i in range(self.nvar)]
445+
if self.format_version > 108:
446+
self.lbllist = [self._null_terminate(self.path_or_buf.read(33)) for i in range(self.nvar)]
447+
else:
448+
self.lbllist = [self._null_terminate(self.path_or_buf.read(9)) for i in range(self.nvar)]
449+
if self.format_version > 105:
450+
self.vlblist = [self._null_terminate(self.path_or_buf.read(81)) for i in range(self.nvar)]
451+
else:
452+
self.vlblist = [self._null_terminate(self.path_or_buf.read(32)) for i in range(self.nvar)]
453+
454+
# ignore expansion fields (Format 105 and later)
455+
# When reading, read five bytes; the last four bytes now tell you the
456+
# size of the next read, which you discard. You then continue like
457+
# this until you read 5 bytes of zeros.
458+
459+
if self.format_version > 104:
460+
while True:
461+
data_type = struct.unpack(self.byteorder + 'b', self.path_or_buf.read(1))[0]
462+
if self.format_version > 108:
463+
data_len = struct.unpack(self.byteorder + 'i', self.path_or_buf.read(4))[0]
464+
else:
465+
data_len = struct.unpack(self.byteorder + 'h', self.path_or_buf.read(2))[0]
466+
if data_type == 0:
467+
break
468+
self.path_or_buf.read(data_len)
469+
470+
# necessary data to continue parsing
471+
self.data_location = self.path_or_buf.tell()
374472

375-
# necessary data to continue parsing
376-
self.data_location = self.path_or_buf.tell()
377473
self.has_string_data = len([x for x in self.typlist if type(x) is int]) > 0
378-
self._col_size()
474+
475+
"""Calculate size of a data record."""
476+
self.col_sizes = lmap(lambda x: self._calcsize(x), self.typlist)
379477

380478
def _calcsize(self, fmt):
381479
return type(fmt) is int and fmt or struct.calcsize(self.byteorder + fmt)
382480

383481
def _col_size(self, k=None):
384-
"""Calculate size of a data record."""
385-
if len(self.col_sizes) == 0:
386-
self.col_sizes = lmap(lambda x: self._calcsize(x), self.typlist)
387482
if k is None:
388483
return self.col_sizes
389484
else:
@@ -427,8 +522,8 @@ def _next(self):
427522
return data
428523
else:
429524
return list(map(lambda i: self._unpack(typlist[i],
430-
self.path_or_buf.read(self._col_size(i))),
431-
range(self.nvar)))
525+
self.path_or_buf.read(self._col_size(i))),
526+
range(self.nvar)))
432527

433528
def _dataset(self):
434529
"""
@@ -450,29 +545,33 @@ def _dataset(self):
450545
be handled by your applcation.
451546
"""
452547

453-
try:
454-
self._file.seek(self._data_location)
455-
except Exception:
456-
pass
548+
self.path_or_buf.seek(self.data_location)
457549

458550
for i in range(self.nobs):
459551
yield self._next()
460552

461553
def _read_value_labels(self):
462-
if not self._data_read:
463-
raise Exception("Data has not been read. Because of the layout of Stata files, this is necessary before reading value labels.")
464-
if self._value_labels_read:
465-
raise Exception("Value labels have already been read.")
554+
if self.format_version >= 117:
555+
self.path_or_buf.seek(self.seek_value_labels)
556+
else:
557+
if not self._data_read:
558+
raise Exception("Data has not been read. Because of the layout of Stata files, this is necessary before reading value labels.")
559+
if self._value_labels_read:
560+
raise Exception("Value labels have already been read.")
466561

467562
self.value_label_dict = dict()
468563

469564
if self.format_version <= 108:
470565
return # Value labels are not supported in version 108 and earlier.
471566

472567
while True:
568+
if self.format_version >= 117:
569+
if self._decode_bytes(self.path_or_buf.read(5), self._encoding) == '</val': # <lbl>
570+
break # end o f variable lable table
571+
473572
slength = self.path_or_buf.read(4)
474573
if not slength:
475-
break # end of variable lable table
574+
break # end of variable lable table (format < 117)
476575
labname = self._null_terminate(self.path_or_buf.read(33))
477576
self.path_or_buf.read(3) # padding
478577

@@ -488,8 +587,24 @@ def _read_value_labels(self):
488587
self.value_label_dict[labname] = dict()
489588
for i in range(n):
490589
self.value_label_dict[labname][val[i]] = self._null_terminate(txt[off[i]:])
590+
591+
if self.format_version >= 117:
592+
self.path_or_buf.read(6) # </lbl>
491593
self._value_labels_read = True
492594

595+
def _read_strls(self):
596+
self.path_or_buf.seek(self.seek_strls)
597+
self.GSO = dict()
598+
while True:
599+
if self.path_or_buf.read(3) is not 'GSO':
600+
break
601+
602+
v_o = struct.unpack(self.byteorder + 'L', self.path_or_buf.read(8))[0]
603+
typ = self.path_or_buf.read(1)
604+
length = struct.unpack(self.byteorder + 'I', self.path_or_buf.read(4))[0]
605+
self.GSO[v_o] = self.path_or_buf.read(length-1)
606+
self.path_or_buf.read(1) # zero-termination
607+
493608
def data(self, convert_dates=True, convert_categoricals=True, index=None):
494609
"""
495610
Reads observations from Stata file, converting them into a dataframe
@@ -511,6 +626,9 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None):
511626
raise Exception("Data has already been read.")
512627
self._data_read = True
513628

629+
if self.format_version >= 117:
630+
self._read_strls()
631+
514632
stata_dta = self._dataset()
515633

516634
data = []

pandas/io/tests/test_stata.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def setUp(self):
3333
self.dta9 = os.path.join(self.dirpath, 'lbw.dta')
3434
self.csv9 = os.path.join(self.dirpath, 'lbw.csv')
3535
self.dta_encoding = os.path.join(self.dirpath, 'stata1_encoding.dta')
36-
3736
def read_dta(self, file):
3837
return read_stata(file, convert_dates=True)
3938

@@ -198,6 +197,17 @@ def test_read_write_dta10(self):
198197
tm.assert_frame_equal(written_and_read_again.set_index('index'),
199198
original)
200199

200+
def test_read_dta11(self):
201+
reader = StataReader(self.dta11)
202+
parsed = reader.data()
203+
# Pandas uses np.nan as missing value.
204+
# Thus, all columns will be of type float, regardless of their name.
205+
expected = DataFrame([(np.nan, np.nan, np.nan, np.nan, np.nan)],
206+
columns=['float_miss', 'double_miss', 'byte_miss',
207+
'int_miss', 'long_miss'])
208+
209+
tm.assert_frame_equal(parsed, expected)
210+
201211
def test_stata_doc_examples(self):
202212
with tm.ensure_clean() as path:
203213
df = DataFrame(np.random.randn(10, 2), columns=list('AB'))

0 commit comments

Comments
 (0)