29
29
30
30
def read_stata (filepath_or_buffer , convert_dates = True ,
31
31
convert_categoricals = True , encoding = None , index = None ,
32
- convert_missing = False , preserve_dtypes = True ):
32
+ convert_missing = False , preserve_dtypes = True , columns = None ):
33
33
"""
34
34
Read Stata file into DataFrame
35
35
@@ -55,11 +55,14 @@ def read_stata(filepath_or_buffer, convert_dates=True,
55
55
preserve_dtypes : boolean, defaults to True
56
56
Preserve Stata datatypes. If False, numeric data are upcast to pandas
57
57
default types for foreign data (float64 or int64)
58
+ columns : list or None
59
+ Columns to retain. Columns will be returned in the given order. None
60
+ returns all columns
58
61
"""
59
62
reader = StataReader (filepath_or_buffer , encoding )
60
63
61
64
return reader .data (convert_dates , convert_categoricals , index ,
62
- convert_missing , preserve_dtypes )
65
+ convert_missing , preserve_dtypes , columns )
63
66
64
67
_date_formats = ["%tc" , "%tC" , "%td" , "%d" , "%tw" , "%tm" , "%tq" , "%th" , "%ty" ]
65
68
@@ -977,7 +980,7 @@ def _read_strls(self):
977
980
self .path_or_buf .read (1 ) # zero-termination
978
981
979
982
def data (self , convert_dates = True , convert_categoricals = True , index = None ,
980
- convert_missing = False , preserve_dtypes = True ):
983
+ convert_missing = False , preserve_dtypes = True , columns = None ):
981
984
"""
982
985
Reads observations from Stata file, converting them into a dataframe
983
986
@@ -999,6 +1002,10 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
999
1002
preserve_dtypes : boolean, defaults to True
1000
1003
Preserve Stata datatypes. If False, numeric data are upcast to
1001
1004
pandas default types for foreign data (float64 or int64)
1005
+ columns : list or None
1006
+ Columns to retain. Columns will be returned in the given order.
1007
+ None returns all columns
1008
+
1002
1009
Returns
1003
1010
-------
1004
1011
y : DataFrame instance
@@ -1034,6 +1041,35 @@ def data(self, convert_dates=True, convert_categoricals=True, index=None,
1034
1041
data = DataFrame .from_records (data , index = index )
1035
1042
data .columns = self .varlist
1036
1043
1044
+ if columns is not None :
1045
+ column_set = set (columns )
1046
+ if len (column_set ) != len (columns ):
1047
+ raise ValueError ('columns contains duplicate entries' )
1048
+ unmatched = column_set .difference (data .columns )
1049
+ if unmatched :
1050
+ raise ValueError ('The following columns were not found in the '
1051
+ 'Stata data set: ' +
1052
+ ', ' .join (list (unmatched )))
1053
+ # Copy information for retained columns for later processing
1054
+ dtyplist = []
1055
+ typlist = []
1056
+ fmtlist = []
1057
+ lbllist = []
1058
+ matched = set ()
1059
+ for i , col in enumerate (data .columns ):
1060
+ if col in column_set :
1061
+ matched .update ([col ])
1062
+ dtyplist .append (self .dtyplist [i ])
1063
+ typlist .append (self .typlist [i ])
1064
+ fmtlist .append (self .fmtlist [i ])
1065
+ lbllist .append (self .lbllist [i ])
1066
+
1067
+ data = data [columns ]
1068
+ self .dtyplist = dtyplist
1069
+ self .typlist = typlist
1070
+ self .fmtlist = fmtlist
1071
+ self .lbllist = lbllist
1072
+
1037
1073
for col , typ in zip (data , self .typlist ):
1038
1074
if type (typ ) is int :
1039
1075
data [col ] = data [col ].apply (self ._null_terminate , convert_dtype = True ,)
0 commit comments