Skip to content

Commit 403f778

Browse files
committed
Merge pull request #6622 from bashtage/stata-timestamps
BUG: Error in to_stata when DataFrame contains non-string column names
2 parents e19b2eb + 0e2c938 commit 403f778

File tree

3 files changed

+115
-54
lines changed

3 files changed

+115
-54
lines changed

doc/source/release.rst

+1
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ Bug Fixes
233233
- Bug in popping from a Series (:issue:`6600`)
234234
- Bug in ``iloc`` indexing when positional indexer matched Int64Index of corresponding axis no reordering happened (:issue:`6612`)
235235
- Bug in ``fillna`` with ``limit`` and ``value`` specified
236+
- Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`)
236237

237238
pandas 0.13.1
238239
-------------

pandas/io/stata.py

+93-49
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pandas.core.categorical import Categorical
2121
import datetime
2222
from pandas import compat
23-
from pandas.compat import long, lrange, lmap, lzip
23+
from pandas.compat import long, lrange, lmap, lzip, text_type, string_types
2424
from pandas import isnull
2525
from pandas.io.common import get_filepath_or_buffer
2626
from pandas.tslib import NaT
@@ -191,6 +191,21 @@ class PossiblePrecisionLoss(Warning):
191191
"""
192192

193193

194+
class InvalidColumnName(Warning):
195+
pass
196+
197+
198+
invalid_name_doc = """
199+
Not all pandas column names were valid Stata variable names.
200+
The following replacements have been made:
201+
202+
{0}
203+
204+
If this is not what you expect, please make sure you have Stata-compliant
205+
column names in your DataFrame (strings only, max 32 characters, only alphanumerics and
206+
underscores, no Stata reserved words)
207+
"""
208+
194209
def _cast_to_stata_types(data):
195210
"""Checks the dtypes of the columns of a pandas DataFrame for
196211
compatibility with the data types and ranges supported by Stata, and
@@ -942,7 +957,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist):
942957
else:
943958
if not isinstance(key, int):
944959
raise ValueError(
945-
"convery_dates key is not in varlist and is not an int"
960+
"convert_dates key is not in varlist and is not an int"
946961
)
947962
new_dict.update({key: convert_dates[key]})
948963
return new_dict
@@ -1092,6 +1107,78 @@ def _write(self, to_write):
10921107
else:
10931108
self._file.write(to_write)
10941109

1110+
1111+
def _check_column_names(self, data):
1112+
"""Checks column names to ensure that they are valid Stata column names.
1113+
This includes checks for:
1114+
* Non-string names
1115+
* Stata keywords
1116+
* Variables that start with numbers
1117+
* Variables with names that are too long
1118+
1119+
When an illegal variable name is detected, it is converted, and if dates
1120+
are exported, the variable name is propogated to the date conversion
1121+
dictionary
1122+
"""
1123+
converted_names = []
1124+
columns = list(data.columns)
1125+
original_columns = columns[:]
1126+
1127+
duplicate_var_id = 0
1128+
for j, name in enumerate(columns):
1129+
orig_name = name
1130+
if not isinstance(name, string_types):
1131+
name = text_type(name)
1132+
1133+
for c in name:
1134+
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and \
1135+
(c < '0' or c > '9') and c != '_':
1136+
name = name.replace(c, '_')
1137+
1138+
# Variable name must not be a reserved word
1139+
if name in self.RESERVED_WORDS:
1140+
name = '_' + name
1141+
1142+
# Variable name may not start with a number
1143+
if name[0] >= '0' and name[0] <= '9':
1144+
name = '_' + name
1145+
1146+
name = name[:min(len(name), 32)]
1147+
1148+
if not name == orig_name:
1149+
# check for duplicates
1150+
while columns.count(name) > 0:
1151+
# prepend ascending number to avoid duplicates
1152+
name = '_' + str(duplicate_var_id) + name
1153+
name = name[:min(len(name), 32)]
1154+
duplicate_var_id += 1
1155+
1156+
# need to possibly encode the orig name if its unicode
1157+
try:
1158+
orig_name = orig_name.encode('utf-8')
1159+
except:
1160+
pass
1161+
converted_names.append('{0} -> {1}'.format(orig_name, name))
1162+
1163+
columns[j] = name
1164+
1165+
data.columns = columns
1166+
1167+
# Check date conversion, and fix key if needed
1168+
if self._convert_dates:
1169+
for c, o in zip(columns, original_columns):
1170+
if c != o:
1171+
self._convert_dates[c] = self._convert_dates[o]
1172+
del self._convert_dates[o]
1173+
1174+
if converted_names:
1175+
import warnings
1176+
1177+
ws = invalid_name_doc.format('\n '.join(converted_names))
1178+
warnings.warn(ws, InvalidColumnName)
1179+
1180+
return data
1181+
10951182
def _prepare_pandas(self, data):
10961183
#NOTE: we might need a different API / class for pandas objects so
10971184
# we can set different semantics - handle this with a PR to pandas.io
@@ -1108,6 +1195,8 @@ def __iter__(self):
11081195
data = data.reset_index()
11091196
# Check columns for compatibility with stata
11101197
data = _cast_to_stata_types(data)
1198+
# Ensure column names are strings
1199+
data = self._check_column_names(data)
11111200
self.datarows = DataFrameRowIter(data)
11121201
self.nobs, self.nvar = data.shape
11131202
self.data = data
@@ -1181,58 +1270,13 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
11811270
for typ in self.typlist:
11821271
self._write(typ)
11831272

1184-
# varlist, length 33*nvar, char array, null terminated
1185-
converted_names = []
1186-
duplicate_var_id = 0
1187-
for j, name in enumerate(self.varlist):
1188-
orig_name = name
1189-
# Replaces all characters disallowed in .dta format by their integral representation.
1190-
for c in name:
1191-
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
1192-
name = name.replace(c, '_')
1193-
# Variable name must not be a reserved word
1194-
if name in self.RESERVED_WORDS:
1195-
name = '_' + name
1196-
# Variable name may not start with a number
1197-
if name[0] > '0' and name[0] < '9':
1198-
name = '_' + name
1199-
1200-
name = name[:min(len(name), 32)]
1201-
1202-
if not name == orig_name:
1203-
# check for duplicates
1204-
while self.varlist.count(name) > 0:
1205-
# prepend ascending number to avoid duplicates
1206-
name = '_' + str(duplicate_var_id) + name
1207-
name = name[:min(len(name), 32)]
1208-
duplicate_var_id += 1
1209-
1210-
# need to possibly encode the orig name if its unicode
1211-
try:
1212-
orig_name = orig_name.encode('utf-8')
1213-
except:
1214-
pass
1215-
1216-
converted_names.append('{0} -> {1}'.format(orig_name, name))
1217-
self.varlist[j] = name
1218-
1273+
# varlist names are checked by _check_column_names
1274+
# varlist, requires null terminated
12191275
for name in self.varlist:
12201276
name = self._null_terminate(name, True)
12211277
name = _pad_bytes(name[:32], 33)
12221278
self._write(name)
12231279

1224-
if converted_names:
1225-
from warnings import warn
1226-
warn("""Not all pandas column names were valid Stata variable names.
1227-
Made the following replacements:
1228-
1229-
{0}
1230-
1231-
If this is not what you expect, please make sure you have Stata-compliant
1232-
column names in your DataFrame (max 32 characters, only alphanumerics and
1233-
underscores)/
1234-
""".format('\n '.join(converted_names)))
1235-
12361280
# srtlist, 2*(nvar+1), int array, encoded by byteorder
12371281
srtlist = _pad_bytes("", (2*(nvar+1)))
12381282
self._write(srtlist)

pandas/io/tests/test_stata.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pandas as pd
1414
from pandas.core.frame import DataFrame, Series
1515
from pandas.io.parsers import read_csv
16-
from pandas.io.stata import read_stata, StataReader
16+
from pandas.io.stata import read_stata, StataReader, InvalidColumnName
1717
import pandas.util.testing as tm
1818
from pandas.util.misc import is_little_endian
1919
from pandas import compat
@@ -332,10 +332,10 @@ def test_read_write_dta12(self):
332332
tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted)
333333

334334
def test_read_write_dta13(self):
335-
s1 = Series(2**9,dtype=np.int16)
336-
s2 = Series(2**17,dtype=np.int32)
337-
s3 = Series(2**33,dtype=np.int64)
338-
original = DataFrame({'int16':s1,'int32':s2,'int64':s3})
335+
s1 = Series(2**9, dtype=np.int16)
336+
s2 = Series(2**17, dtype=np.int32)
337+
s3 = Series(2**33, dtype=np.int64)
338+
original = DataFrame({'int16': s1, 'int32': s2, 'int64': s3})
339339
original.index.name = 'index'
340340

341341
formatted = original
@@ -398,6 +398,22 @@ def test_timestamp_and_label(self):
398398
assert parsed_time_stamp == time_stamp
399399
assert reader.data_label == data_label
400400

401+
def test_numeric_column_names(self):
402+
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
403+
original.index.name = 'index'
404+
with tm.ensure_clean() as path:
405+
# should get a warning for that format.
406+
with warnings.catch_warnings(record=True) as w:
407+
tm.assert_produces_warning(original.to_stata(path), InvalidColumnName)
408+
# should produce a single warning
409+
np.testing.assert_equal(len(w), 1)
410+
411+
written_and_read_again = self.read_dta(path)
412+
written_and_read_again = written_and_read_again.set_index('index')
413+
columns = list(written_and_read_again.columns)
414+
convert_col_name = lambda x: int(x[1])
415+
written_and_read_again.columns = map(convert_col_name, columns)
416+
tm.assert_frame_equal(original, written_and_read_again)
401417

402418

403419
if __name__ == '__main__':

0 commit comments

Comments
 (0)