From 88c4c5526b38194f00a00c207ee5388dfd508092 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Fri, 21 Mar 2014 22:21:19 +0000 Subject: [PATCH] BUG: NaN values not converted to Stata missing values Stata does not correctly handle NaNs, and so these must be replaced with Stata missing values (. by default). The fix checks floating point columns for nan and replaces these with the Stata numeric code for (.). One of the code paths which writes files correctly handled this case, and this last-minute check was removed. The write_index option was also being ignored by omission. This has been fixed and numerous tests which were not correct have been fixed. Also contains some additional tests which were uncovered edges cases related to fix. --- doc/source/release.rst | 1 + pandas/core/frame.py | 3 +- pandas/io/stata.py | 23 +++++-- pandas/io/tests/test_stata.py | 124 ++++++++++++++++++++++++++++++---- 4 files changed, 129 insertions(+), 22 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index 3937b4b30fa0e..55861a0f1b0f0 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -269,6 +269,7 @@ Bug Fixes - Bug in ``DataFrame.to_stata`` when columns have non-string names (:issue:`4558`) - Bug in compat with ``np.compress``, surfaced in (:issue:`6658`) - Bug in binary operations with a rhs of a Series not aligning (:issue:`6681`) +- Bug in ``DataFrame.to_stata`` which incorrectly handles nan values and ignores 'with_index' keyword argument (:issue:`6685`) pandas 0.13.1 ------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 253b9ac2c7a16..8cf164ba76c21 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1258,7 +1258,8 @@ def to_stata( from pandas.io.stata import StataWriter writer = StataWriter(fname, self, convert_dates=convert_dates, encoding=encoding, byteorder=byteorder, - time_stamp=time_stamp, data_label=data_label) + time_stamp=time_stamp, data_label=data_label, + write_index=write_index) writer.write_file() @Appender(fmt.docstring_to_string, indents=1) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 4bb61e385a75c..fd41961109511 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -990,8 +990,6 @@ def _dtype_to_stata_type(dtype): return chr(255) elif dtype == np.float32: return chr(254) - elif dtype == np.int64: - return chr(253) elif dtype == np.int32: return chr(253) elif dtype == np.int16: @@ -1025,8 +1023,6 @@ def _dtype_to_default_stata_fmt(dtype): return "%10.0g" elif dtype == np.float32: return "%9.0g" - elif dtype == np.int64: - return "%9.0g" elif dtype == np.int32: return "%12.0g" elif dtype == np.int8 or dtype == np.int16: @@ -1108,6 +1104,21 @@ def _write(self, to_write): self._file.write(to_write) + def _replace_nans(self, data): + # return data + """Checks floating point data columns for nans, and replaces these with + the generic Stata for missing value (.)""" + for c in data: + dtype = data[c].dtype + if dtype in (np.float32, np.float64): + if dtype == np.float32: + replacement = self.MISSING_VALUES['f'] + else: + replacement = self.MISSING_VALUES['d'] + data[c] = data[c].fillna(replacement) + + return data + def _check_column_names(self, data): """Checks column names to ensure that they are valid Stata column names. This includes checks for: @@ -1197,6 +1208,8 @@ def __iter__(self): data = _cast_to_stata_types(data) # Ensure column names are strings data = self._check_column_names(data) + # Replace NaNs with Stata missing values + data = self._replace_nans(data) self.datarows = DataFrameRowIter(data) self.nobs, self.nvar = data.shape self.data = data @@ -1340,8 +1353,6 @@ def _write_data_dates(self): var = _pad_bytes(var, typ) self._write(var) else: - if isnull(var): # this only matters for floats - var = MISSING_VALUES[TYPE_MAP[typ]] self._file.write(struct.pack(byteorder+TYPE_MAP[typ], var)) def _null_terminate(self, s, as_string=False): diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index fe79bf20615bb..c5debed6654af 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -13,7 +13,8 @@ import pandas as pd from pandas.core.frame import DataFrame, Series from pandas.io.parsers import read_csv -from pandas.io.stata import read_stata, StataReader, InvalidColumnName +from pandas.io.stata import (read_stata, StataReader, InvalidColumnName, + PossiblePrecisionLoss) import pandas.util.testing as tm from pandas.util.misc import is_little_endian from pandas import compat @@ -142,8 +143,7 @@ def test_read_dta2(self): parsed_117 = self.read_dta(self.dta2_117) # 113 is buggy due ot limits date format support in Stata # parsed_113 = self.read_dta(self.dta2_113) - - np.testing.assert_equal( + tm.assert_equal( len(w), 1) # should get a warning for that format. # buggy test because of the NaT comparison on certain platforms @@ -206,7 +206,7 @@ def test_read_write_dta5(self): original.index.name = 'index' with tm.ensure_clean() as path: - original.to_stata(path, None, False) + original.to_stata(path, None) written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), original) @@ -221,7 +221,7 @@ def test_write_dta6(self): original['quarter'] = original['quarter'].astype(np.int32) with tm.ensure_clean() as path: - original.to_stata(path, None, False) + original.to_stata(path, None) written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), original) @@ -257,7 +257,7 @@ def test_read_write_dta10(self): original['integer'] = original['integer'].astype(np.int32) with tm.ensure_clean() as path: - original.to_stata(path, {'datetime': 'tc'}, False) + original.to_stata(path, {'datetime': 'tc'}) written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), original) @@ -295,9 +295,9 @@ def test_read_write_dta11(self): with tm.ensure_clean() as path: with warnings.catch_warnings(record=True) as w: - original.to_stata(path, None, False) - np.testing.assert_equal( - len(w), 1) # should get a warning for that format. + original.to_stata(path, None) + # should get a warning for that format. + tm.assert_equal(len(w), 1) written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted) @@ -324,13 +324,12 @@ def test_read_write_dta12(self): with tm.ensure_clean() as path: with warnings.catch_warnings(record=True) as w: - original.to_stata(path, None, False) - np.testing.assert_equal( - len(w), 1) # should get a warning for that format. + original.to_stata(path, None) + tm.assert_equal(len(w), 1) # should get a warning for that format. written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), formatted) - + def test_read_write_dta13(self): s1 = Series(2**9, dtype=np.int16) s2 = Series(2**17, dtype=np.int32) @@ -366,7 +365,7 @@ def test_read_write_reread_dta14(self): tm.assert_frame_equal(parsed_114, parsed_115) with tm.ensure_clean() as path: - parsed_114.to_stata(path, {'date_td': 'td'}, write_index=False) + parsed_114.to_stata(path, {'date_td': 'td'}) written_and_read_again = self.read_dta(path) tm.assert_frame_equal(written_and_read_again.set_index('index'), parsed_114) @@ -406,7 +405,7 @@ def test_numeric_column_names(self): with warnings.catch_warnings(record=True) as w: tm.assert_produces_warning(original.to_stata(path), InvalidColumnName) # should produce a single warning - np.testing.assert_equal(len(w), 1) + tm.assert_equal(len(w), 1) written_and_read_again = self.read_dta(path) written_and_read_again = written_and_read_again.set_index('index') @@ -415,7 +414,102 @@ def test_numeric_column_names(self): written_and_read_again.columns = map(convert_col_name, columns) tm.assert_frame_equal(original, written_and_read_again) + def test_nan_to_missing_value(self): + s1 = Series(np.arange(4.0), dtype=np.float32) + s2 = Series(np.arange(4.0), dtype=np.float64) + s1[::2] = np.nan + s2[1::2] = np.nan + original = DataFrame({'s1': s1, 's2': s2}) + original.index.name = 'index' + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + written_and_read_again = written_and_read_again.set_index('index') + tm.assert_frame_equal(written_and_read_again, original) + + def test_no_index(self): + columns = ['x', 'y'] + original = DataFrame(np.reshape(np.arange(10.0), (5, 2)), + columns=columns) + original.index.name = 'index_not_written' + with tm.ensure_clean() as path: + original.to_stata(path, write_index=False) + written_and_read_again = self.read_dta(path) + tm.assertRaises(KeyError, + lambda: written_and_read_again['index_not_written']) + + def test_string_no_dates(self): + s1 = Series(['a', 'A longer string']) + s2 = Series([1.0, 2.0], dtype=np.float64) + original = DataFrame({'s1': s1, 's2': s2}) + original.index.name = 'index' + with tm.ensure_clean() as path: + original.to_stata(path) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal(written_and_read_again.set_index('index'), + original) + + def test_large_value_conversion(self): + s0 = Series([1, 99], dtype=np.int8) + s1 = Series([1, 127], dtype=np.int8) + s2 = Series([1, 2 ** 15 - 1], dtype=np.int16) + s3 = Series([1, 2 ** 63 - 1], dtype=np.int64) + original = DataFrame({'s0': s0, 's1': s1, 's2': s2, 's3': s3}) + original.index.name = 'index' + with tm.ensure_clean() as path: + with warnings.catch_warnings(record=True) as w: + tm.assert_produces_warning(original.to_stata(path), + PossiblePrecisionLoss) + # should produce a single warning + tm.assert_equal(len(w), 1) + + written_and_read_again = self.read_dta(path) + modified = original.copy() + modified['s1'] = Series(modified['s1'], dtype=np.int16) + modified['s2'] = Series(modified['s2'], dtype=np.int32) + modified['s3'] = Series(modified['s3'], dtype=np.float64) + tm.assert_frame_equal(written_and_read_again.set_index('index'), + modified) + + def test_dates_invalid_column(self): + original = DataFrame([datetime(2006, 11, 19, 23, 13, 20)]) + original.index.name = 'index' + with tm.ensure_clean() as path: + with warnings.catch_warnings(record=True) as w: + tm.assert_produces_warning(original.to_stata(path, {0: 'tc'}), + InvalidColumnName) + tm.assert_equal(len(w), 1) + + written_and_read_again = self.read_dta(path) + modified = original.copy() + modified.columns = ['_0'] + tm.assert_frame_equal(written_and_read_again.set_index('index'), + modified) + + def test_date_export_formats(self): + columns = ['tc', 'td', 'tw', 'tm', 'tq', 'th', 'ty'] + conversions = dict(((c, c) for c in columns)) + data = [datetime(2006, 11, 20, 23, 13, 20)] * len(columns) + original = DataFrame([data], columns=columns) + original.index.name = 'index' + expected_values = [datetime(2006, 11, 20, 23, 13, 20), # Time + datetime(2006, 11, 20), # Day + datetime(2006, 11, 19), # Week + datetime(2006, 11, 1), # Month + datetime(2006, 10, 1), # Quarter year + datetime(2006, 7, 1), # Half year + datetime(2006, 1, 1)] # Year + + expected = DataFrame([expected_values], columns=columns) + expected.index.name = 'index' + with tm.ensure_clean() as path: + original.to_stata(path, conversions) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal(written_and_read_again.set_index('index'), + expected) + if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], exit=False) +