Skip to content

BUG: Error in to_stata when DataFrame contains non-string column names #6622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ Bug Fixes
- Bug in popping from a Series (:issue:`6600`)
- Bug in ``iloc`` indexing when positional indexer matched Int64Index of corresponding axis no reordering happened (:issue:`6612`)
- Bug in ``fillna`` with ``limit`` and ``value`` specified
- Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`)

pandas 0.13.1
-------------
Expand Down
142 changes: 93 additions & 49 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pandas.core.categorical import Categorical
import datetime
from pandas import compat
from pandas.compat import long, lrange, lmap, lzip
from pandas.compat import long, lrange, lmap, lzip, text_type, string_types
from pandas import isnull
from pandas.io.common import get_filepath_or_buffer
from pandas.tslib import NaT
Expand Down Expand Up @@ -191,6 +191,21 @@ class PossiblePrecisionLoss(Warning):
"""


class InvalidColumnName(Warning):
pass


invalid_name_doc = """
Not all pandas column names were valid Stata variable names.
The following replacements have been made:
{0}
If this is not what you expect, please make sure you have Stata-compliant
column names in your DataFrame (strings only, max 32 characters, only alphanumerics and
underscores, no Stata reserved words)
"""

def _cast_to_stata_types(data):
"""Checks the dtypes of the columns of a pandas DataFrame for
compatibility with the data types and ranges supported by Stata, and
Expand Down Expand Up @@ -942,7 +957,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist):
else:
if not isinstance(key, int):
raise ValueError(
"convery_dates key is not in varlist and is not an int"
"convert_dates key is not in varlist and is not an int"
)
new_dict.update({key: convert_dates[key]})
return new_dict
Expand Down Expand Up @@ -1092,6 +1107,78 @@ def _write(self, to_write):
else:
self._file.write(to_write)


def _check_column_names(self, data):
"""Checks column names to ensure that they are valid Stata column names.
This includes checks for:
* Non-string names
* Stata keywords
* Variables that start with numbers
* Variables with names that are too long
When an illegal variable name is detected, it is converted, and if dates
are exported, the variable name is propogated to the date conversion
dictionary
"""
converted_names = []
columns = list(data.columns)
original_columns = columns[:]

duplicate_var_id = 0
for j, name in enumerate(columns):
orig_name = name
if not isinstance(name, string_types):
name = text_type(name)

for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and \
(c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')

# Variable name must not be a reserved word
if name in self.RESERVED_WORDS:
name = '_' + name

# Variable name may not start with a number
if name[0] >= '0' and name[0] <= '9':
name = '_' + name

name = name[:min(len(name), 32)]

if not name == orig_name:
# check for duplicates
while columns.count(name) > 0:
# prepend ascending number to avoid duplicates
name = '_' + str(duplicate_var_id) + name
name = name[:min(len(name), 32)]
duplicate_var_id += 1

# need to possibly encode the orig name if its unicode
try:
orig_name = orig_name.encode('utf-8')
except:
pass
converted_names.append('{0} -> {1}'.format(orig_name, name))

columns[j] = name

data.columns = columns

# Check date conversion, and fix key if needed
if self._convert_dates:
for c, o in zip(columns, original_columns):
if c != o:
self._convert_dates[c] = self._convert_dates[o]
del self._convert_dates[o]

if converted_names:
import warnings

ws = invalid_name_doc.format('\n '.join(converted_names))
warnings.warn(ws, InvalidColumnName)

return data

def _prepare_pandas(self, data):
#NOTE: we might need a different API / class for pandas objects so
# we can set different semantics - handle this with a PR to pandas.io
Expand All @@ -1108,6 +1195,8 @@ def __iter__(self):
data = data.reset_index()
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)
# Ensure column names are strings
data = self._check_column_names(data)
self.datarows = DataFrameRowIter(data)
self.nobs, self.nvar = data.shape
self.data = data
Expand Down Expand Up @@ -1181,58 +1270,13 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
for typ in self.typlist:
self._write(typ)

# varlist, length 33*nvar, char array, null terminated
converted_names = []
duplicate_var_id = 0
for j, name in enumerate(self.varlist):
orig_name = name
# Replaces all characters disallowed in .dta format by their integral representation.
for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')
# Variable name must not be a reserved word
if name in self.RESERVED_WORDS:
name = '_' + name
# Variable name may not start with a number
if name[0] > '0' and name[0] < '9':
name = '_' + name

name = name[:min(len(name), 32)]

if not name == orig_name:
# check for duplicates
while self.varlist.count(name) > 0:
# prepend ascending number to avoid duplicates
name = '_' + str(duplicate_var_id) + name
name = name[:min(len(name), 32)]
duplicate_var_id += 1

# need to possibly encode the orig name if its unicode
try:
orig_name = orig_name.encode('utf-8')
except:
pass

converted_names.append('{0} -> {1}'.format(orig_name, name))
self.varlist[j] = name

# varlist names are checked by _check_column_names
# varlist, requires null terminated
for name in self.varlist:
name = self._null_terminate(name, True)
name = _pad_bytes(name[:32], 33)
self._write(name)

if converted_names:
from warnings import warn
warn("""Not all pandas column names were valid Stata variable names.
Made the following replacements:
{0}
If this is not what you expect, please make sure you have Stata-compliant
column names in your DataFrame (max 32 characters, only alphanumerics and
underscores)/
""".format('\n '.join(converted_names)))

# srtlist, 2*(nvar+1), int array, encoded by byteorder
srtlist = _pad_bytes("", (2*(nvar+1)))
self._write(srtlist)
Expand Down
26 changes: 21 additions & 5 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
from pandas.core.frame import DataFrame, Series
from pandas.io.parsers import read_csv
from pandas.io.stata import read_stata, StataReader
from pandas.io.stata import read_stata, StataReader, InvalidColumnName
import pandas.util.testing as tm
from pandas.util.misc import is_little_endian
from pandas import compat
Expand Down Expand Up @@ -332,10 +332,10 @@ def test_read_write_dta12(self):
tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted)

def test_read_write_dta13(self):
s1 = Series(2**9,dtype=np.int16)
s2 = Series(2**17,dtype=np.int32)
s3 = Series(2**33,dtype=np.int64)
original = DataFrame({'int16':s1,'int32':s2,'int64':s3})
s1 = Series(2**9, dtype=np.int16)
s2 = Series(2**17, dtype=np.int32)
s3 = Series(2**33, dtype=np.int64)
original = DataFrame({'int16': s1, 'int32': s2, 'int64': s3})
original.index.name = 'index'

formatted = original
Expand Down Expand Up @@ -398,6 +398,22 @@ def test_timestamp_and_label(self):
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label

def test_numeric_column_names(self):
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
original.index.name = 'index'
with tm.ensure_clean() as path:
# should get a warning for that format.
with warnings.catch_warnings(record=True) as w:
tm.assert_produces_warning(original.to_stata(path), InvalidColumnName)
# should produce a single warning
np.testing.assert_equal(len(w), 1)

written_and_read_again = self.read_dta(path)
written_and_read_again = written_and_read_again.set_index('index')
columns = list(written_and_read_again.columns)
convert_col_name = lambda x: int(x[1])
written_and_read_again.columns = map(convert_col_name, columns)
tm.assert_frame_equal(original, written_and_read_again)


if __name__ == '__main__':
Expand Down