diff --git a/doc/source/whatsnew/v0.23.0.txt b/doc/source/whatsnew/v0.23.0.txt
index 207589092dd00..e1f6e2c7d12b5 100644
--- a/doc/source/whatsnew/v0.23.0.txt
+++ b/doc/source/whatsnew/v0.23.0.txt
@@ -503,6 +503,7 @@ Other Enhancements
- Updated :meth:`DataFrame.to_gbq` and :meth:`pandas.read_gbq` signature and documentation to reflect changes from
the Pandas-GBQ library version 0.4.0. Adds intersphinx mapping to Pandas-GBQ
library. (:issue:`20564`)
+- Added new writer for exporting Stata dta files in version 117, ``StataWriter117``. This format supports exporting strings with lengths up to 2,000,000 characters (:issue:`16450`)
- :func:`to_hdf` and :func:`read_hdf` now accept an ``errors`` keyword argument to control encoding error handling (:issue:`20835`)
.. _whatsnew_0230.api_breaking:
diff --git a/pandas/core/frame.py b/pandas/core/frame.py
index d7efd777f4176..f6a57fc5e7ba6 100644
--- a/pandas/core/frame.py
+++ b/pandas/core/frame.py
@@ -1769,27 +1769,28 @@ def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='',
def to_stata(self, fname, convert_dates=None, write_index=True,
encoding="latin-1", byteorder=None, time_stamp=None,
- data_label=None, variable_labels=None):
+ data_label=None, variable_labels=None, version=114,
+ convert_strl=None):
"""
- A class for writing Stata binary dta files from array-like objects
+ Export Stata binary dta files.
Parameters
----------
fname : str or buffer
- String path of file-like object
+ String path of file-like object.
convert_dates : dict
Dictionary mapping columns containing datetime types to stata
internal format to use when writing the dates. Options are 'tc',
'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either an integer
or a name. Datetime columns that do not have a conversion type
specified will be converted to 'tc'. Raises NotImplementedError if
- a datetime column has timezone information
+ a datetime column has timezone information.
write_index : bool
Write the index to Stata dataset.
encoding : str
- Default is latin-1. Unicode is not supported
+ Default is latin-1. Unicode is not supported.
byteorder : str
- Can be ">", "<", "little", or "big". default is `sys.byteorder`
+ Can be ">", "<", "little", or "big". default is `sys.byteorder`.
time_stamp : datetime
A datetime to use as file creation date. Default is the current
time.
@@ -1801,6 +1802,23 @@ def to_stata(self, fname, convert_dates=None, write_index=True,
.. versionadded:: 0.19.0
+ version : {114, 117}
+ Version to use in the output dta file. Version 114 can be used
+ read by Stata 10 and later. Version 117 can be read by Stata 13
+ or later. Version 114 limits string variables to 244 characters or
+ fewer while 117 allows strings with lengths up to 2,000,000
+ characters.
+
+ .. versionadded:: 0.23.0
+
+ convert_strl : list, optional
+ List of column names to convert to string columns to Stata StrL
+ format. Only available if version is 117. Storing strings in the
+ StrL format can produce smaller dta files if strings have more than
+ 8 characters and values are repeated.
+
+ .. versionadded:: 0.23.0
+
Raises
------
NotImplementedError
@@ -1814,6 +1832,12 @@ def to_stata(self, fname, convert_dates=None, write_index=True,
.. versionadded:: 0.19.0
+ See Also
+ --------
+ pandas.read_stata : Import Stata data files
+ pandas.io.stata.StataWriter : low-level writer for Stata data files
+ pandas.io.stata.StataWriter117 : low-level writer for version 117 files
+
Examples
--------
>>> data.to_stata('./data_file.dta')
@@ -1832,12 +1856,23 @@ def to_stata(self, fname, convert_dates=None, write_index=True,
>>> writer = StataWriter('./date_data_file.dta', data, {2 : 'tw'})
>>> writer.write_file()
"""
- from pandas.io.stata import StataWriter
- writer = StataWriter(fname, self, convert_dates=convert_dates,
+ kwargs = {}
+ if version not in (114, 117):
+ raise ValueError('Only formats 114 and 117 supported.')
+ if version == 114:
+ if convert_strl is not None:
+ raise ValueError('strl support is only available when using '
+ 'format 117')
+ from pandas.io.stata import StataWriter as statawriter
+ else:
+ from pandas.io.stata import StataWriter117 as statawriter
+ kwargs['convert_strl'] = convert_strl
+
+ writer = statawriter(fname, self, convert_dates=convert_dates,
encoding=encoding, byteorder=byteorder,
time_stamp=time_stamp, data_label=data_label,
write_index=write_index,
- variable_labels=variable_labels)
+ variable_labels=variable_labels, **kwargs)
writer.write_file()
def to_feather(self, fname):
diff --git a/pandas/io/stata.py b/pandas/io/stata.py
index 9646831cb612c..8f91c7a497e2d 100644
--- a/pandas/io/stata.py
+++ b/pandas/io/stata.py
@@ -25,8 +25,8 @@
from pandas import compat, to_timedelta, to_datetime, isna, DatetimeIndex
from pandas.compat import (lrange, lmap, lzip, text_type, string_types, range,
zip, BytesIO)
-from pandas.core.base import StringMixin
from pandas.core.arrays import Categorical
+from pandas.core.base import StringMixin
from pandas.core.dtypes.common import (is_categorical_dtype, _ensure_object,
is_datetime64_dtype)
from pandas.core.frame import DataFrame
@@ -45,9 +45,9 @@
_statafile_processing_params1 = """\
convert_dates : boolean, defaults to True
- Convert date variables to DataFrame time values
+ Convert date variables to DataFrame time values.
convert_categoricals : boolean, defaults to True
- Read value labels and convert columns to Categorical/Factor variables"""
+ Read value labels and convert columns to Categorical/Factor variables."""
_encoding_params = """\
encoding : string, None or encoding
@@ -55,7 +55,7 @@
_statafile_processing_params2 = """\
index_col : string, optional, default: None
- Column to set as index
+ Column to set as index.
convert_missing : boolean, defaults to False
Flag indicating whether to convert missing values to their Stata
representations. If False, missing values are replaced with nan.
@@ -64,28 +64,29 @@
StataMissingValue objects.
preserve_dtypes : boolean, defaults to True
Preserve Stata datatypes. If False, numeric data are upcast to pandas
- default types for foreign data (float64 or int64)
+ default types for foreign data (float64 or int64).
columns : list or None
Columns to retain. Columns will be returned in the given order. None
- returns all columns
+ returns all columns.
order_categoricals : boolean, defaults to True
Flag indicating whether converted categorical data are ordered."""
_chunksize_params = """\
chunksize : int, default None
Return StataReader object for iterations, returns chunks with
- given number of lines"""
+ given number of lines."""
_iterator_params = """\
iterator : boolean, default False
- Return StataReader object"""
+ Return StataReader object."""
-_read_stata_doc = """Read Stata file into DataFrame
+_read_stata_doc = """
+Read Stata file into DataFrame.
Parameters
----------
filepath_or_buffer : string or file-like object
- Path to .dta file or object implementing a binary read() functions
+ Path to .dta file or object implementing a binary read() functions.
%s
%s
%s
@@ -96,17 +97,23 @@
-------
DataFrame or StataReader
+See Also
+--------
+pandas.io.stata.StataReader : low-level reader for Stata data files
+pandas.DataFrame.to_stata: export Stata data files
+
Examples
--------
Read a Stata dta file:
->>> df = pandas.read_stata('filename.dta')
+>>> import pandas as pd
+>>> df = pd.read_stata('filename.dta')
Read a Stata dta file in 10,000 line chunks:
->>> itr = pandas.read_stata('filename.dta', chunksize=10000)
+>>> itr = pd.read_stata('filename.dta', chunksize=10000)
>>> for chunk in itr:
->>> do_something(chunk)
+... do_something(chunk)
""" % (_statafile_processing_params1, _encoding_params,
_statafile_processing_params2, _chunksize_params,
_iterator_params)
@@ -127,7 +134,6 @@
DataFrame
""" % (_statafile_processing_params1, _statafile_processing_params2)
-
_read_method_doc = """\
Reads observations from Stata file, converting them into a dataframe
@@ -149,8 +155,11 @@
Parameters
----------
-path_or_buf : string or file-like object
- Path to .dta file or object implementing a binary read() functions
+path_or_buf : path (string), buffer or path object
+ string, path object (pathlib.Path or py._path.local.LocalPath) or object
+ implementing a binary read() functions.
+
+ .. versionadded:: 0.23.0 support for pathlib, py.path.
%s
%s
%s
@@ -908,10 +917,10 @@ def __init__(self, encoding):
}
self.OLD_TYPE_MAPPING = {
- 98: 251, # byte
+ 98: 251, # byte
105: 252, # int
108: 253, # long
- 102: 254 # float
+ 102: 254 # float
# don't know old code for double
}
@@ -992,7 +1001,7 @@ def __init__(self, path_or_buf, convert_dates=True,
path_or_buf, encoding=self._default_encoding
)
- if isinstance(path_or_buf, (str, compat.text_type, bytes)):
+ if isinstance(path_or_buf, (str, text_type, bytes)):
self.path_or_buf = open(path_or_buf, 'rb')
else:
# Copy to BytesIO, and ensure no encoding
@@ -1041,7 +1050,7 @@ def _read_new_header(self, first_char):
if self.format_version not in [117, 118]:
raise ValueError(_version_error)
self.path_or_buf.read(21) #
- self.byteorder = self.path_or_buf.read(3) == "MSF" and '>' or '<'
+ self.byteorder = self.path_or_buf.read(3) == b'MSF' and '>' or '<'
self.path_or_buf.read(15) #
self.nvar = struct.unpack(self.byteorder + 'H',
self.path_or_buf.read(2))[0]
@@ -1805,38 +1814,37 @@ def _dtype_to_stata_type(dtype, column):
the dta spec.
1 - 244 are strings of this length
Pandas Stata
- 251 - chr(251) - for int8 byte
- 252 - chr(252) - for int16 int
- 253 - chr(253) - for int32 long
- 254 - chr(254) - for float32 float
- 255 - chr(255) - for double double
+ 251 - for int8 byte
+ 252 - for int16 int
+ 253 - for int32 long
+ 254 - for float32 float
+ 255 - for double double
If there are dates to convert, then dtype will already have the correct
type inserted.
"""
# TODO: expand to handle datetime to integer conversion
- if dtype.type == np.string_:
- return chr(dtype.itemsize)
- elif dtype.type == np.object_: # try to coerce it to the biggest string
- # not memory efficient, what else could we
- # do?
+ if dtype.type == np.object_: # try to coerce it to the biggest string
+ # not memory efficient, what else could we
+ # do?
itemsize = max_len_string_array(_ensure_object(column.values))
- return chr(max(itemsize, 1))
+ return max(itemsize, 1)
elif dtype == np.float64:
- return chr(255)
+ return 255
elif dtype == np.float32:
- return chr(254)
+ return 254
elif dtype == np.int32:
- return chr(253)
+ return 253
elif dtype == np.int16:
- return chr(252)
+ return 252
elif dtype == np.int8:
- return chr(251)
+ return 251
else: # pragma : no cover
raise NotImplementedError("Data type %s not supported." % dtype)
-def _dtype_to_default_stata_fmt(dtype, column):
+def _dtype_to_default_stata_fmt(dtype, column, dta_version=114,
+ force_strl=False):
"""
Maps numpy dtype to stata's default format for this type. Not terribly
important since users can change this in Stata. Semantics are
@@ -1849,17 +1857,27 @@ def _dtype_to_default_stata_fmt(dtype, column):
int32 -> "%12.0g"
int16 -> "%8.0g"
int8 -> "%8.0g"
+ strl -> "%9s"
"""
# TODO: Refactor to combine type with format
# TODO: expand this to handle a default datetime format?
+ if dta_version < 117:
+ max_str_len = 244
+ else:
+ max_str_len = 2045
+ if force_strl:
+ return '%9s'
if dtype.type == np.object_:
inferred_dtype = infer_dtype(column.dropna())
if not (inferred_dtype in ('string', 'unicode') or
len(column) == 0):
raise ValueError('Writing general object arrays is not supported')
itemsize = max_len_string_array(_ensure_object(column.values))
- if itemsize > 244:
- raise ValueError(excessive_string_length_error % column.name)
+ if itemsize > max_str_len:
+ if dta_version >= 117:
+ return '%9s'
+ else:
+ raise ValueError(excessive_string_length_error % column.name)
return "%" + str(max(itemsize, 1)) + "s"
elif dtype == np.float64:
return "%10.0g"
@@ -1879,8 +1897,12 @@ class StataWriter(StataParser):
Parameters
----------
- fname : str or buffer
- String path of file-like object
+ fname : path (string), buffer or path object
+ string, path object (pathlib.Path or py._path.local.LocalPath) or
+ object implementing a binary write() functions.
+
+ .. versionadded:: 0.23.0 support for pathlib, py.path.
+
data : DataFrame
Input to save
convert_dates : dict
@@ -1898,7 +1920,7 @@ class StataWriter(StataParser):
Can be ">", "<", "little", or "big". default is `sys.byteorder`
time_stamp : datetime
A datetime to use as file creation date. Default is the current time
- dataset_label : str
+ data_label : str
A label for the data set. Must be 80 characters or smaller.
variable_labels : dict
Dictionary containing columns as keys and variable labels as values.
@@ -1937,6 +1959,8 @@ class StataWriter(StataParser):
>>> writer.write_file()
"""
+ _max_string_length = 244
+
def __init__(self, fname, data, convert_dates=None, write_index=True,
encoding="latin-1", byteorder=None, time_stamp=None,
data_label=None, variable_labels=None):
@@ -1954,6 +1978,7 @@ def __init__(self, fname, data, convert_dates=None, write_index=True,
self._byteorder = _set_endianness(byteorder)
self._fname = _stringify_path(fname)
self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8}
+ self._converted_names = {}
def _write(self, to_write):
"""
@@ -2018,6 +2043,10 @@ def _replace_nans(self, data):
return data
+ def _update_strl_names(self):
+ """No-op, forward compatibility"""
+ pass
+
def _check_column_names(self, data):
"""
Checks column names to ensure that they are valid Stata column names.
@@ -2031,7 +2060,7 @@ def _check_column_names(self, data):
dates are exported, the variable name is propagated to the date
conversion dictionary
"""
- converted_names = []
+ converted_names = {}
columns = list(data.columns)
original_columns = columns[:]
@@ -2063,14 +2092,7 @@ def _check_column_names(self, data):
name = '_' + str(duplicate_var_id) + name
name = name[:min(len(name), 32)]
duplicate_var_id += 1
-
- # need to possibly encode the orig name if its unicode
- try:
- orig_name = orig_name.encode('utf-8')
- except:
- pass
- converted_names.append(
- '{0} -> {1}'.format(orig_name, name))
+ converted_names[orig_name] = name
columns[j] = name
@@ -2085,12 +2107,31 @@ def _check_column_names(self, data):
if converted_names:
import warnings
+ conversion_warning = []
+ for orig_name, name in converted_names.items():
+ # need to possibly encode the orig name if its unicode
+ try:
+ orig_name = orig_name.encode('utf-8')
+ except (UnicodeDecodeError, AttributeError):
+ pass
+ msg = '{0} -> {1}'.format(orig_name, name)
+ conversion_warning.append(msg)
- ws = invalid_name_doc.format('\n '.join(converted_names))
+ ws = invalid_name_doc.format('\n '.join(conversion_warning))
warnings.warn(ws, InvalidColumnName)
+ self._converted_names = converted_names
+ self._update_strl_names()
+
return data
+ def _set_formats_and_types(self, data, dtypes):
+ self.typlist = []
+ self.fmtlist = []
+ for col, dtype in dtypes.iteritems():
+ self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col]))
+ self.typlist.append(_dtype_to_stata_type(dtype, data[col]))
+
def _prepare_pandas(self, data):
# NOTE: we might need a different API / class for pandas objects so
# we can set different semantics - handle this with a PR to pandas.io
@@ -2134,11 +2175,7 @@ def _prepare_pandas(self, data):
)
dtypes[key] = np.dtype(new_type)
- self.typlist = []
- self.fmtlist = []
- for col, dtype in dtypes.iteritems():
- self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, data[col]))
- self.typlist.append(_dtype_to_stata_type(dtype, data[col]))
+ self._set_formats_and_types(data, dtypes)
# set the given format for the datetime cols
if self._convert_dates is not None:
@@ -2152,16 +2189,44 @@ def write_file(self):
try:
self._write_header(time_stamp=self._time_stamp,
data_label=self._data_label)
- self._write_descriptors()
+ self._write_map()
+ self._write_variable_types()
+ self._write_varnames()
+ self._write_sortlist()
+ self._write_formats()
+ self._write_value_label_names()
self._write_variable_labels()
- # write 5 zeros for expansion fields
- self._write(_pad_bytes("", 5))
+ self._write_expansion_fields()
+ self._write_characteristics()
self._prepare_data()
self._write_data()
+ self._write_strls()
self._write_value_labels()
+ self._write_file_close_tag()
+ self._write_map()
finally:
self._file.close()
+ def _write_map(self):
+ """No-op, future compatibility"""
+ pass
+
+ def _write_file_close_tag(self):
+ """No-op, future compatibility"""
+ pass
+
+ def _write_characteristics(self):
+ """No-op, future compatibility"""
+ pass
+
+ def _write_strls(self):
+ """No-op, future compatibility"""
+ pass
+
+ def _write_expansion_fields(self):
+ """Write 5 zeros for expansion fields"""
+ self._write(_pad_bytes("", 5))
+
def _write_value_labels(self):
for vl in self._value_labels:
self._file.write(vl.generate_value_label(self._byteorder,
@@ -2204,13 +2269,11 @@ def _write_header(self, data_label=None, time_stamp=None):
time_stamp.strftime(" %Y %H:%M"))
self._file.write(self._null_terminate(ts))
- def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
- fmtlist=None, lbllist=None):
- nvar = self.nvar
- # typlist, length nvar, format byte array
+ def _write_variable_types(self):
for typ in self.typlist:
- self._write(typ)
+ self._file.write(struct.pack('B', typ))
+ def _write_varnames(self):
# varlist names are checked by _check_column_names
# varlist, requires null terminated
for name in self.varlist:
@@ -2218,16 +2281,19 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
name = _pad_bytes(name[:32], 33)
self._write(name)
+ def _write_sortlist(self):
# srtlist, 2*(nvar+1), int array, encoded by byteorder
- srtlist = _pad_bytes("", 2 * (nvar + 1))
+ srtlist = _pad_bytes("", 2 * (self.nvar + 1))
self._write(srtlist)
+ def _write_formats(self):
# fmtlist, 49*nvar, char array
for fmt in self.fmtlist:
self._write(_pad_bytes(fmt, 49))
+ def _write_value_label_names(self):
# lbllist, 33*nvar, char array
- for i in range(nvar):
+ for i in range(self.nvar):
# Use variable name when categorical
if self._is_col_cat[i]:
name = self.varlist[i]
@@ -2261,6 +2327,10 @@ def _write_variable_labels(self):
else:
self._write(blank)
+ def _convert_strls(self, data):
+ """No-op, future compatibility"""
+ return data
+
def _prepare_data(self):
data = self.data
typlist = self.typlist
@@ -2271,27 +2341,34 @@ def _prepare_data(self):
if i in convert_dates:
data[col] = _datetime_to_stata_elapsed_vec(data[col],
self.fmtlist[i])
+ # 2. Convert strls
+ data = self._convert_strls(data)
- # 2. Convert bad string data to '' and pad to correct length
- dtype = []
+ # 3. Convert bad string data to '' and pad to correct length
+ dtypes = []
data_cols = []
has_strings = False
+ native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
for i, col in enumerate(data):
- typ = ord(typlist[i])
- if typ <= 244:
+ typ = typlist[i]
+ if typ <= self._max_string_length:
has_strings = True
data[col] = data[col].fillna('').apply(_pad_bytes, args=(typ,))
stype = 'S%d' % typ
- dtype.append(('c' + str(i), stype))
+ dtypes.append(('c' + str(i), stype))
string = data[col].str.encode(self._encoding)
data_cols.append(string.values.astype(stype))
else:
- dtype.append(('c' + str(i), data[col].dtype))
- data_cols.append(data[col].values)
- dtype = np.dtype(dtype)
-
- if has_strings:
- self.data = np.fromiter(zip(*data_cols), dtype=dtype)
+ values = data[col].values
+ dtype = data[col].dtype
+ if not native_byteorder:
+ dtype = dtype.newbyteorder(self._byteorder)
+ dtypes.append(('c' + str(i), dtype))
+ data_cols.append(values)
+ dtypes = np.dtype(dtypes)
+
+ if has_strings or not native_byteorder:
+ self.data = np.fromiter(zip(*data_cols), dtype=dtypes)
else:
self.data = data.to_records(index=False)
@@ -2307,3 +2384,561 @@ def _null_terminate(self, s, as_string=False):
else:
s += null_byte
return s
+
+
+def _dtype_to_stata_type_117(dtype, column, force_strl):
+ """
+ Converts dtype types to stata types. Returns the byte of the given ordinal.
+ See TYPE_MAP and comments for an explanation. This is also explained in
+ the dta spec.
+ 1 - 2045 are strings of this length
+ Pandas Stata
+ 32768 - for object strL
+ 65526 - for int8 byte
+ 65527 - for int16 int
+ 65528 - for int32 long
+ 65529 - for float32 float
+ 65530 - for double double
+
+ If there are dates to convert, then dtype will already have the correct
+ type inserted.
+ """
+ # TODO: expand to handle datetime to integer conversion
+ if force_strl:
+ return 32768
+ if dtype.type == np.object_: # try to coerce it to the biggest string
+ # not memory efficient, what else could we
+ # do?
+ itemsize = max_len_string_array(_ensure_object(column.values))
+ itemsize = max(itemsize, 1)
+ if itemsize <= 2045:
+ return itemsize
+ return 32768
+ elif dtype == np.float64:
+ return 65526
+ elif dtype == np.float32:
+ return 65527
+ elif dtype == np.int32:
+ return 65528
+ elif dtype == np.int16:
+ return 65529
+ elif dtype == np.int8:
+ return 65530
+ else: # pragma : no cover
+ raise NotImplementedError("Data type %s not supported." % dtype)
+
+
+def _bytes(s, encoding):
+ if compat.PY3:
+ return bytes(s, encoding)
+ else:
+ return bytes(s.encode(encoding))
+
+
+def _pad_bytes_new(name, length):
+ """
+ Takes a bytes instance and pads it with null bytes until it's length chars.
+ """
+ if isinstance(name, string_types):
+ name = _bytes(name, 'utf-8')
+ return name + b'\x00' * (length - len(name))
+
+
+class StataStrLWriter(object):
+ """
+ Converter for Stata StrLs
+
+ Stata StrLs map 8 byte values to strings which are stored using a
+ dictionary-like format where strings are keyed to two values.
+
+ Parameters
+ ----------
+ df : DataFrame
+ DataFrame to convert
+ columns : list
+ List of columns names to convert to StrL
+ version : int, optional
+ dta version. Currently supports 117, 118 and 119
+ byteorder : str, optional
+ Can be ">", "<", "little", or "big". default is `sys.byteorder`
+
+ Notes
+ -----
+ Supports creation of the StrL block of a dta file for dta versions
+ 117, 118 and 119. These differ in how the GSO is stored. 118 and
+ 119 store the GSO lookup value as a uint32 and a uint64, while 117
+ uses two uint32s. 118 and 119 also encode all strings as unicode
+ which is required by the format. 117 uses 'latin-1' a fixed width
+ encoding that extends the 7-bit ascii table with an additional 128
+ characters.
+ """
+
+ def __init__(self, df, columns, version=117, byteorder=None):
+ if version not in (117, 118, 119):
+ raise ValueError('Only dta versions 117, 118 and 119 supported')
+ self._dta_ver = version
+
+ self.df = df
+ self.columns = columns
+ self._gso_table = OrderedDict((('', (0, 0)),))
+ if byteorder is None:
+ byteorder = sys.byteorder
+ self._byteorder = _set_endianness(byteorder)
+
+ gso_v_type = 'I' # uint32
+ gso_o_type = 'Q' # uint64
+ self._encoding = 'utf-8'
+ if version == 117:
+ o_size = 4
+ gso_o_type = 'I' # 117 used uint32
+ self._encoding = 'latin-1'
+ elif version == 118:
+ o_size = 6
+ else: # version == 119
+ o_size = 5
+ self._o_offet = 2 ** (8 * (8 - o_size))
+ self._gso_o_type = gso_o_type
+ self._gso_v_type = gso_v_type
+
+ def _convert_key(self, key):
+ v, o = key
+ return v + self._o_offet * o
+
+ def generate_table(self):
+ """
+ Generates the GSO lookup table for the DataFRame
+
+ Returns
+ -------
+ gso_table : OrderedDict
+ Ordered dictionary using the string found as keys
+ and their lookup position (v,o) as values
+ gso_df : DataFrame
+ DataFrame where strl columns have been converted to
+ (v,o) values
+
+ Notes
+ -----
+ Modifies the DataFrame in-place.
+
+ The DataFrame returned encodes the (v,o) values as uint64s. The
+ encoding depends on teh dta version, and can be expressed as
+
+ enc = v + o * 2 ** (o_size * 8)
+
+ so that v is stored in the lower bits and o is in the upper
+ bits. o_size is
+
+ * 117: 4
+ * 118: 6
+ * 119: 5
+ """
+
+ gso_table = self._gso_table
+ gso_df = self.df
+ columns = list(gso_df.columns)
+ selected = gso_df[self.columns]
+ col_index = [(col, columns.index(col)) for col in self.columns]
+ keys = np.empty(selected.shape, dtype=np.uint64)
+ for o, (idx, row) in enumerate(selected.iterrows()):
+ for j, (col, v) in enumerate(col_index):
+ val = row[col]
+ key = gso_table.get(val, None)
+ if key is None:
+ # Stata prefers human numbers
+ key = (v + 1, o + 1)
+ gso_table[val] = key
+ keys[o, j] = self._convert_key(key)
+ for i, col in enumerate(self.columns):
+ gso_df[col] = keys[:, i]
+
+ return gso_table, gso_df
+
+ def _encode(self, s):
+ """
+ Python 3 compatibility shim
+ """
+ if compat.PY3:
+ return s.encode(self._encoding)
+ else:
+ if isinstance(s, text_type):
+ return s.encode(self._encoding)
+ return s
+
+ def generate_blob(self, gso_table):
+ """
+ Generates the binary blob of GSOs that is written to the dta file.
+
+ Parameters
+ ----------
+ gso_table : OrderedDict
+ Ordered dictionary (str, vo)
+
+ Returns
+ -------
+ gso : bytes
+ Binary content of dta file to be placed between strl tags
+
+ Notes
+ -----
+ Output format depends on dta version. 117 uses two uint32s to
+ express v and o while 118+ uses a uint32 for v and a uint64 for o.
+ """
+ # Format information
+ # Length includes null term
+ # 117
+ # GSOvvvvooootllllxxxxxxxxxxxxxxx...x
+ # 3 u4 u4 u1 u4 string + null term
+ #
+ # 118, 119
+ # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x
+ # 3 u4 u8 u1 u4 string + null term
+
+ bio = BytesIO()
+ gso = _bytes('GSO', 'ascii')
+ gso_type = struct.pack(self._byteorder + 'B', 130)
+ null = struct.pack(self._byteorder + 'B', 0)
+ v_type = self._byteorder + self._gso_v_type
+ o_type = self._byteorder + self._gso_o_type
+ len_type = self._byteorder + 'I'
+ for strl, vo in gso_table.items():
+ if vo == (0, 0):
+ continue
+ v, o = vo
+
+ # GSO
+ bio.write(gso)
+
+ # vvvv
+ bio.write(struct.pack(v_type, v))
+
+ # oooo / oooooooo
+ bio.write(struct.pack(o_type, o))
+
+ # t
+ bio.write(gso_type)
+
+ # llll
+ encoded = self._encode(strl)
+ bio.write(struct.pack(len_type, len(encoded) + 1))
+
+ # xxx...xxx
+ s = _bytes(strl, 'utf-8')
+ bio.write(s)
+ bio.write(null)
+
+ bio.seek(0)
+ return bio.read()
+
+
+class StataWriter117(StataWriter):
+ """
+ A class for writing Stata binary dta files in Stata 13 format (117)
+
+ .. versionadded:: 0.23.0
+
+ Parameters
+ ----------
+ fname : path (string), buffer or path object
+ string, path object (pathlib.Path or py._path.local.LocalPath) or
+ object implementing a binary write() functions.
+ data : DataFrame
+ Input to save
+ convert_dates : dict
+ Dictionary mapping columns containing datetime types to stata internal
+ format to use when writing the dates. Options are 'tc', 'td', 'tm',
+ 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name.
+ Datetime columns that do not have a conversion type specified will be
+ converted to 'tc'. Raises NotImplementedError if a datetime column has
+ timezone information
+ write_index : bool
+ Write the index to Stata dataset.
+ encoding : str
+ Default is latin-1. Only latin-1 and ascii are supported.
+ byteorder : str
+ Can be ">", "<", "little", or "big". default is `sys.byteorder`
+ time_stamp : datetime
+ A datetime to use as file creation date. Default is the current time
+ data_label : str
+ A label for the data set. Must be 80 characters or smaller.
+ variable_labels : dict
+ Dictionary containing columns as keys and variable labels as values.
+ Each label must be 80 characters or smaller.
+ convert_strl : list
+ List of columns names to convert to Stata StrL format. Columns with
+ more than 2045 characters are aautomatically written as StrL.
+ Smaller columns can be converted by including the column name. Using
+ StrLs can reduce output file size when strings are longer than 8
+ characters, and either frequently repeated or sparse.
+
+ Returns
+ -------
+ writer : StataWriter117 instance
+ The StataWriter117 instance has a write_file method, which will
+ write the file to the given `fname`.
+
+ Raises
+ ------
+ NotImplementedError
+ * If datetimes contain timezone information
+ ValueError
+ * Columns listed in convert_dates are neither datetime64[ns]
+ or datetime.datetime
+ * Column dtype is not representable in Stata
+ * Column listed in convert_dates is not in DataFrame
+ * Categorical label contains more than 32,000 characters
+
+ Examples
+ --------
+ >>> import pandas as pd
+ >>> from pandas.io.stata import StataWriter117
+ >>> data = pd.DataFrame([[1.0, 1, 'a']], columns=['a', 'b', 'c'])
+ >>> writer = StataWriter117('./data_file.dta', data)
+ >>> writer.write_file()
+
+ Or with long strings stored in strl format
+
+ >>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
+ ... columns=['strls'])
+ >>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
+ ... convert_strl=['strls'])
+ >>> writer.write_file()
+ """
+
+ _max_string_length = 2045
+
+ def __init__(self, fname, data, convert_dates=None, write_index=True,
+ encoding="latin-1", byteorder=None, time_stamp=None,
+ data_label=None, variable_labels=None, convert_strl=None):
+ # Shallow copy since convert_strl might be modified later
+ self._convert_strl = [] if convert_strl is None else convert_strl[:]
+
+ super(StataWriter117, self).__init__(fname, data, convert_dates,
+ write_index, encoding, byteorder,
+ time_stamp, data_label,
+ variable_labels)
+ self._map = None
+ self._strl_blob = None
+
+ @staticmethod
+ def _tag(val, tag):
+ """Surround val with """
+ if isinstance(val, str) and compat.PY3:
+ val = _bytes(val, 'utf-8')
+ return (_bytes('<' + tag + '>', 'utf-8') + val +
+ _bytes('' + tag + '>', 'utf-8'))
+
+ def _update_map(self, tag):
+ """Update map location for tag with file position"""
+ self._map[tag] = self._file.tell()
+
+ def _write_header(self, data_label=None, time_stamp=None):
+ """Write the file header"""
+ byteorder = self._byteorder
+ self._file.write(_bytes('', 'utf-8'))
+ bio = BytesIO()
+ # ds_format - 117
+ bio.write(self._tag(_bytes('117', 'utf-8'), 'release'))
+ # byteorder
+ bio.write(self._tag(byteorder == ">" and "MSF" or "LSF", 'byteorder'))
+ # number of vars, 2 bytes
+ assert self.nvar < 2 ** 16
+ bio.write(self._tag(struct.pack(byteorder + "H", self.nvar), 'K'))
+ # number of obs, 4 bytes
+ bio.write(self._tag(struct.pack(byteorder + "I", self.nobs), 'N'))
+ # data label 81 bytes, char, null terminated
+ label = data_label[:80] if data_label is not None else ''
+ label_len = struct.pack(byteorder + "B", len(label))
+ label = label_len + _bytes(label, 'utf-8')
+ bio.write(self._tag(label, 'label'))
+ # time stamp, 18 bytes, char, null terminated
+ # format dd Mon yyyy hh:mm
+ if time_stamp is None:
+ time_stamp = datetime.datetime.now()
+ elif not isinstance(time_stamp, datetime.datetime):
+ raise ValueError("time_stamp should be datetime type")
+ # Avoid locale-specific month conversion
+ months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug',
+ 'Sep', 'Oct', 'Nov', 'Dec']
+ month_lookup = {i + 1: month for i, month in enumerate(months)}
+ ts = (time_stamp.strftime("%d ") +
+ month_lookup[time_stamp.month] +
+ time_stamp.strftime(" %Y %H:%M"))
+ # '\x11' added due to inspection of Stata file
+ ts = b'\x11' + _bytes(ts, 'utf8')
+ bio.write(self._tag(ts, 'timestamp'))
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'header'))
+
+ def _write_map(self):
+ """Called twice during file write. The first populates the values in
+ the map with 0s. The second call writes the final map locations when
+ all blocks have been written."""
+ if self._map is None:
+ self._map = OrderedDict((('stata_data', 0),
+ ('map', self._file.tell()),
+ ('variable_types', 0),
+ ('varnames', 0),
+ ('sortlist', 0),
+ ('formats', 0),
+ ('value_label_names', 0),
+ ('variable_labels', 0),
+ ('characteristics', 0),
+ ('data', 0),
+ ('strls', 0),
+ ('value_labels', 0),
+ ('stata_data_close', 0),
+ ('end-of-file', 0)))
+ # Move to start of map
+ self._file.seek(self._map['map'])
+ bio = BytesIO()
+ for val in self._map.values():
+ bio.write(struct.pack(self._byteorder + 'Q', val))
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'map'))
+
+ def _write_variable_types(self):
+ self._update_map('variable_types')
+ bio = BytesIO()
+ for typ in self.typlist:
+ bio.write(struct.pack(self._byteorder + 'H', typ))
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'variable_types'))
+
+ def _write_varnames(self):
+ self._update_map('varnames')
+ bio = BytesIO()
+ for name in self.varlist:
+ name = self._null_terminate(name, True)
+ name = _pad_bytes_new(name[:32], 33)
+ bio.write(name)
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'varnames'))
+
+ def _write_sortlist(self):
+ self._update_map('sortlist')
+ self._file.write(self._tag(b'\x00\00' * (self.nvar + 1), 'sortlist'))
+
+ def _write_formats(self):
+ self._update_map('formats')
+ bio = BytesIO()
+ for fmt in self.fmtlist:
+ bio.write(_pad_bytes_new(fmt, 49))
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'formats'))
+
+ def _write_value_label_names(self):
+ self._update_map('value_label_names')
+ bio = BytesIO()
+ for i in range(self.nvar):
+ # Use variable name when categorical
+ name = '' # default name
+ if self._is_col_cat[i]:
+ name = self.varlist[i]
+ name = self._null_terminate(name, True)
+ name = _pad_bytes_new(name[:32], 33)
+ bio.write(name)
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'value_label_names'))
+
+ def _write_variable_labels(self):
+ # Missing labels are 80 blank characters plus null termination
+ self._update_map('variable_labels')
+ bio = BytesIO()
+ blank = _pad_bytes_new('', 81)
+
+ if self._variable_labels is None:
+ for _ in range(self.nvar):
+ bio.write(blank)
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'variable_labels'))
+ return
+
+ for col in self.data:
+ if col in self._variable_labels:
+ label = self._variable_labels[col]
+ if len(label) > 80:
+ raise ValueError('Variable labels must be 80 characters '
+ 'or fewer')
+ is_latin1 = all(ord(c) < 256 for c in label)
+ if not is_latin1:
+ raise ValueError('Variable labels must contain only '
+ 'characters that can be encoded in '
+ 'Latin-1')
+ bio.write(_pad_bytes_new(label, 81))
+ else:
+ bio.write(blank)
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'variable_labels'))
+
+ def _write_characteristics(self):
+ self._update_map('characteristics')
+ self._file.write(self._tag(b'', 'characteristics'))
+
+ def _write_data(self):
+ self._update_map('data')
+ data = self.data
+ self._file.write(b'')
+ data.tofile(self._file)
+ self._file.write(b'')
+
+ def _write_strls(self):
+ self._update_map('strls')
+ strls = b''
+ if self._strl_blob is not None:
+ strls = self._strl_blob
+ self._file.write(self._tag(strls, 'strls'))
+
+ def _write_expansion_fields(self):
+ """No-op in dta 117+"""
+ pass
+
+ def _write_value_labels(self):
+ self._update_map('value_labels')
+ bio = BytesIO()
+ for vl in self._value_labels:
+ lab = vl.generate_value_label(self._byteorder, self._encoding)
+ lab = self._tag(lab, 'lbl')
+ bio.write(lab)
+ bio.seek(0)
+ self._file.write(self._tag(bio.read(), 'value_labels'))
+
+ def _write_file_close_tag(self):
+ self._update_map('stata_data_close')
+ self._file.write(_bytes('', 'utf-8'))
+ self._update_map('end-of-file')
+
+ def _update_strl_names(self):
+ """Update column names for conversion to strl if they might have been
+ changed to comply with Stata naming rules"""
+ # Update convert_strl if names changed
+ for orig, new in self._converted_names.items():
+ if orig in self._convert_strl:
+ idx = self._convert_strl.index(orig)
+ self._convert_strl[idx] = new
+
+ def _convert_strls(self, data):
+ """Convert columns to StrLs if either very large or in the
+ convert_strl variable"""
+ convert_cols = []
+ for i, col in enumerate(data):
+ if self.typlist[i] == 32768 or col in self._convert_strl:
+ convert_cols.append(col)
+ if convert_cols:
+ ssw = StataStrLWriter(data, convert_cols)
+ tab, new_data = ssw.generate_table()
+ data = new_data
+ self._strl_blob = ssw.generate_blob(tab)
+ return data
+
+ def _set_formats_and_types(self, data, dtypes):
+ self.typlist = []
+ self.fmtlist = []
+ for col, dtype in dtypes.iteritems():
+ force_strl = col in self._convert_strl
+ fmt = _dtype_to_default_stata_fmt(dtype, data[col],
+ dta_version=117,
+ force_strl=force_strl)
+ self.fmtlist.append(fmt)
+ self.typlist.append(_dtype_to_stata_type_117(dtype, data[col],
+ force_strl))
diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py
index 972a47ef91c05..110b790a65037 100644
--- a/pandas/tests/io/test_stata.py
+++ b/pandas/tests/io/test_stata.py
@@ -5,21 +5,21 @@
import os
import struct
import warnings
-from datetime import datetime
from collections import OrderedDict
+from datetime import datetime
import numpy as np
+import pytest
+
import pandas as pd
import pandas.util.testing as tm
-import pytest
from pandas import compat
-from pandas._libs.tslib import NaT
from pandas.compat import iterkeys
from pandas.core.dtypes.common import is_categorical_dtype
from pandas.core.frame import DataFrame, Series
from pandas.io.parsers import read_csv
-from pandas.io.stata import (read_stata, StataReader, InvalidColumnName,
- PossiblePrecisionLoss, StataMissingValue)
+from pandas.io.stata import (InvalidColumnName, PossiblePrecisionLoss,
+ StataMissingValue, StataReader, read_stata)
@pytest.fixture
@@ -104,11 +104,12 @@ def read_dta(self, file):
def read_csv(self, file):
return read_csv(file, parse_dates=True)
- def test_read_empty_dta(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_read_empty_dta(self, version):
empty_ds = DataFrame(columns=['unit'])
# GH 7369, make sure can read a 0-obs dta file
with tm.ensure_clean() as path:
- empty_ds.to_stata(path, write_index=False)
+ empty_ds.to_stata(path, write_index=False, version=version)
empty_ds2 = read_stata(path)
tm.assert_frame_equal(empty_ds, empty_ds2)
@@ -319,7 +320,8 @@ def test_write_dta6(self):
tm.assert_frame_equal(written_and_read_again.set_index('index'),
original, check_index_type=False)
- def test_read_write_dta10(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_read_write_dta10(self, version):
original = DataFrame(data=[["string", "object", 1, 1.1,
np.datetime64('2003-12-25')]],
columns=['string', 'object', 'integer',
@@ -330,7 +332,7 @@ def test_read_write_dta10(self):
original['integer'] = original['integer'].astype(np.int32)
with tm.ensure_clean() as path:
- original.to_stata(path, {'datetime': 'tc'})
+ original.to_stata(path, {'datetime': 'tc'}, version=version)
written_and_read_again = self.read_dta(path)
# original.index is np.int32, read index is np.int64
tm.assert_frame_equal(written_and_read_again.set_index('index'),
@@ -351,7 +353,8 @@ def test_write_preserves_original(self):
df.to_stata(path, write_index=False)
tm.assert_frame_equal(df, df_copy)
- def test_encoding(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_encoding(self, version):
# GH 4626, proper encoding handling
raw = read_stata(self.dta_encoding)
@@ -368,7 +371,8 @@ def test_encoding(self):
assert isinstance(result, unicode) # noqa
with tm.ensure_clean() as path:
- encoded.to_stata(path, encoding='latin-1', write_index=False)
+ encoded.to_stata(path, encoding='latin-1',
+ write_index=False, version=version)
reread_encoded = read_stata(path, encoding='latin-1')
tm.assert_frame_equal(encoded, reread_encoded)
@@ -392,7 +396,8 @@ def test_read_write_dta11(self):
tm.assert_frame_equal(
written_and_read_again.set_index('index'), formatted)
- def test_read_write_dta12(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_read_write_dta12(self, version):
original = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_1',
'astringwithmorethan32characters_2',
@@ -412,7 +417,8 @@ def test_read_write_dta12(self):
with tm.ensure_clean() as path:
with warnings.catch_warnings(record=True) as w:
- original.to_stata(path, None)
+ warnings.simplefilter('always', InvalidColumnName)
+ original.to_stata(path, None, version=version)
# should get a warning for that format.
assert len(w) == 1
@@ -436,9 +442,10 @@ def test_read_write_dta13(self):
tm.assert_frame_equal(written_and_read_again.set_index('index'),
formatted)
+ @pytest.mark.parametrize('version', [114, 117])
@pytest.mark.parametrize(
'file', ['dta14_113', 'dta14_114', 'dta14_115', 'dta14_117'])
- def test_read_write_reread_dta14(self, file, parsed_114):
+ def test_read_write_reread_dta14(self, file, parsed_114, version):
file = getattr(self, file)
parsed = self.read_dta(file)
parsed.index.name = 'index'
@@ -454,7 +461,7 @@ def test_read_write_reread_dta14(self, file, parsed_114):
tm.assert_frame_equal(parsed_114, parsed)
with tm.ensure_clean() as path:
- parsed_114.to_stata(path, {'date_td': 'td'})
+ parsed_114.to_stata(path, {'date_td': 'td'}, version=version)
written_and_read_again = self.read_dta(path)
tm.assert_frame_equal(
written_and_read_again.set_index('index'), parsed_114)
@@ -477,18 +484,29 @@ def test_read_write_reread_dta15(self, file):
tm.assert_frame_equal(expected, parsed)
- def test_timestamp_and_label(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_timestamp_and_label(self, version):
original = DataFrame([(1,)], columns=['variable'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp,
- data_label=data_label)
+ data_label=data_label,
+ version=version)
with StataReader(path) as reader:
assert reader.time_stamp == '29 Feb 2000 14:21'
assert reader.data_label == data_label
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_invalid_timestamp(self, version):
+ original = DataFrame([(1,)], columns=['variable'])
+ time_stamp = '01 Jan 2000, 00:00:00'
+ with tm.ensure_clean() as path:
+ with pytest.raises(ValueError):
+ original.to_stata(path, time_stamp=time_stamp,
+ version=version)
+
def test_numeric_column_names(self):
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
original.index.name = 'index'
@@ -504,7 +522,8 @@ def test_numeric_column_names(self):
written_and_read_again.columns = map(convert_col_name, columns)
tm.assert_frame_equal(original, written_and_read_again)
- def test_nan_to_missing_value(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_nan_to_missing_value(self, version):
s1 = Series(np.arange(4.0), dtype=np.float32)
s2 = Series(np.arange(4.0), dtype=np.float64)
s1[::2] = np.nan
@@ -512,7 +531,7 @@ def test_nan_to_missing_value(self):
original = DataFrame({'s1': s1, 's2': s2})
original.index.name = 'index'
with tm.ensure_clean() as path:
- original.to_stata(path)
+ original.to_stata(path, version=version)
written_and_read_again = self.read_dta(path)
written_and_read_again = written_and_read_again.set_index('index')
tm.assert_frame_equal(written_and_read_again, original)
@@ -627,7 +646,9 @@ def test_write_missing_strings(self):
tm.assert_frame_equal(written_and_read_again.set_index('index'),
expected)
- def test_bool_uint(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ @pytest.mark.parametrize('byteorder', ['>', '<'])
+ def test_bool_uint(self, byteorder, version):
s0 = Series([0, 1, True], dtype=np.bool)
s1 = Series([0, 1, 100], dtype=np.uint8)
s2 = Series([0, 1, 255], dtype=np.uint8)
@@ -646,7 +667,7 @@ def test_bool_uint(self):
expected[c] = expected[c].astype(t)
with tm.ensure_clean() as path:
- original.to_stata(path)
+ original.to_stata(path, byteorder=byteorder, version=version)
written_and_read_again = self.read_dta(path)
written_and_read_again = written_and_read_again.set_index('index')
tm.assert_frame_equal(written_and_read_again, expected)
@@ -757,7 +778,7 @@ def test_big_dates(self):
else:
row.append(datetime(yr[i], mo[i], dd[i]))
expected.append(row)
- expected.append([NaT] * 7)
+ expected.append([pd.NaT] * 7)
columns = ['date_tc', 'date_td', 'date_tw', 'date_tm', 'date_tq',
'date_th', 'date_ty']
@@ -848,7 +869,8 @@ def test_drop_column(self):
columns = ['byte_', 'int_', 'long_', 'not_found']
read_stata(self.dta15_117, convert_dates=True, columns=columns)
- def test_categorical_writing(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_categorical_writing(self, version):
original = DataFrame.from_records(
[
["one", "ten", "one", "one", "one", 1],
@@ -880,7 +902,7 @@ def test_categorical_writing(self):
with tm.ensure_clean() as path:
with warnings.catch_warnings(record=True) as w: # noqa
# Silence warnings
- original.to_stata(path)
+ original.to_stata(path, version=version)
written_and_read_again = self.read_dta(path)
res = written_and_read_again.set_index('index')
tm.assert_frame_equal(res, expected, check_categorical=False)
@@ -915,7 +937,8 @@ def test_categorical_warnings_and_errors(self):
# should get a warning for mixed content
assert len(w) == 1
- def test_categorical_with_stata_missing_values(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_categorical_with_stata_missing_values(self, version):
values = [['a' + str(i)] for i in range(120)]
values.append([np.nan])
original = pd.DataFrame.from_records(values, columns=['many_labels'])
@@ -923,7 +946,7 @@ def test_categorical_with_stata_missing_values(self):
for col in original], axis=1)
original.index.name = 'index'
with tm.ensure_clean() as path:
- original.to_stata(path)
+ original.to_stata(path, version=version)
written_and_read_again = self.read_dta(path)
res = written_and_read_again.set_index('index')
tm.assert_frame_equal(res, original, check_categorical=False)
@@ -1129,7 +1152,8 @@ def test_read_chunks_columns(self):
tm.assert_frame_equal(from_frame, chunk, check_dtype=False)
pos += chunksize
- def test_write_variable_labels(self):
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_write_variable_labels(self, version):
# GH 13631, add support for writing variable labels
original = pd.DataFrame({'a': [1, 2, 3, 4],
'b': [1.0, 3.0, 27.0, 81.0],
@@ -1138,7 +1162,9 @@ def test_write_variable_labels(self):
original.index.name = 'index'
variable_labels = {'a': 'City Rank', 'b': 'City Exponent', 'c': 'City'}
with tm.ensure_clean() as path:
- original.to_stata(path, variable_labels=variable_labels)
+ original.to_stata(path,
+ variable_labels=variable_labels,
+ version=version)
with StataReader(path) as sr:
read_labels = sr.variable_labels()
expected_labels = {'index': '',
@@ -1149,11 +1175,36 @@ def test_write_variable_labels(self):
variable_labels['index'] = 'The Index'
with tm.ensure_clean() as path:
- original.to_stata(path, variable_labels=variable_labels)
+ original.to_stata(path,
+ variable_labels=variable_labels,
+ version=version)
with StataReader(path) as sr:
read_labels = sr.variable_labels()
assert read_labels == variable_labels
+ @pytest.mark.parametrize('version', [114, 117])
+ def test_invalid_variable_labels(self, version):
+ original = pd.DataFrame({'a': [1, 2, 3, 4],
+ 'b': [1.0, 3.0, 27.0, 81.0],
+ 'c': ['Atlanta', 'Birmingham',
+ 'Cincinnati', 'Detroit']})
+ original.index.name = 'index'
+ variable_labels = {'a': 'very long' * 10,
+ 'b': 'City Exponent',
+ 'c': 'City'}
+ with tm.ensure_clean() as path:
+ with pytest.raises(ValueError):
+ original.to_stata(path,
+ variable_labels=variable_labels,
+ version=version)
+
+ variable_labels['a'] = u'invalid character Œ'
+ with tm.ensure_clean() as path:
+ with pytest.raises(ValueError):
+ original.to_stata(path,
+ variable_labels=variable_labels,
+ version=version)
+
def test_write_variable_label_errors(self):
original = pd.DataFrame({'a': [1, 2, 3, 4],
'b': [1.0, 3.0, 27.0, 81.0],
@@ -1201,6 +1252,13 @@ def test_default_date_conversion(self):
direct = read_stata(path, convert_dates=True)
tm.assert_frame_equal(reread, direct)
+ dates_idx = original.columns.tolist().index('dates')
+ original.to_stata(path,
+ write_index=False,
+ convert_dates={dates_idx: 'tc'})
+ direct = read_stata(path, convert_dates=True)
+ tm.assert_frame_equal(reread, direct)
+
def test_unsupported_type(self):
original = pd.DataFrame({'a': [1 + 2j, 2 + 4j]})
@@ -1355,3 +1413,63 @@ def test_date_parsing_ignores_format_details(self, column):
unformatted = df.loc[0, column]
formatted = df.loc[0, column + "_fmt"]
assert unformatted == formatted
+
+ def test_writer_117(self):
+ original = DataFrame(data=[['string', 'object', 1, 1, 1, 1.1, 1.1,
+ np.datetime64('2003-12-25'),
+ 'a', 'a' * 2045, 'a' * 5000, 'a'],
+ ['string-1', 'object-1', 1, 1, 1, 1.1, 1.1,
+ np.datetime64('2003-12-26'),
+ 'b', 'b' * 2045, '', '']
+ ],
+ columns=['string', 'object', 'int8', 'int16',
+ 'int32', 'float32', 'float64',
+ 'datetime',
+ 's1', 's2045', 'srtl', 'forced_strl'])
+ original['object'] = Series(original['object'], dtype=object)
+ original['int8'] = Series(original['int8'], dtype=np.int8)
+ original['int16'] = Series(original['int16'], dtype=np.int16)
+ original['int32'] = original['int32'].astype(np.int32)
+ original['float32'] = Series(original['float32'], dtype=np.float32)
+ original.index.name = 'index'
+ original.index = original.index.astype(np.int32)
+ copy = original.copy()
+ with tm.ensure_clean() as path:
+ original.to_stata(path,
+ convert_dates={'datetime': 'tc'},
+ convert_strl=['forced_strl'],
+ version=117)
+ written_and_read_again = self.read_dta(path)
+ # original.index is np.int32, read index is np.int64
+ tm.assert_frame_equal(written_and_read_again.set_index('index'),
+ original, check_index_type=False)
+ tm.assert_frame_equal(original, copy)
+
+ def test_convert_strl_name_swap(self):
+ original = DataFrame([['a' * 3000, 'A', 'apple'],
+ ['b' * 1000, 'B', 'banana']],
+ columns=['long1' * 10, 'long', 1])
+ original.index.name = 'index'
+
+ with warnings.catch_warnings(record=True) as w: # noqa
+ with tm.ensure_clean() as path:
+ original.to_stata(path, convert_strl=['long', 1], version=117)
+ reread = self.read_dta(path)
+ reread = reread.set_index('index')
+ reread.columns = original.columns
+ tm.assert_frame_equal(reread, original,
+ check_index_type=False)
+
+ def test_invalid_date_conversion(self):
+ # GH 12259
+ dates = [dt.datetime(1999, 12, 31, 12, 12, 12, 12000),
+ dt.datetime(2012, 12, 21, 12, 21, 12, 21000),
+ dt.datetime(1776, 7, 4, 7, 4, 7, 4000)]
+ original = pd.DataFrame({'nums': [1.0, 2.0, 3.0],
+ 'strs': ['apple', 'banana', 'cherry'],
+ 'dates': dates})
+
+ with tm.ensure_clean() as path:
+ with pytest.raises(ValueError):
+ original.to_stata(path,
+ convert_dates={'wrong_name': 'tc'})