diff --git a/doc/source/whatsnew/v0.19.0.txt b/doc/source/whatsnew/v0.19.0.txt index 688f3b7ff6ada..a5478e3bed459 100644 --- a/doc/source/whatsnew/v0.19.0.txt +++ b/doc/source/whatsnew/v0.19.0.txt @@ -250,6 +250,8 @@ Other enhancements - A function :func:`union_categorical` has been added for combining categoricals, see :ref:`Unioning Categoricals` (:issue:`13361`) - ``Series`` has gained the properties ``.is_monotonic``, ``.is_monotonic_increasing``, ``.is_monotonic_decreasing``, similar to ``Index`` (:issue:`13336`) +- ``to_stata`` and ```StataWriter`` can now write variable labels to Stata dta files using a dictionary to make column names to labels (:issue:`13535`, :issue:`13535`) + .. _whatsnew_0190.api: API changes diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 334526b424be5..4fe7b318b3a18 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1467,7 +1467,7 @@ def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='', def to_stata(self, fname, convert_dates=None, write_index=True, encoding="latin-1", byteorder=None, time_stamp=None, - data_label=None): + data_label=None, variable_labels=None): """ A class for writing Stata binary dta files from array-like objects @@ -1480,11 +1480,24 @@ def to_stata(self, fname, convert_dates=None, write_index=True, format that you want to use for the dates. Options are 'tc', 'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either a number or a name. + write_index : bool + Write the index to Stata dataset. encoding : str Default is latin-1. Note that Stata does not support unicode. byteorder : str Can be ">", "<", "little", or "big". The default is None which uses `sys.byteorder` + time_stamp : datetime + A date time to use when writing the file. Can be None, in which + case the current time is used. + dataset_label : str + A label for the data set. Should be 80 characters or smaller. + + .. versionadded:: 0.19.0 + + variable_labels : dict + Dictionary containing columns as keys and variable labels as + values. Each label must be 80 characters or smaller. Examples -------- @@ -1500,7 +1513,8 @@ def to_stata(self, fname, convert_dates=None, write_index=True, writer = StataWriter(fname, self, convert_dates=convert_dates, encoding=encoding, byteorder=byteorder, time_stamp=time_stamp, data_label=data_label, - write_index=write_index) + write_index=write_index, + variable_labels=variable_labels) writer.write_file() @Appender(fmt.docstring_to_string, indents=1) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index bd19102c7f18c..d35466e8896ba 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -1059,7 +1059,7 @@ def _read_new_header(self, first_char): self.lbllist = self._get_lbllist() self.path_or_buf.seek(self._seek_variable_labels) - self.vlblist = self._get_vlblist() + self._variable_labels = self._get_variable_labels() # Get data type information, works for versions 117-118. def _get_dtypes(self, seek_vartypes): @@ -1127,7 +1127,7 @@ def _get_lbllist(self): return [self._null_terminate(self.path_or_buf.read(b)) for i in range(self.nvar)] - def _get_vlblist(self): + def _get_variable_labels(self): if self.format_version == 118: vlblist = [self._decode(self.path_or_buf.read(321)) for i in range(self.nvar)] @@ -1242,7 +1242,7 @@ def _read_old_header(self, first_char): self.lbllist = self._get_lbllist() - self.vlblist = self._get_vlblist() + self._variable_labels = self._get_variable_labels() # ignore expansion fields (Format 105 and later) # When reading, read five bytes; the last four bytes now tell you @@ -1306,11 +1306,11 @@ def _read_value_labels(self): while True: if self.format_version >= 117: if self.path_or_buf.read(5) == b' - break # end of variable label table + break # end of value label table slength = self.path_or_buf.read(4) if not slength: - break # end of variable label table (format < 117) + break # end of value label table (format < 117) if self.format_version <= 117: labname = self._null_terminate(self.path_or_buf.read(33)) else: @@ -1666,7 +1666,7 @@ def variable_labels(self): """Returns variable labels as a dict, associating each variable name with corresponding label """ - return dict(zip(self.varlist, self.vlblist)) + return dict(zip(self.varlist, self._variable_labels)) def value_labels(self): """Returns a dict, associating each variable name a dict, associating @@ -1696,7 +1696,7 @@ def _set_endianness(endianness): def _pad_bytes(name, length): """ - Takes a char string and pads it wih null bytes until it's length chars + Takes a char string and pads it with null bytes until it's length chars """ return name + "\x00" * (length - len(name)) @@ -1831,6 +1831,12 @@ class StataWriter(StataParser): dataset_label : str A label for the data set. Should be 80 characters or smaller. + .. versionadded:: 0.19.0 + + variable_labels : dict + Dictionary containing columns as keys and variable labels as values. + Each label must be 80 characters or smaller. + Returns ------- writer : StataWriter instance @@ -1853,12 +1859,13 @@ class StataWriter(StataParser): def __init__(self, fname, data, convert_dates=None, write_index=True, encoding="latin-1", byteorder=None, time_stamp=None, - data_label=None): + data_label=None, variable_labels=None): super(StataWriter, self).__init__(encoding) self._convert_dates = convert_dates self._write_index = write_index self._time_stamp = time_stamp self._data_label = data_label + self._variable_labels = variable_labels # attach nobs, nvars, data, varlist, typlist self._prepare_pandas(data) @@ -2135,11 +2142,29 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None, else: # Default is empty label self._write(_pad_bytes("", 33)) - def _write_variable_labels(self, labels=None): - nvar = self.nvar - if labels is None: - for i in range(nvar): - self._write(_pad_bytes("", 81)) + def _write_variable_labels(self): + # Missing labels are 80 blank characters plus null termination + blank = _pad_bytes('', 81) + + if self._variable_labels is None: + for i in range(self.nvar): + self._write(blank) + return + + for col in self.data: + if col in self._variable_labels: + label = self._variable_labels[col] + if len(label) > 80: + raise ValueError('Variable labels must be 80 characters ' + 'or fewer') + is_latin1 = all(ord(c) < 256 for c in label) + if not is_latin1: + raise ValueError('Variable labels must contain only ' + 'characters that can be encoded in ' + 'Latin-1') + self._write(_pad_bytes(label, 81)) + else: + self._write(blank) def _prepare_data(self): data = self.data diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index 5f45d1b547e62..91850e6ffe9b9 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -1,27 +1,27 @@ # -*- coding: utf-8 -*- # pylint: disable=E1101 -from datetime import datetime import datetime as dt import os -import warnings -import nose import struct import sys +import warnings +from datetime import datetime from distutils.version import LooseVersion +import nose import numpy as np import pandas as pd +import pandas.util.testing as tm +from pandas import compat from pandas.compat import iterkeys from pandas.core.frame import DataFrame, Series from pandas.types.common import is_categorical_dtype +from pandas.tslib import NaT from pandas.io.parsers import read_csv from pandas.io.stata import (read_stata, StataReader, InvalidColumnName, PossiblePrecisionLoss, StataMissingValue) -import pandas.util.testing as tm -from pandas.tslib import NaT -from pandas import compat class TestStata(tm.TestCase): @@ -1113,6 +1113,58 @@ def test_read_chunks_columns(self): tm.assert_frame_equal(from_frame, chunk, check_dtype=False) pos += chunksize + def test_write_variable_labels(self): + # GH 13631, add support for writing variable labels + original = pd.DataFrame({'a': [1, 2, 3, 4], + 'b': [1.0, 3.0, 27.0, 81.0], + 'c': ['Atlanta', 'Birmingham', + 'Cincinnati', 'Detroit']}) + original.index.name = 'index' + variable_labels = {'a': 'City Rank', 'b': 'City Exponent', 'c': 'City'} + with tm.ensure_clean() as path: + original.to_stata(path, variable_labels=variable_labels) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + expected_labels = {'index': '', + 'a': 'City Rank', + 'b': 'City Exponent', + 'c': 'City'} + tm.assert_equal(read_labels, expected_labels) + + variable_labels['index'] = 'The Index' + with tm.ensure_clean() as path: + original.to_stata(path, variable_labels=variable_labels) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + tm.assert_equal(read_labels, variable_labels) + + def test_write_variable_label_errors(self): + original = pd.DataFrame({'a': [1, 2, 3, 4], + 'b': [1.0, 3.0, 27.0, 81.0], + 'c': ['Atlanta', 'Birmingham', + 'Cincinnati', 'Detroit']}) + values = [u'\u03A1', u'\u0391', + u'\u039D', u'\u0394', + u'\u0391', u'\u03A3'] + + variable_labels_utf8 = {'a': 'City Rank', + 'b': 'City Exponent', + 'c': u''.join(values)} + + with tm.assertRaises(ValueError): + with tm.ensure_clean() as path: + original.to_stata(path, variable_labels=variable_labels_utf8) + + variable_labels_long = {'a': 'City Rank', + 'b': 'City Exponent', + 'c': 'A very, very, very long variable label ' + 'that is too long for Stata which means ' + 'that it has more than 80 characters'} + + with tm.assertRaises(ValueError): + with tm.ensure_clean() as path: + original.to_stata(path, variable_labels=variable_labels_long) + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],