Skip to content

ENH: Allow timestamp and data label to be set when exporting to Stata #6553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Improvements to existing features
- perf improvements in DataFrame construction with certain offsets, by removing faulty caching
(e.g. MonthEnd,BusinessMonthEnd), (:issue:`6479`)
- perf improvements in single-dtyped indexing (:issue:`6484`)
- ``StataWriter`` and ``DataFrame.to_stata`` accept time stamp and data labels (:issue:`6545`)

.. _release.bug_fixes-0.14.0:

Expand Down
3 changes: 3 additions & 0 deletions doc/source/v0.14.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ Enhancements
- ``DataFrame.to_stata`` will now check data for compatibility with Stata data types
and will upcast when needed. When it isn't possibly to losslessly upcast, a warning
is raised (:issue:`6327`)
- ``DataFrame.to_stata`` and ``StataWriter`` will accept keyword arguments time_stamp
and data_label which allow the time stamp and dataset label to be set when creating a
file. (:issue:`6545`)

Performance
~~~~~~~~~~~
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='',

def to_stata(
self, fname, convert_dates=None, write_index=True, encoding="latin-1",
byteorder=None):
byteorder=None, time_stamp=None, data_label=None):
"""
A class for writing Stata binary dta files from array-like objects

Expand Down Expand Up @@ -1247,7 +1247,8 @@ def to_stata(
"""
from pandas.io.stata import StataWriter
writer = StataWriter(fname, self, convert_dates=convert_dates,
encoding=encoding, byteorder=byteorder)
encoding=encoding, byteorder=byteorder,
time_stamp=time_stamp, data_label=data_label)
writer.write_file()

def to_sql(self, name, con, flavor='sqlite', if_exists='fail', **kwargs):
Expand Down
43 changes: 33 additions & 10 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def __init__(self, encoding):
'd': np.float64(struct.unpack('<d', b'\x00\x00\x00\x00\x00\x00\xe0\x7f')[0])
}

# Reserved words cannot be used as variable names
self.RESERVED_WORDS = ('aggregate', 'array', 'boolean', 'break',
'byte', 'case', 'catch', 'class', 'colvector',
'complex', 'const', 'continue', 'default',
'delegate', 'delete', 'do', 'double', 'else',
'eltypedef', 'end', 'enum', 'explicit',
'export', 'external', 'float', 'for', 'friend',
'function', 'global', 'goto', 'if', 'inline',
'int', 'local', 'long', 'NULL', 'pragma',
'protected', 'quad', 'rowvector', 'short',
'typedef', 'typename', 'virtual')

def _decode_bytes(self, str, errors=None):
if compat.PY3 or self._encoding is not None:
return str.decode(self._encoding, errors)
Expand Down Expand Up @@ -449,10 +461,10 @@ def _read_header(self):
self.path_or_buf.read(4))[0]
self.path_or_buf.read(11) # </N><label>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.data_label = self.path_or_buf.read(strlen)
self.data_label = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(19) # </label><timestamp>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.time_stamp = self.path_or_buf.read(strlen)
self.time_stamp = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(26) # </timestamp></header><map>
self.path_or_buf.read(8) # 0x0000000000000000
self.path_or_buf.read(8) # position of <map>
Expand Down Expand Up @@ -543,11 +555,11 @@ def _read_header(self):
self.nobs = struct.unpack(self.byteorder + 'I',
self.path_or_buf.read(4))[0]
if self.format_version > 105:
self.data_label = self.path_or_buf.read(81)
self.data_label = self._null_terminate(self.path_or_buf.read(81))
else:
self.data_label = self.path_or_buf.read(32)
self.data_label = self._null_terminate(self.path_or_buf.read(32))
if self.format_version > 104:
self.time_stamp = self.path_or_buf.read(18)
self.time_stamp = self._null_terminate(self.path_or_buf.read(18))

# descriptors
if self.format_version > 108:
Expand Down Expand Up @@ -1029,6 +1041,11 @@ class StataWriter(StataParser):
byteorder : str
Can be ">", "<", "little", or "big". The default is None which uses
`sys.byteorder`
time_stamp : datetime
A date time to use when writing the file. Can be None, in which
case the current time is used.
dataset_label : str
A label for the data set. Should be 80 characters or smaller.

Returns
-------
Expand All @@ -1047,10 +1064,13 @@ class StataWriter(StataParser):
>>> writer.write_file()
"""
def __init__(self, fname, data, convert_dates=None, write_index=True,
encoding="latin-1", byteorder=None):
encoding="latin-1", byteorder=None, time_stamp=None,
data_label=None):
super(StataWriter, self).__init__(encoding)
self._convert_dates = convert_dates
self._write_index = write_index
self._time_stamp = time_stamp
self._data_label = data_label
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)

Expand Down Expand Up @@ -1086,7 +1106,7 @@ def __iter__(self):

if self._write_index:
data = data.reset_index()
# Check columns for compatbaility with stata
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)
self.datarows = DataFrameRowIter(data)
self.nobs, self.nvar = data.shape
Expand All @@ -1110,7 +1130,8 @@ def __iter__(self):
self.fmtlist[key] = self._convert_dates[key]

def write_file(self):
self._write_header()
self._write_header(time_stamp=self._time_stamp,
data_label=self._data_label)
self._write_descriptors()
self._write_variable_labels()
# write 5 zeros for expansion fields
Expand Down Expand Up @@ -1147,7 +1168,7 @@ def _write_header(self, data_label=None, time_stamp=None):
# format dd Mon yyyy hh:mm
if time_stamp is None:
time_stamp = datetime.datetime.now()
elif not isinstance(time_stamp, datetime):
elif not isinstance(time_stamp, datetime.datetime):
raise ValueError("time_stamp should be datetime type")
self._file.write(
self._null_terminate(time_stamp.strftime("%d %b %Y %H:%M"))
Expand All @@ -1169,7 +1190,9 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')

# Variable name must not be a reserved word
if name in self.RESERVED_WORDS:
name = '_' + name
# Variable name may not start with a number
if name[0] > '0' and name[0] < '9':
name = '_' + name
Expand Down
32 changes: 27 additions & 5 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=E1101

from datetime import datetime
import datetime as dt
import os
import warnings
import nose
Expand Down Expand Up @@ -248,7 +249,7 @@ def test_read_write_dta10(self):

original = DataFrame(data=[["string", "object", 1, 1.1,
np.datetime64('2003-12-25')]],
columns=['string', 'object', 'integer', 'float',
columns=['string', 'object', 'integer', 'floating',
'datetime'])
original["object"] = Series(original["object"], dtype=object)
original.index.name = 'index'
Expand Down Expand Up @@ -304,10 +305,20 @@ def test_read_write_dta11(self):
def test_read_write_dta12(self):
# skip_if_not_little_endian()

original = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_1', 'astringwithmorethan32characters_2', '+', '-'])
formatted = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_', '_0astringwithmorethan32character', '_', '_1_'])
original = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_1',
'astringwithmorethan32characters_2',
'+',
'-',
'short',
'delete'])
formatted = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_',
'_0astringwithmorethan32character',
'_',
'_1_',
'_short',
'_delete'])
formatted.index.name = 'index'
formatted = formatted.astype(np.int32)

Expand Down Expand Up @@ -376,6 +387,17 @@ def test_read_write_reread_dta15(self):
tm.assert_frame_equal(parsed_113, parsed_114)
tm.assert_frame_equal(parsed_114, parsed_115)

def test_timestamp_and_label(self):
original = DataFrame([(1,)], columns=['var'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
reader = StataReader(path)
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label



if __name__ == '__main__':
Expand Down