@@ -375,6 +375,18 @@ def __init__(self, encoding):
375
375
'd' : np .float64 (struct .unpack ('<d' , b'\x00 \x00 \x00 \x00 \x00 \x00 \xe0 \x7f ' )[0 ])
376
376
}
377
377
378
+ # Reserved words cannot be used as variable names
379
+ self .RESERVED_WORDS = ('aggregate' , 'array' , 'boolean' , 'break' ,
380
+ 'byte' , 'case' , 'catch' , 'class' , 'colvector' ,
381
+ 'complex' , 'const' , 'continue' , 'default' ,
382
+ 'delegate' , 'delete' , 'do' , 'double' , 'else' ,
383
+ 'eltypedef' , 'end' , 'enum' , 'explicit' ,
384
+ 'export' , 'external' , 'float' , 'for' , 'friend' ,
385
+ 'function' , 'global' , 'goto' , 'if' , 'inline' ,
386
+ 'int' , 'local' , 'long' , 'NULL' , 'pragma' ,
387
+ 'protected' , 'quad' , 'rowvector' , 'short' ,
388
+ 'typedef' , 'typename' , 'virtual' )
389
+
378
390
def _decode_bytes (self , str , errors = None ):
379
391
if compat .PY3 or self ._encoding is not None :
380
392
return str .decode (self ._encoding , errors )
@@ -449,10 +461,10 @@ def _read_header(self):
449
461
self .path_or_buf .read (4 ))[0 ]
450
462
self .path_or_buf .read (11 ) # </N><label>
451
463
strlen = struct .unpack ('b' , self .path_or_buf .read (1 ))[0 ]
452
- self .data_label = self .path_or_buf .read (strlen )
464
+ self .data_label = self ._null_terminate ( self . path_or_buf .read (strlen ) )
453
465
self .path_or_buf .read (19 ) # </label><timestamp>
454
466
strlen = struct .unpack ('b' , self .path_or_buf .read (1 ))[0 ]
455
- self .time_stamp = self .path_or_buf .read (strlen )
467
+ self .time_stamp = self ._null_terminate ( self . path_or_buf .read (strlen ) )
456
468
self .path_or_buf .read (26 ) # </timestamp></header><map>
457
469
self .path_or_buf .read (8 ) # 0x0000000000000000
458
470
self .path_or_buf .read (8 ) # position of <map>
@@ -543,11 +555,11 @@ def _read_header(self):
543
555
self .nobs = struct .unpack (self .byteorder + 'I' ,
544
556
self .path_or_buf .read (4 ))[0 ]
545
557
if self .format_version > 105 :
546
- self .data_label = self .path_or_buf .read (81 )
558
+ self .data_label = self ._null_terminate ( self . path_or_buf .read (81 ) )
547
559
else :
548
- self .data_label = self .path_or_buf .read (32 )
560
+ self .data_label = self ._null_terminate ( self . path_or_buf .read (32 ) )
549
561
if self .format_version > 104 :
550
- self .time_stamp = self .path_or_buf .read (18 )
562
+ self .time_stamp = self ._null_terminate ( self . path_or_buf .read (18 ) )
551
563
552
564
# descriptors
553
565
if self .format_version > 108 :
@@ -1029,6 +1041,11 @@ class StataWriter(StataParser):
1029
1041
byteorder : str
1030
1042
Can be ">", "<", "little", or "big". The default is None which uses
1031
1043
`sys.byteorder`
1044
+ time_stamp : datetime
1045
+ A date time to use when writing the file. Can be None, in which
1046
+ case the current time is used.
1047
+ dataset_label : str
1048
+ A label for the data set. Should be 80 characters or smaller.
1032
1049
1033
1050
Returns
1034
1051
-------
@@ -1047,10 +1064,13 @@ class StataWriter(StataParser):
1047
1064
>>> writer.write_file()
1048
1065
"""
1049
1066
def __init__ (self , fname , data , convert_dates = None , write_index = True ,
1050
- encoding = "latin-1" , byteorder = None ):
1067
+ encoding = "latin-1" , byteorder = None , time_stamp = None ,
1068
+ data_label = None ):
1051
1069
super (StataWriter , self ).__init__ (encoding )
1052
1070
self ._convert_dates = convert_dates
1053
1071
self ._write_index = write_index
1072
+ self ._time_stamp = time_stamp
1073
+ self ._data_label = data_label
1054
1074
# attach nobs, nvars, data, varlist, typlist
1055
1075
self ._prepare_pandas (data )
1056
1076
@@ -1086,7 +1106,7 @@ def __iter__(self):
1086
1106
1087
1107
if self ._write_index :
1088
1108
data = data .reset_index ()
1089
- # Check columns for compatbaility with stata
1109
+ # Check columns for compatibility with stata
1090
1110
data = _cast_to_stata_types (data )
1091
1111
self .datarows = DataFrameRowIter (data )
1092
1112
self .nobs , self .nvar = data .shape
@@ -1110,7 +1130,8 @@ def __iter__(self):
1110
1130
self .fmtlist [key ] = self ._convert_dates [key ]
1111
1131
1112
1132
def write_file (self ):
1113
- self ._write_header ()
1133
+ self ._write_header (time_stamp = self ._time_stamp ,
1134
+ data_label = self ._data_label )
1114
1135
self ._write_descriptors ()
1115
1136
self ._write_variable_labels ()
1116
1137
# write 5 zeros for expansion fields
@@ -1147,7 +1168,7 @@ def _write_header(self, data_label=None, time_stamp=None):
1147
1168
# format dd Mon yyyy hh:mm
1148
1169
if time_stamp is None :
1149
1170
time_stamp = datetime .datetime .now ()
1150
- elif not isinstance (time_stamp , datetime ):
1171
+ elif not isinstance (time_stamp , datetime . datetime ):
1151
1172
raise ValueError ("time_stamp should be datetime type" )
1152
1173
self ._file .write (
1153
1174
self ._null_terminate (time_stamp .strftime ("%d %b %Y %H:%M" ))
@@ -1169,7 +1190,9 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
1169
1190
for c in name :
1170
1191
if (c < 'A' or c > 'Z' ) and (c < 'a' or c > 'z' ) and (c < '0' or c > '9' ) and c != '_' :
1171
1192
name = name .replace (c , '_' )
1172
-
1193
+ # Variable name must not be a reserved word
1194
+ if name in self .RESERVED_WORDS :
1195
+ name = '_' + name
1173
1196
# Variable name may not start with a number
1174
1197
if name [0 ] > '0' and name [0 ] < '9' :
1175
1198
name = '_' + name
0 commit comments