From 204b50e5a30644179188ef06641d4c0e095c5bcf Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Fri, 7 Nov 2014 17:07:28 -0500 Subject: [PATCH] ENH: Add categorical support for Stata export Add support for exporting DataFrames containing categorical data. closes #8633 xref #7621 --- doc/source/io.rst | 6 + doc/source/whatsnew/v0.15.2.txt | 1 + pandas/io/stata.py | 256 ++++++++++++++++++++++++++++---- pandas/io/tests/test_stata.py | 79 +++++++++- 4 files changed, 313 insertions(+), 29 deletions(-) diff --git a/doc/source/io.rst b/doc/source/io.rst index 066a9af472c24..1d83e06a13567 100644 --- a/doc/source/io.rst +++ b/doc/source/io.rst @@ -3626,12 +3626,18 @@ outside of this range, the data is cast to ``int16``. if ``int64`` values are larger than 2**53. .. warning:: + :class:`~pandas.io.stata.StataWriter`` and :func:`~pandas.core.frame.DataFrame.to_stata` only support fixed width strings containing up to 244 characters, a limitation imposed by the version 115 dta file format. Attempting to write *Stata* dta files with strings longer than 244 characters raises a ``ValueError``. +.. warning:: + + *Stata* data files only support text labels for categorical data. Exporting + data frames containing categorical data will convert non-string categorical values + to strings. .. _io.stata_reader: diff --git a/doc/source/whatsnew/v0.15.2.txt b/doc/source/whatsnew/v0.15.2.txt index 97cbecd4bb0e7..8b104d54ed778 100644 --- a/doc/source/whatsnew/v0.15.2.txt +++ b/doc/source/whatsnew/v0.15.2.txt @@ -41,6 +41,7 @@ API changes Enhancements ~~~~~~~~~~~~ +- Added ability to export Categorical data to Stata (:issue:`8633`). .. _whatsnew_0152.performance: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index c2542594861c4..ab9d330b48988 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -15,13 +15,13 @@ import struct from dateutil.relativedelta import relativedelta from pandas.core.base import StringMixin +from pandas.core.categorical import Categorical from pandas.core.frame import DataFrame from pandas.core.series import Series -from pandas.core.categorical import Categorical import datetime from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \ - zip + zip, BytesIO import pandas.core.common as com from pandas.io.common import get_filepath_or_buffer from pandas.lib import max_len_string_array, infer_dtype @@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning): conversion range. This may result in a loss of precision in the saved data. """ +class ValueLabelTypeMismatch(Warning): + pass + +value_label_mismatch_doc = """ +Stata value labels (pandas categories) must be strings. Column {0} contains +non-string labels which will be converted to strings. Please check that the +Stata data file created has not lost information due to duplicate labels. +""" + class InvalidColumnName(Warning): pass @@ -425,6 +434,131 @@ def _cast_to_stata_types(data): return data +class StataValueLabel(object): + """ + Parse a categorical column and prepare formatted output + + Parameters + ----------- + value : int8, int16, int32, float32 or float64 + The Stata missing value code + + Attributes + ---------- + string : string + String representation of the Stata missing value + value : int8, int16, int32, float32 or float64 + The original encoded missing value + + Methods + ------- + generate_value_label + + """ + + def __init__(self, catarray): + + self.labname = catarray.name + + categories = catarray.cat.categories + self.value_labels = list(zip(np.arange(len(categories)), categories)) + self.value_labels.sort(key=lambda x: x[0]) + self.text_len = np.int32(0) + self.off = [] + self.val = [] + self.txt = [] + self.n = 0 + + # Compute lengths and setup lists of offsets and labels + for vl in self.value_labels: + category = vl[1] + if not isinstance(category, string_types): + category = str(category) + import warnings + warnings.warn(value_label_mismatch_doc.format(catarray.name), + ValueLabelTypeMismatch) + + self.off.append(self.text_len) + self.text_len += len(category) + 1 # +1 for the padding + self.val.append(vl[0]) + self.txt.append(category) + self.n += 1 + + if self.text_len > 32000: + raise ValueError('Stata value labels for a single variable must ' + 'have a combined length less than 32,000 ' + 'characters.') + + # Ensure int32 + self.off = np.array(self.off, dtype=np.int32) + self.val = np.array(self.val, dtype=np.int32) + + # Total length + self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len + + def _encode(self, s): + """ + Python 3 compatability shim + """ + if compat.PY3: + return s.encode(self._encoding) + else: + return s + + def generate_value_label(self, byteorder, encoding): + """ + Parameters + ---------- + byteorder : str + Byte order of the output + encoding : str + File encoding + + Returns + ------- + value_label : bytes + Bytes containing the formatted value label + """ + + self._encoding = encoding + bio = BytesIO() + null_string = '\x00' + null_byte = b'\x00' + + # len + bio.write(struct.pack(byteorder + 'i', self.len)) + + # labname + labname = self._encode(_pad_bytes(self.labname[:32], 33)) + bio.write(labname) + + # padding - 3 bytes + for i in range(3): + bio.write(struct.pack('c', null_byte)) + + # value_label_table + # n - int32 + bio.write(struct.pack(byteorder + 'i', self.n)) + + # textlen - int32 + bio.write(struct.pack(byteorder + 'i', self.text_len)) + + # off - int32 array (n elements) + for offset in self.off: + bio.write(struct.pack(byteorder + 'i', offset)) + + # val - int32 array (n elements) + for value in self.val: + bio.write(struct.pack(byteorder + 'i', value)) + + # txt - Text labels, null terminated + for text in self.txt: + bio.write(self._encode(text + null_string)) + + bio.seek(0) + return bio.read() + + class StataMissingValue(StringMixin): """ An observation's missing value. @@ -477,25 +611,31 @@ class StataMissingValue(StringMixin): for i in range(1, 27): MISSING_VALUES[i + b] = '.' + chr(96 + i) - base = b'\x00\x00\x00\x7f' + float32_base = b'\x00\x00\x00\x7f' increment = struct.unpack(' 0: MISSING_VALUES[value] += chr(96 + i) int_value = struct.unpack(' 0: MISSING_VALUES[value] += chr(96 + i) int_value = struct.unpack('q', struct.pack('= get_base_missing_value(dtype): + if dtype == np.int8: + dtype = np.int16 + elif dtype == np.int16: + dtype = np.int32 + else: + dtype = np.float64 + values = np.array(values, dtype=dtype) + + # Replace missing values with Stata missing value for type + values[values == -1] = get_base_missing_value(dtype) + data_formatted.append((col, values, index)) + + else: + data_formatted.append((col, data[col])) + return DataFrame.from_items(data_formatted) def _replace_nans(self, data): # return data @@ -1480,27 +1675,26 @@ def _check_column_names(self, data): def _prepare_pandas(self, data): #NOTE: we might need a different API / class for pandas objects so # we can set different semantics - handle this with a PR to pandas.io - class DataFrameRowIter(object): - def __init__(self, data): - self.data = data - - def __iter__(self): - for row in data.itertuples(): - # First element is index, so remove - yield row[1:] if self._write_index: data = data.reset_index() - # Check columns for compatibility with stata - data = _cast_to_stata_types(data) + # Ensure column names are strings data = self._check_column_names(data) + + # Check columns for compatibility with stata, upcast if necessary + data = _cast_to_stata_types(data) + # Replace NaNs with Stata missing values data = self._replace_nans(data) - self.datarows = DataFrameRowIter(data) + + # Convert categoricals to int data, and strip labels + data = self._prepare_categoricals(data) + self.nobs, self.nvar = data.shape self.data = data self.varlist = data.columns.tolist() + dtypes = data.dtypes if self._convert_dates is not None: self._convert_dates = _maybe_convert_to_int_keys( @@ -1515,6 +1709,7 @@ def __iter__(self): self.fmtlist = [] for col, dtype in dtypes.iteritems(): self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col])) + # set the given format for the datetime cols if self._convert_dates is not None: for key in self._convert_dates: @@ -1529,8 +1724,14 @@ def write_file(self): self._write(_pad_bytes("", 5)) self._prepare_data() self._write_data() + self._write_value_labels() self._file.close() + def _write_value_labels(self): + for vl in self._value_labels: + self._file.write(vl.generate_value_label(self._byteorder, + self._encoding)) + def _write_header(self, data_label=None, time_stamp=None): byteorder = self._byteorder # ds_format - just use 114 @@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None, self._write(_pad_bytes(fmt, 49)) # lbllist, 33*nvar, char array - #NOTE: this is where you could get fancy with pandas categorical type for i in range(nvar): - self._write(_pad_bytes("", 33)) + # Use variable name when categorical + if self._is_col_cat[i]: + name = self.varlist[i] + name = self._null_terminate(name, True) + name = _pad_bytes(name[:32], 33) + self._write(name) + else: # Default is empty label + self._write(_pad_bytes("", 33)) def _write_variable_labels(self, labels=None): nvar = self.nvar @@ -1624,9 +1831,6 @@ def _prepare_data(self): data_cols.append(data[col].values) dtype = np.dtype(dtype) - # 3. Convert to record array - - # data.to_records(index=False, convert_datetime64=False) if has_strings: self.data = np.fromiter(zip(*data_cols), dtype=dtype) else: diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index 2cb7809166be5..d97feaea2658a 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -164,8 +164,9 @@ def test_read_dta2(self): parsed_117 = self.read_dta(self.dta2_117) # 113 is buggy due ot limits date format support in Stata # parsed_113 = self.read_dta(self.dta2_113) - tm.assert_equal( - len(w), 1) # should get a warning for that format. + + # should get a warning for that format. + tm.assert_equal(len(w), 1) # buggy test because of the NaT comparison on certain platforms # Format 113 test fails since it does not support tc and tC formats @@ -214,7 +215,7 @@ def test_read_dta4(self): 'labeled_with_missings', 'float_labelled']) # these are all categoricals - expected = pd.concat([ Series(pd.Categorical(value)) for col, value in compat.iteritems(expected)],axis=1) + expected = pd.concat([expected[col].astype('category') for col in expected], axis=1) tm.assert_frame_equal(parsed_113, expected) tm.assert_frame_equal(parsed_114, expected) @@ -744,6 +745,78 @@ def test_drop_column(self): columns = ['byte_', 'int_', 'long_', 'not_found'] read_stata(self.dta15_117, convert_dates=True, columns=columns) + def test_categorical_writing(self): + original = DataFrame.from_records( + [ + ["one", "ten", "one", "one", "one", 1], + ["two", "nine", "two", "two", "two", 2], + ["three", "eight", "three", "three", "three", 3], + ["four", "seven", 4, "four", "four", 4], + ["five", "six", 5, np.nan, "five", 5], + ["six", "five", 6, np.nan, "six", 6], + ["seven", "four", 7, np.nan, "seven", 7], + ["eight", "three", 8, np.nan, "eight", 8], + ["nine", "two", 9, np.nan, "nine", 9], + ["ten", "one", "ten", np.nan, "ten", 10] + ], + columns=['fully_labeled', 'fully_labeled2', 'incompletely_labeled', + 'labeled_with_missings', 'float_labelled', 'unlabeled']) + expected = original.copy() + + # these are all categoricals + original = pd.concat([original[col].astype('category') for col in original], axis=1) + + expected['incompletely_labeled'] = expected['incompletely_labeled'].apply(str) + expected['unlabeled'] = expected['unlabeled'].apply(str) + expected = pd.concat([expected[col].astype('category') for col in expected], axis=1) + expected.index.name = 'index' + + with tm.ensure_clean() as path: + with warnings.catch_warnings(record=True) as w: + # Silence warnings + original.to_stata(path) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal(written_and_read_again.set_index('index'), expected) + + + def test_categorical_warnings_and_errors(self): + # Warning for non-string labels + # Error for labels too long + original = pd.DataFrame.from_records( + [['a' * 10000], + ['b' * 10000], + ['c' * 10000], + ['d' * 10000]], + columns=['Too_long']) + + original = pd.concat([original[col].astype('category') for col in original], axis=1) + with tm.ensure_clean() as path: + tm.assertRaises(ValueError, original.to_stata, path) + + original = pd.DataFrame.from_records( + [['a'], + ['b'], + ['c'], + ['d'], + [1]], + columns=['Too_long']) + original = pd.concat([original[col].astype('category') for col in original], axis=1) + + with warnings.catch_warnings(record=True) as w: + original.to_stata(path) + tm.assert_equal(len(w), 1) # should get a warning for mixed content + + def test_categorical_with_stata_missing_values(self): + values = [['a' + str(i)] for i in range(120)] + values.append([np.nan]) + original = pd.DataFrame.from_records(values, columns=['many_labels']) + original = pd.concat([original[col].astype('category') for col in original], axis=1) + original.index.name = 'index' + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal(written_and_read_again.set_index('index'), original) + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False)