diff --git a/doc/source/io.rst b/doc/source/io.rst index 70d5c195233c3..ae07e0af10c0e 100644 --- a/doc/source/io.rst +++ b/doc/source/io.rst @@ -3654,10 +3654,15 @@ missing values are represented as ``np.nan``. If ``True``, missing values are represented using ``StataMissingValue`` objects, and columns containing missing values will have ``dtype`` set to ``object``. - The StataReader supports .dta Formats 104, 105, 108, 113-115 and 117. Alternatively, the function :func:`~pandas.io.stata.read_stata` can be used +.. note:: + + Setting ``preserve_dtypes=False`` will upcast all integer data types to + ``int64`` and all floating point data types to ``float64``. By default, + the Stata data types are preserved when importing. + .. ipython:: python :suppress: diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index f8a245f750068..a40e1d87776d9 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -864,6 +864,7 @@ Enhancements - Added support for writing datetime64 columns with ``to_sql`` for all database flavors (:issue:`7103`). - Added support for bool, uint8, uint16 and uint32 datatypes in ``to_stata`` (:issue:`7097`, :issue:`7365`) +- Added conversion option when importing Stata files (:issue:`8527`) - Added ``layout`` keyword to ``DataFrame.plot``. You can pass a tuple of ``(rows, columns)``, one of which can be ``-1`` to automatically infer (:issue:`6667`, :issue:`8071`). - Allow to pass multiple axes to ``DataFrame.plot``, ``hist`` and ``boxplot`` (:issue:`5353`, :issue:`6970`, :issue:`7069`) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 246465153c611..8bf1c596b62cf 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -29,7 +29,7 @@ def read_stata(filepath_or_buffer, convert_dates=True, convert_categoricals=True, encoding=None, index=None, - convert_missing=False): + convert_missing=False, preserve_dtypes=True): """ Read Stata file into DataFrame @@ -52,13 +52,14 @@ def read_stata(filepath_or_buffer, convert_dates=True, If True, columns containing missing values are returned with object data types and missing values are represented by StataMissingValue objects. + preserve_dtypes : boolean, defaults to True + Preserve Stata datatypes. If False, numeric data are upcast to pandas + default types for foreign data (float64 or int64) """ reader = StataReader(filepath_or_buffer, encoding) - return reader.data(convert_dates, - convert_categoricals, - index, - convert_missing) + return reader.data(convert_dates, convert_categoricals, index, + convert_missing, preserve_dtypes) _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] @@ -976,7 +977,7 @@ def _read_strls(self): self.path_or_buf.read(1) # zero-termination def data(self, convert_dates=True, convert_categoricals=True, index=None, - convert_missing=False): + convert_missing=False, preserve_dtypes=True): """ Reads observations from Stata file, converting them into a dataframe @@ -995,7 +996,9 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, nans. If True, columns containing missing values are returned with object data types and missing values are represented by StataMissingValue objects. - + preserve_dtypes : boolean, defaults to True + Preserve Stata datatypes. If False, numeric data are upcast to + pandas default types for foreign data (float64 or int64) Returns ------- y : DataFrame instance @@ -1107,6 +1110,21 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, labeled_data[(data[col] == k).values] = v data[col] = Categorical.from_array(labeled_data) + if not preserve_dtypes: + retyped_data = [] + convert = False + for col in data: + dtype = data[col].dtype + if dtype in (np.float16, np.float32): + dtype = np.float64 + convert = True + elif dtype in (np.int8, np.int16, np.int32): + dtype = np.int64 + convert = True + retyped_data.append((col, data[col].astype(dtype))) + if convert: + data = DataFrame.from_items(retyped_data) + return data def data_label(self): diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index c458688b3d2d2..c5727a5579b79 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -83,6 +83,7 @@ def setUp(self): def read_dta(self, file): + # Legacy default reader configuration return read_stata(file, convert_dates=True) def read_csv(self, file): @@ -694,6 +695,35 @@ def test_big_dates(self): tm.assert_frame_equal(written_and_read_again.set_index('index'), expected) + def test_dtype_conversion(self): + expected = self.read_csv(self.csv15) + expected['byte_'] = expected['byte_'].astype(np.int8) + expected['int_'] = expected['int_'].astype(np.int16) + expected['long_'] = expected['long_'].astype(np.int32) + expected['float_'] = expected['float_'].astype(np.float32) + expected['double_'] = expected['double_'].astype(np.float64) + expected['date_td'] = expected['date_td'].apply(datetime.strptime, + args=('%Y-%m-%d',)) + + no_conversion = read_stata(self.dta15_117, + convert_dates=True) + tm.assert_frame_equal(expected, no_conversion) + + conversion = read_stata(self.dta15_117, + convert_dates=True, + preserve_dtypes=False) + + # read_csv types are the same + expected = self.read_csv(self.csv15) + expected['date_td'] = expected['date_td'].apply(datetime.strptime, + args=('%Y-%m-%d',)) + + tm.assert_frame_equal(expected, conversion) + + + + +