Skip to content

Commit 204b50e

Browse files
committed
ENH: Add categorical support for Stata export
Add support for exporting DataFrames containing categorical data. closes #8633 xref #7621
1 parent 54e237b commit 204b50e

File tree

4 files changed

+313
-29
lines changed

4 files changed

+313
-29
lines changed

doc/source/io.rst

+6
Original file line numberDiff line numberDiff line change
@@ -3626,12 +3626,18 @@ outside of this range, the data is cast to ``int16``.
36263626
if ``int64`` values are larger than 2**53.
36273627

36283628
.. warning::
3629+
36293630
:class:`~pandas.io.stata.StataWriter`` and
36303631
:func:`~pandas.core.frame.DataFrame.to_stata` only support fixed width
36313632
strings containing up to 244 characters, a limitation imposed by the version
36323633
115 dta file format. Attempting to write *Stata* dta files with strings
36333634
longer than 244 characters raises a ``ValueError``.
36343635

3636+
.. warning::
3637+
3638+
*Stata* data files only support text labels for categorical data. Exporting
3639+
data frames containing categorical data will convert non-string categorical values
3640+
to strings.
36353641

36363642
.. _io.stata_reader:
36373643

doc/source/whatsnew/v0.15.2.txt

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ API changes
4141
Enhancements
4242
~~~~~~~~~~~~
4343

44+
- Added ability to export Categorical data to Stata (:issue:`8633`).
4445

4546
.. _whatsnew_0152.performance:
4647

pandas/io/stata.py

+230-26
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import struct
1616
from dateutil.relativedelta import relativedelta
1717
from pandas.core.base import StringMixin
18+
from pandas.core.categorical import Categorical
1819
from pandas.core.frame import DataFrame
1920
from pandas.core.series import Series
20-
from pandas.core.categorical import Categorical
2121
import datetime
2222
from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex
2323
from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \
24-
zip
24+
zip, BytesIO
2525
import pandas.core.common as com
2626
from pandas.io.common import get_filepath_or_buffer
2727
from pandas.lib import max_len_string_array, infer_dtype
@@ -336,6 +336,15 @@ class PossiblePrecisionLoss(Warning):
336336
conversion range. This may result in a loss of precision in the saved data.
337337
"""
338338

339+
class ValueLabelTypeMismatch(Warning):
340+
pass
341+
342+
value_label_mismatch_doc = """
343+
Stata value labels (pandas categories) must be strings. Column {0} contains
344+
non-string labels which will be converted to strings. Please check that the
345+
Stata data file created has not lost information due to duplicate labels.
346+
"""
347+
339348

340349
class InvalidColumnName(Warning):
341350
pass
@@ -425,6 +434,131 @@ def _cast_to_stata_types(data):
425434
return data
426435

427436

437+
class StataValueLabel(object):
438+
"""
439+
Parse a categorical column and prepare formatted output
440+
441+
Parameters
442+
-----------
443+
value : int8, int16, int32, float32 or float64
444+
The Stata missing value code
445+
446+
Attributes
447+
----------
448+
string : string
449+
String representation of the Stata missing value
450+
value : int8, int16, int32, float32 or float64
451+
The original encoded missing value
452+
453+
Methods
454+
-------
455+
generate_value_label
456+
457+
"""
458+
459+
def __init__(self, catarray):
460+
461+
self.labname = catarray.name
462+
463+
categories = catarray.cat.categories
464+
self.value_labels = list(zip(np.arange(len(categories)), categories))
465+
self.value_labels.sort(key=lambda x: x[0])
466+
self.text_len = np.int32(0)
467+
self.off = []
468+
self.val = []
469+
self.txt = []
470+
self.n = 0
471+
472+
# Compute lengths and setup lists of offsets and labels
473+
for vl in self.value_labels:
474+
category = vl[1]
475+
if not isinstance(category, string_types):
476+
category = str(category)
477+
import warnings
478+
warnings.warn(value_label_mismatch_doc.format(catarray.name),
479+
ValueLabelTypeMismatch)
480+
481+
self.off.append(self.text_len)
482+
self.text_len += len(category) + 1 # +1 for the padding
483+
self.val.append(vl[0])
484+
self.txt.append(category)
485+
self.n += 1
486+
487+
if self.text_len > 32000:
488+
raise ValueError('Stata value labels for a single variable must '
489+
'have a combined length less than 32,000 '
490+
'characters.')
491+
492+
# Ensure int32
493+
self.off = np.array(self.off, dtype=np.int32)
494+
self.val = np.array(self.val, dtype=np.int32)
495+
496+
# Total length
497+
self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len
498+
499+
def _encode(self, s):
500+
"""
501+
Python 3 compatability shim
502+
"""
503+
if compat.PY3:
504+
return s.encode(self._encoding)
505+
else:
506+
return s
507+
508+
def generate_value_label(self, byteorder, encoding):
509+
"""
510+
Parameters
511+
----------
512+
byteorder : str
513+
Byte order of the output
514+
encoding : str
515+
File encoding
516+
517+
Returns
518+
-------
519+
value_label : bytes
520+
Bytes containing the formatted value label
521+
"""
522+
523+
self._encoding = encoding
524+
bio = BytesIO()
525+
null_string = '\x00'
526+
null_byte = b'\x00'
527+
528+
# len
529+
bio.write(struct.pack(byteorder + 'i', self.len))
530+
531+
# labname
532+
labname = self._encode(_pad_bytes(self.labname[:32], 33))
533+
bio.write(labname)
534+
535+
# padding - 3 bytes
536+
for i in range(3):
537+
bio.write(struct.pack('c', null_byte))
538+
539+
# value_label_table
540+
# n - int32
541+
bio.write(struct.pack(byteorder + 'i', self.n))
542+
543+
# textlen - int32
544+
bio.write(struct.pack(byteorder + 'i', self.text_len))
545+
546+
# off - int32 array (n elements)
547+
for offset in self.off:
548+
bio.write(struct.pack(byteorder + 'i', offset))
549+
550+
# val - int32 array (n elements)
551+
for value in self.val:
552+
bio.write(struct.pack(byteorder + 'i', value))
553+
554+
# txt - Text labels, null terminated
555+
for text in self.txt:
556+
bio.write(self._encode(text + null_string))
557+
558+
bio.seek(0)
559+
return bio.read()
560+
561+
428562
class StataMissingValue(StringMixin):
429563
"""
430564
An observation's missing value.
@@ -477,25 +611,31 @@ class StataMissingValue(StringMixin):
477611
for i in range(1, 27):
478612
MISSING_VALUES[i + b] = '.' + chr(96 + i)
479613

480-
base = b'\x00\x00\x00\x7f'
614+
float32_base = b'\x00\x00\x00\x7f'
481615
increment = struct.unpack('<i', b'\x00\x08\x00\x00')[0]
482616
for i in range(27):
483-
value = struct.unpack('<f', base)[0]
617+
value = struct.unpack('<f', float32_base)[0]
484618
MISSING_VALUES[value] = '.'
485619
if i > 0:
486620
MISSING_VALUES[value] += chr(96 + i)
487621
int_value = struct.unpack('<i', struct.pack('<f', value))[0] + increment
488-
base = struct.pack('<i', int_value)
622+
float32_base = struct.pack('<i', int_value)
489623

490-
base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
624+
float64_base = b'\x00\x00\x00\x00\x00\x00\xe0\x7f'
491625
increment = struct.unpack('q', b'\x00\x00\x00\x00\x00\x01\x00\x00')[0]
492626
for i in range(27):
493-
value = struct.unpack('<d', base)[0]
627+
value = struct.unpack('<d', float64_base)[0]
494628
MISSING_VALUES[value] = '.'
495629
if i > 0:
496630
MISSING_VALUES[value] += chr(96 + i)
497631
int_value = struct.unpack('q', struct.pack('<d', value))[0] + increment
498-
base = struct.pack('q', int_value)
632+
float64_base = struct.pack('q', int_value)
633+
634+
BASE_MISSING_VALUES = {'int8': 101,
635+
'int16': 32741,
636+
'int32': 2147483621,
637+
'float32': struct.unpack('<f', float32_base)[0],
638+
'float64': struct.unpack('<d', float64_base)[0]}
499639

500640
def __init__(self, value):
501641
self._value = value
@@ -518,6 +658,22 @@ def __eq__(self, other):
518658
return (isinstance(other, self.__class__)
519659
and self.string == other.string and self.value == other.value)
520660

661+
@classmethod
662+
def get_base_missing_value(cls, dtype):
663+
if dtype == np.int8:
664+
value = cls.BASE_MISSING_VALUES['int8']
665+
elif dtype == np.int16:
666+
value = cls.BASE_MISSING_VALUES['int16']
667+
elif dtype == np.int32:
668+
value = cls.BASE_MISSING_VALUES['int32']
669+
elif dtype == np.float32:
670+
value = cls.BASE_MISSING_VALUES['float32']
671+
elif dtype == np.float64:
672+
value = cls.BASE_MISSING_VALUES['float64']
673+
else:
674+
raise ValueError('Unsupported dtype')
675+
return value
676+
521677

522678
class StataParser(object):
523679
_default_encoding = 'cp1252'
@@ -1111,10 +1267,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
11111267
umissing, umissing_loc = np.unique(series[missing],
11121268
return_inverse=True)
11131269
replacement = Series(series, dtype=np.object)
1114-
for i, um in enumerate(umissing):
1270+
for j, um in enumerate(umissing):
11151271
missing_value = StataMissingValue(um)
11161272

1117-
loc = missing_loc[umissing_loc == i]
1273+
loc = missing_loc[umissing_loc == j]
11181274
replacement.iloc[loc] = missing_value
11191275
else: # All replacements are identical
11201276
dtype = series.dtype
@@ -1390,6 +1546,45 @@ def _write(self, to_write):
13901546
else:
13911547
self._file.write(to_write)
13921548

1549+
def _prepare_categoricals(self, data):
1550+
"""Check for categorigal columns, retain categorical information for
1551+
Stata file and convert categorical data to int"""
1552+
1553+
is_cat = [com.is_categorical_dtype(data[col]) for col in data]
1554+
self._is_col_cat = is_cat
1555+
self._value_labels = []
1556+
if not any(is_cat):
1557+
return data
1558+
1559+
get_base_missing_value = StataMissingValue.get_base_missing_value
1560+
index = data.index
1561+
data_formatted = []
1562+
for col, col_is_cat in zip(data, is_cat):
1563+
if col_is_cat:
1564+
self._value_labels.append(StataValueLabel(data[col]))
1565+
dtype = data[col].cat.codes.dtype
1566+
if dtype == np.int64:
1567+
raise ValueError('It is not possible to export int64-based '
1568+
'categorical data to Stata.')
1569+
values = data[col].cat.codes.values.copy()
1570+
1571+
# Upcast if needed so that correct missing values can be set
1572+
if values.max() >= get_base_missing_value(dtype):
1573+
if dtype == np.int8:
1574+
dtype = np.int16
1575+
elif dtype == np.int16:
1576+
dtype = np.int32
1577+
else:
1578+
dtype = np.float64
1579+
values = np.array(values, dtype=dtype)
1580+
1581+
# Replace missing values with Stata missing value for type
1582+
values[values == -1] = get_base_missing_value(dtype)
1583+
data_formatted.append((col, values, index))
1584+
1585+
else:
1586+
data_formatted.append((col, data[col]))
1587+
return DataFrame.from_items(data_formatted)
13931588

13941589
def _replace_nans(self, data):
13951590
# return data
@@ -1480,27 +1675,26 @@ def _check_column_names(self, data):
14801675
def _prepare_pandas(self, data):
14811676
#NOTE: we might need a different API / class for pandas objects so
14821677
# we can set different semantics - handle this with a PR to pandas.io
1483-
class DataFrameRowIter(object):
1484-
def __init__(self, data):
1485-
self.data = data
1486-
1487-
def __iter__(self):
1488-
for row in data.itertuples():
1489-
# First element is index, so remove
1490-
yield row[1:]
14911678

14921679
if self._write_index:
14931680
data = data.reset_index()
1494-
# Check columns for compatibility with stata
1495-
data = _cast_to_stata_types(data)
1681+
14961682
# Ensure column names are strings
14971683
data = self._check_column_names(data)
1684+
1685+
# Check columns for compatibility with stata, upcast if necessary
1686+
data = _cast_to_stata_types(data)
1687+
14981688
# Replace NaNs with Stata missing values
14991689
data = self._replace_nans(data)
1500-
self.datarows = DataFrameRowIter(data)
1690+
1691+
# Convert categoricals to int data, and strip labels
1692+
data = self._prepare_categoricals(data)
1693+
15011694
self.nobs, self.nvar = data.shape
15021695
self.data = data
15031696
self.varlist = data.columns.tolist()
1697+
15041698
dtypes = data.dtypes
15051699
if self._convert_dates is not None:
15061700
self._convert_dates = _maybe_convert_to_int_keys(
@@ -1515,6 +1709,7 @@ def __iter__(self):
15151709
self.fmtlist = []
15161710
for col, dtype in dtypes.iteritems():
15171711
self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col]))
1712+
15181713
# set the given format for the datetime cols
15191714
if self._convert_dates is not None:
15201715
for key in self._convert_dates:
@@ -1529,8 +1724,14 @@ def write_file(self):
15291724
self._write(_pad_bytes("", 5))
15301725
self._prepare_data()
15311726
self._write_data()
1727+
self._write_value_labels()
15321728
self._file.close()
15331729

1730+
def _write_value_labels(self):
1731+
for vl in self._value_labels:
1732+
self._file.write(vl.generate_value_label(self._byteorder,
1733+
self._encoding))
1734+
15341735
def _write_header(self, data_label=None, time_stamp=None):
15351736
byteorder = self._byteorder
15361737
# ds_format - just use 114
@@ -1585,9 +1786,15 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
15851786
self._write(_pad_bytes(fmt, 49))
15861787

15871788
# lbllist, 33*nvar, char array
1588-
#NOTE: this is where you could get fancy with pandas categorical type
15891789
for i in range(nvar):
1590-
self._write(_pad_bytes("", 33))
1790+
# Use variable name when categorical
1791+
if self._is_col_cat[i]:
1792+
name = self.varlist[i]
1793+
name = self._null_terminate(name, True)
1794+
name = _pad_bytes(name[:32], 33)
1795+
self._write(name)
1796+
else: # Default is empty label
1797+
self._write(_pad_bytes("", 33))
15911798

15921799
def _write_variable_labels(self, labels=None):
15931800
nvar = self.nvar
@@ -1624,9 +1831,6 @@ def _prepare_data(self):
16241831
data_cols.append(data[col].values)
16251832
dtype = np.dtype(dtype)
16261833

1627-
# 3. Convert to record array
1628-
1629-
# data.to_records(index=False, convert_datetime64=False)
16301834
if has_strings:
16311835
self.data = np.fromiter(zip(*data_cols), dtype=dtype)
16321836
else:

0 commit comments

Comments
 (0)