Skip to content

StataWriter: Replace non-isalnum characters in variable names by _ instead of integral represantation of replaced character. Eliminate duplicates created by replacement. #5709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ Improvements to existing features
MultiIndex and Hierarchical Rows. Set the ``merge_cells`` to ``False`` to
restore the previous behaviour. (:issue:`5254`)
- The FRED DataReader now accepts multiple series (:issue`3413`)
- StataWriter adjusts variable names to Stata's limitations (:issue:`5709`)

API Changes
~~~~~~~~~~~
Expand Down
44 changes: 44 additions & 0 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,11 +1068,55 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
self._write(typ)

# varlist, length 33*nvar, char array, null terminated
converted_names = []
duplicate_var_id = 0
for j, name in enumerate(self.varlist):
orig_name = name
# Replaces all characters disallowed in .dta format by their integral representation.
for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')

# Variable name may not start with a number
if name[0] > '0' and name[0] < '9':
name = '_' + name

name = name[:min(len(name), 32)]

if not name == orig_name:
# check for duplicates
while self.varlist.count(name) > 0:
# prepend ascending number to avoid duplicates
name = '_' + str(duplicate_var_id) + name
name = name[:min(len(name), 32)]
duplicate_var_id += 1

# need to possibly encode the orig name if its unicode
try:
orig_name = orig_name.encode('utf-8')
except:
pass

converted_names.append('{0} -> {1}'.format(orig_name, name))
self.varlist[j] = name

for name in self.varlist:
name = self._null_terminate(name, True)
name = _pad_bytes(name[:32], 33)
self._write(name)

if converted_names:
from warnings import warn
warn("""Not all pandas column names were valid Stata variable names.
Made the following replacements:

{0}

If this is not what you expect, please make sure you have Stata-compliant
column names in your DataFrame (max 32 characters, only alphanumerics and
underscores)/
""".format('\n '.join(converted_names)))

# srtlist, 2*(nvar+1), int array, encoded by byteorder
srtlist = _pad_bytes("", (2*(nvar+1)))
self._write(srtlist)
Expand Down
32 changes: 32 additions & 0 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,38 @@ def test_encoding(self):
self.assert_(result == expected)
self.assert_(isinstance(result, unicode))

def test_read_write_dta11(self):
original = DataFrame([(1, 2, 3, 4)],
columns=['good', compat.u('b\u00E4d'), '8number', 'astringwithmorethan32characters______'])
formatted = DataFrame([(1, 2, 3, 4)],
columns=['good', 'b_d', '_8number', 'astringwithmorethan32characters_'])
formatted.index.name = 'index'

with tm.ensure_clean() as path:
with warnings.catch_warnings(record=True) as w:
original.to_stata(path, None, False)
np.testing.assert_equal(
len(w), 1) # should get a warning for that format.

written_and_read_again = self.read_dta(path)
tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted)

def test_read_write_dta12(self):
original = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_1', 'astringwithmorethan32characters_2', '+', '-'])
formatted = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_', '_0astringwithmorethan32character', '_', '_1_'])
formatted.index.name = 'index'

with tm.ensure_clean() as path:
with warnings.catch_warnings(record=True) as w:
original.to_stata(path, None, False)
np.testing.assert_equal(
len(w), 1) # should get a warning for that format.

written_and_read_again = self.read_dta(path)
tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted)

if __name__ == '__main__':
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
exit=False)