Skip to content

Commit a2c36c1

Browse files
committed
Merge pull request #8577 from bashtage/stata-columns
ENH: Allow columns selection in read_stata
2 parents 8336e36 + 902179d commit a2c36c1

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

doc/source/v0.15.1.txt

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ API changes
2626
Enhancements
2727
~~~~~~~~~~~~
2828

29+
- Added option to select columns when importing Stata files (:issue:`7935`)
30+
2931

3032
.. _whatsnew_0151.performance:
3133

pandas/io/stata.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def read_stata(filepath_or_buffer, convert_dates=True,
3131
convert_categoricals=True, encoding=None, index=None,
32-
convert_missing=False, preserve_dtypes=True):
32+
convert_missing=False, preserve_dtypes=True, columns=None):
3333
"""
3434
Read Stata file into DataFrame
3535
@@ -55,11 +55,14 @@ def read_stata(filepath_or_buffer, convert_dates=True,
5555
preserve_dtypes : boolean, defaults to True
5656
Preserve Stata datatypes. If False, numeric data are upcast to pandas
5757
default types for foreign data (float64 or int64)
58+
columns : list or None
59+
Columns to retain. Columns will be returned in the given order. None
60+
returns all columns
5861
"""
5962
reader = StataReader(filepath_or_buffer, encoding)
6063

6164
return reader.data(convert_dates, convert_categoricals, index,
62-
convert_missing, preserve_dtypes)
65+
convert_missing, preserve_dtypes, columns)
6366

6467
_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"]
6568

@@ -977,7 +980,7 @@ def _read_strls(self):
977980
self.path_or_buf.read(1) # zero-termination
978981

979982
def data(self, convert_dates=True, convert_categoricals=True, index=None,
980-
convert_missing=False, preserve_dtypes=True):
983+
convert_missing=False, preserve_dtypes=True, columns=None):
981984
"""
982985
Reads observations from Stata file, converting them into a dataframe
983986
@@ -999,6 +1002,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
9991002
preserve_dtypes : boolean, defaults to True
10001003
Preserve Stata datatypes. If False, numeric data are upcast to
10011004
pandas default types for foreign data (float64 or int64)
1005+
columns : list or None
1006+
Columns to retain. Columns will be returned in the given order.
1007+
None returns all columns
1008+
10021009
Returns
10031010
-------
10041011
y : DataFrame instance
@@ -1034,6 +1041,35 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
10341041
data = DataFrame.from_records(data, index=index)
10351042
data.columns = self.varlist
10361043

1044+
if columns is not None:
1045+
column_set = set(columns)
1046+
if len(column_set) != len(columns):
1047+
raise ValueError('columns contains duplicate entries')
1048+
unmatched = column_set.difference(data.columns)
1049+
if unmatched:
1050+
raise ValueError('The following columns were not found in the '
1051+
'Stata data set: ' +
1052+
', '.join(list(unmatched)))
1053+
# Copy information for retained columns for later processing
1054+
dtyplist = []
1055+
typlist = []
1056+
fmtlist = []
1057+
lbllist = []
1058+
matched = set()
1059+
for i, col in enumerate(data.columns):
1060+
if col in column_set:
1061+
matched.update([col])
1062+
dtyplist.append(self.dtyplist[i])
1063+
typlist.append(self.typlist[i])
1064+
fmtlist.append(self.fmtlist[i])
1065+
lbllist.append(self.lbllist[i])
1066+
1067+
data = data[columns]
1068+
self.dtyplist = dtyplist
1069+
self.typlist = typlist
1070+
self.fmtlist = fmtlist
1071+
self.lbllist = lbllist
1072+
10371073
for col, typ in zip(data, self.typlist):
10381074
if type(typ) is int:
10391075
data[col] = data[col].apply(self._null_terminate, convert_dtype=True,)

pandas/io/tests/test_stata.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,29 @@ def test_dtype_conversion(self):
720720

721721
tm.assert_frame_equal(expected, conversion)
722722

723+
def test_drop_column(self):
724+
expected = self.read_csv(self.csv15)
725+
expected['byte_'] = expected['byte_'].astype(np.int8)
726+
expected['int_'] = expected['int_'].astype(np.int16)
727+
expected['long_'] = expected['long_'].astype(np.int32)
728+
expected['float_'] = expected['float_'].astype(np.float32)
729+
expected['double_'] = expected['double_'].astype(np.float64)
730+
expected['date_td'] = expected['date_td'].apply(datetime.strptime,
731+
args=('%Y-%m-%d',))
723732

733+
columns = ['byte_', 'int_', 'long_']
734+
expected = expected[columns]
735+
dropped = read_stata(self.dta15_117, convert_dates=True,
736+
columns=columns)
724737

738+
tm.assert_frame_equal(expected, dropped)
739+
with tm.assertRaises(ValueError):
740+
columns = ['byte_', 'byte_']
741+
read_stata(self.dta15_117, convert_dates=True, columns=columns)
725742

726-
727-
728-
743+
with tm.assertRaises(ValueError):
744+
columns = ['byte_', 'int_', 'long_', 'not_found']
745+
read_stata(self.dta15_117, convert_dates=True, columns=columns)
729746

730747
if __name__ == '__main__':
731748
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],

0 commit comments

Comments
 (0)