diff --git a/doc/source/release.rst b/doc/source/release.rst index 4890f22e98468..12a83f48706e5 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -233,6 +233,7 @@ Bug Fixes - Bug in popping from a Series (:issue:`6600`) - Bug in ``iloc`` indexing when positional indexer matched Int64Index of corresponding axis no reordering happened (:issue:`6612`) - Bug in ``fillna`` with ``limit`` and ``value`` specified +- Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`) pandas 0.13.1 ------------- diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 7d9d272eea1b6..4bb61e385a75c 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -20,7 +20,7 @@ from pandas.core.categorical import Categorical import datetime from pandas import compat -from pandas.compat import long, lrange, lmap, lzip +from pandas.compat import long, lrange, lmap, lzip, text_type, string_types from pandas import isnull from pandas.io.common import get_filepath_or_buffer from pandas.tslib import NaT @@ -191,6 +191,21 @@ class PossiblePrecisionLoss(Warning): """ +class InvalidColumnName(Warning): + pass + + +invalid_name_doc = """ +Not all pandas column names were valid Stata variable names. +The following replacements have been made: + + {0} + +If this is not what you expect, please make sure you have Stata-compliant +column names in your DataFrame (strings only, max 32 characters, only alphanumerics and +underscores, no Stata reserved words) +""" + def _cast_to_stata_types(data): """Checks the dtypes of the columns of a pandas DataFrame for compatibility with the data types and ranges supported by Stata, and @@ -942,7 +957,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist): else: if not isinstance(key, int): raise ValueError( - "convery_dates key is not in varlist and is not an int" + "convert_dates key is not in varlist and is not an int" ) new_dict.update({key: convert_dates[key]}) return new_dict @@ -1092,6 +1107,78 @@ def _write(self, to_write): else: self._file.write(to_write) + + def _check_column_names(self, data): + """Checks column names to ensure that they are valid Stata column names. + This includes checks for: + * Non-string names + * Stata keywords + * Variables that start with numbers + * Variables with names that are too long + + When an illegal variable name is detected, it is converted, and if dates + are exported, the variable name is propogated to the date conversion + dictionary + """ + converted_names = [] + columns = list(data.columns) + original_columns = columns[:] + + duplicate_var_id = 0 + for j, name in enumerate(columns): + orig_name = name + if not isinstance(name, string_types): + name = text_type(name) + + for c in name: + if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and \ + (c < '0' or c > '9') and c != '_': + name = name.replace(c, '_') + + # Variable name must not be a reserved word + if name in self.RESERVED_WORDS: + name = '_' + name + + # Variable name may not start with a number + if name[0] >= '0' and name[0] <= '9': + name = '_' + name + + name = name[:min(len(name), 32)] + + if not name == orig_name: + # check for duplicates + while columns.count(name) > 0: + # prepend ascending number to avoid duplicates + name = '_' + str(duplicate_var_id) + name + name = name[:min(len(name), 32)] + duplicate_var_id += 1 + + # need to possibly encode the orig name if its unicode + try: + orig_name = orig_name.encode('utf-8') + except: + pass + converted_names.append('{0} -> {1}'.format(orig_name, name)) + + columns[j] = name + + data.columns = columns + + # Check date conversion, and fix key if needed + if self._convert_dates: + for c, o in zip(columns, original_columns): + if c != o: + self._convert_dates[c] = self._convert_dates[o] + del self._convert_dates[o] + + if converted_names: + import warnings + + ws = invalid_name_doc.format('\n '.join(converted_names)) + warnings.warn(ws, InvalidColumnName) + + return data + def _prepare_pandas(self, data): #NOTE: we might need a different API / class for pandas objects so # we can set different semantics - handle this with a PR to pandas.io @@ -1108,6 +1195,8 @@ def __iter__(self): data = data.reset_index() # Check columns for compatibility with stata data = _cast_to_stata_types(data) + # Ensure column names are strings + data = self._check_column_names(data) self.datarows = DataFrameRowIter(data) self.nobs, self.nvar = data.shape self.data = data @@ -1181,58 +1270,13 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None, for typ in self.typlist: self._write(typ) - # varlist, length 33*nvar, char array, null terminated - converted_names = [] - duplicate_var_id = 0 - for j, name in enumerate(self.varlist): - orig_name = name - # Replaces all characters disallowed in .dta format by their integral representation. - for c in name: - if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_': - name = name.replace(c, '_') - # Variable name must not be a reserved word - if name in self.RESERVED_WORDS: - name = '_' + name - # Variable name may not start with a number - if name[0] > '0' and name[0] < '9': - name = '_' + name - - name = name[:min(len(name), 32)] - - if not name == orig_name: - # check for duplicates - while self.varlist.count(name) > 0: - # prepend ascending number to avoid duplicates - name = '_' + str(duplicate_var_id) + name - name = name[:min(len(name), 32)] - duplicate_var_id += 1 - - # need to possibly encode the orig name if its unicode - try: - orig_name = orig_name.encode('utf-8') - except: - pass - - converted_names.append('{0} -> {1}'.format(orig_name, name)) - self.varlist[j] = name - + # varlist names are checked by _check_column_names + # varlist, requires null terminated for name in self.varlist: name = self._null_terminate(name, True) name = _pad_bytes(name[:32], 33) self._write(name) - if converted_names: - from warnings import warn - warn("""Not all pandas column names were valid Stata variable names. - Made the following replacements: - - {0} - - If this is not what you expect, please make sure you have Stata-compliant - column names in your DataFrame (max 32 characters, only alphanumerics and - underscores)/ - """.format('\n '.join(converted_names))) - # srtlist, 2*(nvar+1), int array, encoded by byteorder srtlist = _pad_bytes("", (2*(nvar+1))) self._write(srtlist) diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index a99420493d047..fe79bf20615bb 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -13,7 +13,7 @@ import pandas as pd from pandas.core.frame import DataFrame, Series from pandas.io.parsers import read_csv -from pandas.io.stata import read_stata, StataReader +from pandas.io.stata import read_stata, StataReader, InvalidColumnName import pandas.util.testing as tm from pandas.util.misc import is_little_endian from pandas import compat @@ -332,10 +332,10 @@ def test_read_write_dta12(self): tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted) def test_read_write_dta13(self): - s1 = Series(2**9,dtype=np.int16) - s2 = Series(2**17,dtype=np.int32) - s3 = Series(2**33,dtype=np.int64) - original = DataFrame({'int16':s1,'int32':s2,'int64':s3}) + s1 = Series(2**9, dtype=np.int16) + s2 = Series(2**17, dtype=np.int32) + s3 = Series(2**33, dtype=np.int64) + original = DataFrame({'int16': s1, 'int32': s2, 'int64': s3}) original.index.name = 'index' formatted = original @@ -398,6 +398,22 @@ def test_timestamp_and_label(self): assert parsed_time_stamp == time_stamp assert reader.data_label == data_label + def test_numeric_column_names(self): + original = DataFrame(np.reshape(np.arange(25.0), (5, 5))) + original.index.name = 'index' + with tm.ensure_clean() as path: + # should get a warning for that format. + with warnings.catch_warnings(record=True) as w: + tm.assert_produces_warning(original.to_stata(path), InvalidColumnName) + # should produce a single warning + np.testing.assert_equal(len(w), 1) + + written_and_read_again = self.read_dta(path) + written_and_read_again = written_and_read_again.set_index('index') + columns = list(written_and_read_again.columns) + convert_col_name = lambda x: int(x[1]) + written_and_read_again.columns = map(convert_col_name, columns) + tm.assert_frame_equal(original, written_and_read_again) if __name__ == '__main__':