From 0e2c938d7d8f3ecffa93cf901d160533ed38d4b2 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Thu, 13 Mar 2014 00:36:26 +0000 Subject: [PATCH] BUG: Error in to_stata when DataFrame contains non-string column names to_stata does not work correctly when used with non-string names. Since Stata requires string names, the proposed fix attempts to rename columns using the string representation of the column name used. The main method that reformats column names was refactored to handle this case. Patch includes additional fixes for detecting invalid names. Patch includes some minor documentation fixes. --- doc/source/release.rst | 1 + pandas/io/stata.py | 142 ++++++++++++++++++++++------------ pandas/io/tests/test_stata.py | 26 +++++-- 3 files changed, 115 insertions(+), 54 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index 4890f22e98468..12a83f48706e5 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -233,6 +233,7 @@ Bug Fixes - Bug in popping from a Series (:issue:`6600`) - Bug in ``iloc`` indexing when positional indexer matched Int64Index of corresponding axis no reordering happened (:issue:`6612`) - Bug in ``fillna`` with ``limit`` and ``value`` specified +- Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`) pandas 0.13.1 ------------- diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 7d9d272eea1b6..4bb61e385a75c 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -20,7 +20,7 @@ from pandas.core.categorical import Categorical import datetime from pandas import compat -from pandas.compat import long, lrange, lmap, lzip +from pandas.compat import long, lrange, lmap, lzip, text_type, string_types from pandas import isnull from pandas.io.common import get_filepath_or_buffer from pandas.tslib import NaT @@ -191,6 +191,21 @@ class PossiblePrecisionLoss(Warning): """ +class InvalidColumnName(Warning): + pass + + +invalid_name_doc = """ +Not all pandas column names were valid Stata variable names. +The following replacements have been made: + + {0} + +If this is not what you expect, please make sure you have Stata-compliant +column names in your DataFrame (strings only, max 32 characters, only alphanumerics and +underscores, no Stata reserved words) +""" + def _cast_to_stata_types(data): """Checks the dtypes of the columns of a pandas DataFrame for compatibility with the data types and ranges supported by Stata, and @@ -942,7 +957,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist): else: if not isinstance(key, int): raise ValueError( - "convery_dates key is not in varlist and is not an int" + "convert_dates key is not in varlist and is not an int" ) new_dict.update({key: convert_dates[key]}) return new_dict @@ -1092,6 +1107,78 @@ def _write(self, to_write): else: self._file.write(to_write) + + def _check_column_names(self, data): + """Checks column names to ensure that they are valid Stata column names. + This includes checks for: + * Non-string names + * Stata keywords + * Variables that start with numbers + * Variables with names that are too long + + When an illegal variable name is detected, it is converted, and if dates + are exported, the variable name is propogated to the date conversion + dictionary + """ + converted_names = [] + columns = list(data.columns) + original_columns = columns[:] + + duplicate_var_id = 0 + for j, name in enumerate(columns): + orig_name = name + if not isinstance(name, string_types): + name = text_type(name) + + for c in name: + if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and \ + (c < '0' or c > '9') and c != '_': + name = name.replace(c, '_') + + # Variable name must not be a reserved word + if name in self.RESERVED_WORDS: + name = '_' + name + + # Variable name may not start with a number + if name[0] >= '0' and name[0] <= '9': + name = '_' + name + + name = name[:min(len(name), 32)] + + if not name == orig_name: + # check for duplicates + while columns.count(name) > 0: + # prepend ascending number to avoid duplicates + name = '_' + str(duplicate_var_id) + name + name = name[:min(len(name), 32)] + duplicate_var_id += 1 + + # need to possibly encode the orig name if its unicode + try: + orig_name = orig_name.encode('utf-8') + except: + pass + converted_names.append('{0} -> {1}'.format(orig_name, name)) + + columns[j] = name + + data.columns = columns + + # Check date conversion, and fix key if needed + if self._convert_dates: + for c, o in zip(columns, original_columns): + if c != o: + self._convert_dates[c] = self._convert_dates[o] + del self._convert_dates[o] + + if converted_names: + import warnings + + ws = invalid_name_doc.format('\n '.join(converted_names)) + warnings.warn(ws, InvalidColumnName) + + return data + def _prepare_pandas(self, data): #NOTE: we might need a different API / class for pandas objects so # we can set different semantics - handle this with a PR to pandas.io @@ -1108,6 +1195,8 @@ def __iter__(self): data = data.reset_index() # Check columns for compatibility with stata data = _cast_to_stata_types(data) + # Ensure column names are strings + data = self._check_column_names(data) self.datarows = DataFrameRowIter(data) self.nobs, self.nvar = data.shape self.data = data @@ -1181,58 +1270,13 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None, for typ in self.typlist: self._write(typ) - # varlist, length 33*nvar, char array, null terminated - converted_names = [] - duplicate_var_id = 0 - for j, name in enumerate(self.varlist): - orig_name = name - # Replaces all characters disallowed in .dta format by their integral representation. - for c in name: - if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_': - name = name.replace(c, '_') - # Variable name must not be a reserved word - if name in self.RESERVED_WORDS: - name = '_' + name - # Variable name may not start with a number - if name[0] > '0' and name[0] < '9': - name = '_' + name - - name = name[:min(len(name), 32)] - - if not name == orig_name: - # check for duplicates - while self.varlist.count(name) > 0: - # prepend ascending number to avoid duplicates - name = '_' + str(duplicate_var_id) + name - name = name[:min(len(name), 32)] - duplicate_var_id += 1 - - # need to possibly encode the orig name if its unicode - try: - orig_name = orig_name.encode('utf-8') - except: - pass - - converted_names.append('{0} -> {1}'.format(orig_name, name)) - self.varlist[j] = name - + # varlist names are checked by _check_column_names + # varlist, requires null terminated for name in self.varlist: name = self._null_terminate(name, True) name = _pad_bytes(name[:32], 33) self._write(name) - if converted_names: - from warnings import warn - warn("""Not all pandas column names were valid Stata variable names. - Made the following replacements: - - {0} - - If this is not what you expect, please make sure you have Stata-compliant - column names in your DataFrame (max 32 characters, only alphanumerics and - underscores)/ - """.format('\n '.join(converted_names))) - # srtlist, 2*(nvar+1), int array, encoded by byteorder srtlist = _pad_bytes("", (2*(nvar+1))) self._write(srtlist) diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index a99420493d047..fe79bf20615bb 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -13,7 +13,7 @@ import pandas as pd from pandas.core.frame import DataFrame, Series from pandas.io.parsers import read_csv -from pandas.io.stata import read_stata, StataReader +from pandas.io.stata import read_stata, StataReader, InvalidColumnName import pandas.util.testing as tm from pandas.util.misc import is_little_endian from pandas import compat @@ -332,10 +332,10 @@ def test_read_write_dta12(self): tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted) def test_read_write_dta13(self): - s1 = Series(2**9,dtype=np.int16) - s2 = Series(2**17,dtype=np.int32) - s3 = Series(2**33,dtype=np.int64) - original = DataFrame({'int16':s1,'int32':s2,'int64':s3}) + s1 = Series(2**9, dtype=np.int16) + s2 = Series(2**17, dtype=np.int32) + s3 = Series(2**33, dtype=np.int64) + original = DataFrame({'int16': s1, 'int32': s2, 'int64': s3}) original.index.name = 'index' formatted = original @@ -398,6 +398,22 @@ def test_timestamp_and_label(self): assert parsed_time_stamp == time_stamp assert reader.data_label == data_label + def test_numeric_column_names(self): + original = DataFrame(np.reshape(np.arange(25.0), (5, 5))) + original.index.name = 'index' + with tm.ensure_clean() as path: + # should get a warning for that format. + with warnings.catch_warnings(record=True) as w: + tm.assert_produces_warning(original.to_stata(path), InvalidColumnName) + # should produce a single warning + np.testing.assert_equal(len(w), 1) + + written_and_read_again = self.read_dta(path) + written_and_read_again = written_and_read_again.set_index('index') + columns = list(written_and_read_again.columns) + convert_col_name = lambda x: int(x[1]) + written_and_read_again.columns = map(convert_col_name, columns) + tm.assert_frame_equal(original, written_and_read_again) if __name__ == '__main__':