diff --git a/doc/source/v0.15.1.txt b/doc/source/v0.15.1.txt index 89447121930fb..2bd88531f92d9 100644 --- a/doc/source/v0.15.1.txt +++ b/doc/source/v0.15.1.txt @@ -26,6 +26,8 @@ API changes Enhancements ~~~~~~~~~~~~ +- Added option to select columns when importing Stata files (:issue:`7935`) + .. _whatsnew_0151.performance: diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 8bf1c596b62cf..c2542594861c4 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -29,7 +29,7 @@ def read_stata(filepath_or_buffer, convert_dates=True, convert_categoricals=True, encoding=None, index=None, - convert_missing=False, preserve_dtypes=True): + convert_missing=False, preserve_dtypes=True, columns=None): """ Read Stata file into DataFrame @@ -55,11 +55,14 @@ def read_stata(filepath_or_buffer, convert_dates=True, preserve_dtypes : boolean, defaults to True Preserve Stata datatypes. If False, numeric data are upcast to pandas default types for foreign data (float64 or int64) + columns : list or None + Columns to retain. Columns will be returned in the given order. None + returns all columns """ reader = StataReader(filepath_or_buffer, encoding) return reader.data(convert_dates, convert_categoricals, index, - convert_missing, preserve_dtypes) + convert_missing, preserve_dtypes, columns) _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] @@ -977,7 +980,7 @@ def _read_strls(self): self.path_or_buf.read(1) # zero-termination def data(self, convert_dates=True, convert_categoricals=True, index=None, - convert_missing=False, preserve_dtypes=True): + convert_missing=False, preserve_dtypes=True, columns=None): """ Reads observations from Stata file, converting them into a dataframe @@ -999,6 +1002,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, preserve_dtypes : boolean, defaults to True Preserve Stata datatypes. If False, numeric data are upcast to pandas default types for foreign data (float64 or int64) + columns : list or None + Columns to retain. Columns will be returned in the given order. + None returns all columns + Returns ------- y : DataFrame instance @@ -1034,6 +1041,35 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None, data = DataFrame.from_records(data, index=index) data.columns = self.varlist + if columns is not None: + column_set = set(columns) + if len(column_set) != len(columns): + raise ValueError('columns contains duplicate entries') + unmatched = column_set.difference(data.columns) + if unmatched: + raise ValueError('The following columns were not found in the ' + 'Stata data set: ' + + ', '.join(list(unmatched))) + # Copy information for retained columns for later processing + dtyplist = [] + typlist = [] + fmtlist = [] + lbllist = [] + matched = set() + for i, col in enumerate(data.columns): + if col in column_set: + matched.update([col]) + dtyplist.append(self.dtyplist[i]) + typlist.append(self.typlist[i]) + fmtlist.append(self.fmtlist[i]) + lbllist.append(self.lbllist[i]) + + data = data[columns] + self.dtyplist = dtyplist + self.typlist = typlist + self.fmtlist = fmtlist + self.lbllist = lbllist + for col, typ in zip(data, self.typlist): if type(typ) is int: data[col] = data[col].apply(self._null_terminate, convert_dtype=True,) diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index c5727a5579b79..2cb7809166be5 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -720,12 +720,29 @@ def test_dtype_conversion(self): tm.assert_frame_equal(expected, conversion) + def test_drop_column(self): + expected = self.read_csv(self.csv15) + expected['byte_'] = expected['byte_'].astype(np.int8) + expected['int_'] = expected['int_'].astype(np.int16) + expected['long_'] = expected['long_'].astype(np.int32) + expected['float_'] = expected['float_'].astype(np.float32) + expected['double_'] = expected['double_'].astype(np.float64) + expected['date_td'] = expected['date_td'].apply(datetime.strptime, + args=('%Y-%m-%d',)) + columns = ['byte_', 'int_', 'long_'] + expected = expected[columns] + dropped = read_stata(self.dta15_117, convert_dates=True, + columns=columns) + tm.assert_frame_equal(expected, dropped) + with tm.assertRaises(ValueError): + columns = ['byte_', 'byte_'] + read_stata(self.dta15_117, convert_dates=True, columns=columns) - - - + with tm.assertRaises(ValueError): + columns = ['byte_', 'int_', 'long_', 'not_found'] + read_stata(self.dta15_117, convert_dates=True, columns=columns) if __name__ == '__main__': nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],