16
16
from collections import abc
17
17
from datetime import datetime , timedelta
18
18
import struct
19
+ from typing import IO , Any , Union
19
20
20
21
import numpy as np
21
22
@@ -62,12 +63,42 @@ def _convert_datetimes(sas_datetimes: pd.Series, unit: str) -> pd.Series:
62
63
raise ValueError ("unit must be 'd' or 's'" )
63
64
64
65
65
- class _subheader_pointer :
66
- pass
66
+ class _SubheaderPointer :
67
+ offset : int
68
+ length : int
69
+ compression : int
70
+ ptype : int
67
71
72
+ def __init__ (self , offset : int , length : int , compression : int , ptype : int ):
73
+ self .offset = offset
74
+ self .length = length
75
+ self .compression = compression
76
+ self .ptype = ptype
68
77
69
- class _column :
70
- pass
78
+
79
+ class _Column :
80
+ col_id : int
81
+ name : Union [str , bytes ]
82
+ label : Union [str , bytes ]
83
+ format : Union [str , bytes ] # TODO: i think allowing bytes is from py2 days
84
+ ctype : bytes
85
+ length : int
86
+
87
+ def __init__ (
88
+ self ,
89
+ col_id : int ,
90
+ name : Union [str , bytes ],
91
+ label : Union [str , bytes ],
92
+ format : Union [str , bytes ],
93
+ ctype : bytes ,
94
+ length : int ,
95
+ ):
96
+ self .col_id = col_id
97
+ self .name = name
98
+ self .label = label
99
+ self .format = format
100
+ self .ctype = ctype
101
+ self .length = length
71
102
72
103
73
104
# SAS7BDAT represents a SAS data file in SAS7BDAT format.
@@ -100,6 +131,8 @@ class SAS7BDATReader(ReaderBase, abc.Iterator):
100
131
bytes.
101
132
"""
102
133
134
+ _path_or_buf : IO [Any ]
135
+
103
136
def __init__ (
104
137
self ,
105
138
path_or_buf ,
@@ -121,7 +154,7 @@ def __init__(
121
154
self .convert_header_text = convert_header_text
122
155
123
156
self .default_encoding = "latin-1"
124
- self .compression = ""
157
+ self .compression = b ""
125
158
self .column_names_strings = []
126
159
self .column_names = []
127
160
self .column_formats = []
@@ -137,10 +170,14 @@ def __init__(
137
170
self ._current_row_on_page_index = 0
138
171
self ._current_row_in_file_index = 0
139
172
140
- self ._path_or_buf = get_filepath_or_buffer (path_or_buf ).filepath_or_buffer
141
- if isinstance (self ._path_or_buf , str ):
142
- self ._path_or_buf = open (self ._path_or_buf , "rb" )
143
- self .handle = self ._path_or_buf
173
+ path_or_buf = get_filepath_or_buffer (path_or_buf ).filepath_or_buffer
174
+ if isinstance (path_or_buf , str ):
175
+ buf = open (path_or_buf , "rb" )
176
+ self .handle = buf
177
+ else :
178
+ buf = path_or_buf
179
+
180
+ self ._path_or_buf : IO [Any ] = buf
144
181
145
182
try :
146
183
self ._get_properties ()
@@ -319,7 +356,7 @@ def _read_float(self, offset, width):
319
356
return struct .unpack (self .byte_order + fd , buf )[0 ]
320
357
321
358
# Read a single signed integer of the given width (1, 2, 4 or 8).
322
- def _read_int (self , offset , width ) :
359
+ def _read_int (self , offset : int , width : int ) -> int :
323
360
if width not in (1 , 2 , 4 , 8 ):
324
361
self .close ()
325
362
raise ValueError ("invalid int width" )
@@ -328,7 +365,7 @@ def _read_int(self, offset, width):
328
365
iv = struct .unpack (self .byte_order + it , buf )[0 ]
329
366
return iv
330
367
331
- def _read_bytes (self , offset , length ):
368
+ def _read_bytes (self , offset : int , length : int ):
332
369
if self ._cached_page is None :
333
370
self ._path_or_buf .seek (offset )
334
371
buf = self ._path_or_buf .read (length )
@@ -400,14 +437,14 @@ def _get_subheader_index(self, signature, compression, ptype):
400
437
if index is None :
401
438
f1 = (compression == const .compressed_subheader_id ) or (compression == 0 )
402
439
f2 = ptype == const .compressed_subheader_type
403
- if (self .compression != "" ) and f1 and f2 :
440
+ if (self .compression != b "" ) and f1 and f2 :
404
441
index = const .SASIndex .data_subheader_index
405
442
else :
406
443
self .close ()
407
444
raise ValueError ("Unknown subheader signature" )
408
445
return index
409
446
410
- def _process_subheader_pointers (self , offset , subheader_pointer_index ):
447
+ def _process_subheader_pointers (self , offset : int , subheader_pointer_index : int ):
411
448
412
449
subheader_pointer_length = self ._subheader_pointer_length
413
450
total_offset = offset + subheader_pointer_length * subheader_pointer_index
@@ -423,11 +460,9 @@ def _process_subheader_pointers(self, offset, subheader_pointer_index):
423
460
424
461
subheader_type = self ._read_int (total_offset , 1 )
425
462
426
- x = _subheader_pointer ()
427
- x .offset = subheader_offset
428
- x .length = subheader_length
429
- x .compression = subheader_compression
430
- x .ptype = subheader_type
463
+ x = _SubheaderPointer (
464
+ subheader_offset , subheader_length , subheader_compression , subheader_type
465
+ )
431
466
432
467
return x
433
468
@@ -519,7 +554,7 @@ def _process_columntext_subheader(self, offset, length):
519
554
self .column_names_strings .append (cname )
520
555
521
556
if len (self .column_names_strings ) == 1 :
522
- compression_literal = ""
557
+ compression_literal = b ""
523
558
for cl in const .compression_literals :
524
559
if cl in cname_raw :
525
560
compression_literal = cl
@@ -532,7 +567,7 @@ def _process_columntext_subheader(self, offset, length):
532
567
533
568
buf = self ._read_bytes (offset1 , self ._lcp )
534
569
compression_literal = buf .rstrip (b"\x00 " )
535
- if compression_literal == "" :
570
+ if compression_literal == b "" :
536
571
self ._lcs = 0
537
572
offset1 = offset + 32
538
573
if self .U64 :
@@ -657,13 +692,14 @@ def _process_format_subheader(self, offset, length):
657
692
column_format = format_names [format_start : format_start + format_len ]
658
693
current_column_number = len (self .columns )
659
694
660
- col = _column ()
661
- col .col_id = current_column_number
662
- col .name = self .column_names [current_column_number ]
663
- col .label = column_label
664
- col .format = column_format
665
- col .ctype = self ._column_types [current_column_number ]
666
- col .length = self ._column_data_lengths [current_column_number ]
695
+ col = _Column (
696
+ current_column_number ,
697
+ self .column_names [current_column_number ],
698
+ column_label ,
699
+ column_format ,
700
+ self ._column_types [current_column_number ],
701
+ self ._column_data_lengths [current_column_number ],
702
+ )
667
703
668
704
self .column_formats .append (column_format )
669
705
self .columns .append (col )
0 commit comments