Skip to content

ENH: Add categorical support for Stata export #8767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3626,12 +3626,18 @@ outside of this range, the data is cast to ``int16``.
if ``int64`` values are larger than 2**53.

.. warning::

:class:`~pandas.io.stata.StataWriter`` and
:func:`~pandas.core.frame.DataFrame.to_stata` only support fixed width
strings containing up to 244 characters, a limitation imposed by the version
115 dta file format. Attempting to write *Stata* dta files with strings
longer than 244 characters raises a ``ValueError``.

.. warning::

*Stata* data files only support text labels for categorical data. Exporting
data frames containing categorical data will convert non-string categorical values
to strings.

.. _io.stata_reader:

Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.15.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ API changes
Enhancements
~~~~~~~~~~~~

- Added ability to export Categorical data to Stata (:issue:`8633`).

.. _whatsnew_0152.performance:

Expand Down
256 changes: 230 additions & 26 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import struct
from dateutil.relativedelta import relativedelta
from pandas.core.base import StringMixin
from pandas.core.categorical import Categorical
from pandas.core.frame import DataFrame
from pandas.core.series import Series
from pandas.core.categorical import Categorical
import datetime
from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
zip
zip, BytesIO
import pandas.core.common as com
from pandas.io.common import get_filepath_or_buffer
from pandas.lib import max_len_string_array, infer_dtype
Expand Down Expand Up @@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning):
conversion range. This may result in a loss of precision in the saved data.
"""

class ValueLabelTypeMismatch(Warning):
pass

value_label_mismatch_doc = """
Stata value labels (pandas categories) must be strings. Column {0} contains
non-string labels which will be converted to strings. Please check that the
Stata data file created has not lost information due to duplicate labels.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the connection between the need for checking for duplicate data and the fact that the labels are stringified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String representations could conflict with either each other, if two objects have the same str, or a stringified value could conflict with an existing cat label. For example, if categorical labels contained 1 and "1" .

"""


class InvalidColumnName(Warning):
pass
Expand Down Expand Up @@ -425,6 +434,131 @@ def _cast_to_stata_types(data):
return data


class StataValueLabel(object):
"""
Parse a categorical column and prepare formatted output

Parameters
-----------
value : int8, int16, int32, float32 or float64
The Stata missing value code

Attributes
----------
string : string
String representation of the Stata missing value
value : int8, int16, int32, float32 or float64
The original encoded missing value

Methods
-------
generate_value_label

"""

def __init__(self, catarray):

self.labname = catarray.name

categories = catarray.cat.categories
self.value_labels = list(zip(np.arange(len(categories)), categories))
self.value_labels.sort(key=lambda x: x[0])
self.text_len = np.int32(0)
self.off = []
self.val = []
self.txt = []
self.n = 0

# Compute lengths and setup lists of offsets and labels
for vl in self.value_labels:
category = vl[1]
if not isinstance(category, string_types):
category = str(category)
import warnings
warnings.warn(value_label_mismatch_doc.format(catarray.name),
ValueLabelTypeMismatch)

self.off.append(self.text_len)
self.text_len += len(category) + 1 # +1 for the padding
self.val.append(vl[0])
self.txt.append(category)
self.n += 1

if self.text_len > 32000:
raise ValueError('Stata value labels for a single variable must '
'have a combined length less than 32,000 '
'characters.')

# Ensure int32
self.off = np.array(self.off, dtype=np.int32)
self.val = np.array(self.val, dtype=np.int32)

# Total length
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len

def _encode(self, s):
"""
Python 3 compatability shim
"""
if compat.PY3:
return s.encode(self._encoding)
else:
return s

def generate_value_label(self, byteorder, encoding):
"""
Parameters
----------
byteorder : str
Byte order of the output
encoding : str
File encoding

Returns
-------
value_label : bytes
Bytes containing the formatted value label
"""

self._encoding = encoding
bio = BytesIO()
null_string = '\x00'
null_byte = b'\x00'

# len
bio.write(struct.pack(byteorder + 'i', self.len))

# labname
labname = self._encode(_pad_bytes(self.labname[:32], 33))
bio.write(labname)

# padding - 3 bytes
for i in range(3):
bio.write(struct.pack('c', null_byte))

# value_label_table
# n - int32
bio.write(struct.pack(byteorder + 'i', self.n))

# textlen - int32
bio.write(struct.pack(byteorder + 'i', self.text_len))

# off - int32 array (n elements)
for offset in self.off:
bio.write(struct.pack(byteorder + 'i', offset))

# val - int32 array (n elements)
for value in self.val:
bio.write(struct.pack(byteorder + 'i', value))

# txt - Text labels, null terminated
for text in self.txt:
bio.write(self._encode(text + null_string))

bio.seek(0)
return bio.read()


class StataMissingValue(StringMixin):
"""
An observation's missing value.
Expand Down Expand Up @@ -477,25 +611,31 @@ class StataMissingValue(StringMixin):
for i in range(1, 27):
MISSING_VALUES[i + b] = '.' + chr(96 + i)

base = b'\x00\x00\x00\x7f'
float32_base = b'\x00\x00\x00\x7f'
increment = struct.unpack('<i', b'\x00\x08\x00\x00')[0]
for i in range(27):
value = struct.unpack('<f', base)[0]
value = struct.unpack('<f', float32_base)[0]
MISSING_VALUES[value] = '.'
if i > 0:
MISSING_VALUES[value] += chr(96 + i)
int_value = struct.unpack('<i', struct.pack('<f', value))[0] + increment
base = struct.pack('<i', int_value)
float32_base = struct.pack('<i', int_value)

base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
float64_base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
increment = struct.unpack('q', b'\x00\x00\x00\x00\x00\x01\x00\x00')[0]
for i in range(27):
value = struct.unpack('<d', base)[0]
value = struct.unpack('<d', float64_base)[0]
MISSING_VALUES[value] = '.'
if i > 0:
MISSING_VALUES[value] += chr(96 + i)
int_value = struct.unpack('q', struct.pack('<d', value))[0] + increment
base = struct.pack('q', int_value)
float64_base = struct.pack('q', int_value)

BASE_MISSING_VALUES = {'int8': 101,
'int16': 32741,
'int32': 2147483621,
'float32': struct.unpack('<f', float32_base)[0],
'float64': struct.unpack('<d', float64_base)[0]}

def __init__(self, value):
self._value = value
Expand All @@ -518,6 +658,22 @@ def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.string == other.string and self.value == other.value)

@classmethod
def get_base_missing_value(cls, dtype):
if dtype == np.int8:
value = cls.BASE_MISSING_VALUES['int8']
elif dtype == np.int16:
value = cls.BASE_MISSING_VALUES['int16']
elif dtype == np.int32:
value = cls.BASE_MISSING_VALUES['int32']
elif dtype == np.float32:
value = cls.BASE_MISSING_VALUES['float32']
elif dtype == np.float64:
value = cls.BASE_MISSING_VALUES['float64']
else:
raise ValueError('Unsupported dtype')
return value


class StataParser(object):
_default_encoding = 'cp1252'
Expand Down Expand Up @@ -1111,10 +1267,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
umissing, umissing_loc = np.unique(series[missing],
return_inverse=True)
replacement = Series(series, dtype=np.object)
for i, um in enumerate(umissing):
for j, um in enumerate(umissing):
missing_value = StataMissingValue(um)

loc = missing_loc[umissing_loc == i]
loc = missing_loc[umissing_loc == j]
replacement.iloc[loc] = missing_value
else: # All replacements are identical
dtype = series.dtype
Expand Down Expand Up @@ -1390,6 +1546,45 @@ def _write(self, to_write):
else:
self._file.write(to_write)

def _prepare_categoricals(self, data):
"""Check for categorigal columns, retain categorical information for
Stata file and convert categorical data to int"""

is_cat = [com.is_categorical_dtype(data[col]) for col in data]
self._is_col_cat = is_cat
self._value_labels = []
if not any(is_cat):
return data

get_base_missing_value = StataMissingValue.get_base_missing_value
index = data.index
data_formatted = []
for col, col_is_cat in zip(data, is_cat):
if col_is_cat:
self._value_labels.append(StataValueLabel(data[col]))
dtype = data[col].cat.codes.dtype
if dtype == np.int64:
raise ValueError('It is not possible to export int64-based '
'categorical data to Stata.')
values = data[col].cat.codes.values.copy()

# Upcast if needed so that correct missing values can be set
if values.max() >= get_base_missing_value(dtype):
if dtype == np.int8:
dtype = np.int16
elif dtype == np.int16:
dtype = np.int32
else:
dtype = np.float64
values = np.array(values, dtype=dtype)

# Replace missing values with Stata missing value for type
values[values == -1] = get_base_missing_value(dtype)
data_formatted.append((col, values, index))

else:
data_formatted.append((col, data[col]))
return DataFrame.from_items(data_formatted)

def _replace_nans(self, data):
# return data
Expand Down Expand Up @@ -1480,27 +1675,26 @@ def _check_column_names(self, data):
def _prepare_pandas(self, data):
#NOTE: we might need a different API / class for pandas objects so
# we can set different semantics - handle this with a PR to pandas.io
class DataFrameRowIter(object):
def __init__(self, data):
self.data = data

def __iter__(self):
for row in data.itertuples():
# First element is index, so remove
yield row[1:]

if self._write_index:
data = data.reset_index()
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)

# Ensure column names are strings
data = self._check_column_names(data)

# Check columns for compatibility with stata, upcast if necessary
data = _cast_to_stata_types(data)

# Replace NaNs with Stata missing values
data = self._replace_nans(data)
self.datarows = DataFrameRowIter(data)

# Convert categoricals to int data, and strip labels
data = self._prepare_categoricals(data)

self.nobs, self.nvar = data.shape
self.data = data
self.varlist = data.columns.tolist()

dtypes = data.dtypes
if self._convert_dates is not None:
self._convert_dates = _maybe_convert_to_int_keys(
Expand All @@ -1515,6 +1709,7 @@ def __iter__(self):
self.fmtlist = []
for col, dtype in dtypes.iteritems():
self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col]))

# set the given format for the datetime cols
if self._convert_dates is not None:
for key in self._convert_dates:
Expand All @@ -1529,8 +1724,14 @@ def write_file(self):
self._write(_pad_bytes("", 5))
self._prepare_data()
self._write_data()
self._write_value_labels()
self._file.close()

def _write_value_labels(self):
for vl in self._value_labels:
self._file.write(vl.generate_value_label(self._byteorder,
self._encoding))

def _write_header(self, data_label=None, time_stamp=None):
byteorder = self._byteorder
# ds_format - just use 114
Expand Down Expand Up @@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
self._write(_pad_bytes(fmt, 49))

# lbllist, 33*nvar, char array
#NOTE: this is where you could get fancy with pandas categorical type
for i in range(nvar):
self._write(_pad_bytes("", 33))
# Use variable name when categorical
if self._is_col_cat[i]:
name = self.varlist[i]
name = self._null_terminate(name, True)
name = _pad_bytes(name[:32], 33)
self._write(name)
else: # Default is empty label
self._write(_pad_bytes("", 33))

def _write_variable_labels(self, labels=None):
nvar = self.nvar
Expand Down Expand Up @@ -1624,9 +1831,6 @@ def _prepare_data(self):
data_cols.append(data[col].values)
dtype = np.dtype(dtype)

# 3. Convert to record array

# data.to_records(index=False, convert_datetime64=False)
if has_strings:
self.data = np.fromiter(zip(*data_cols), dtype=dtype)
else:
Expand Down
Loading