Skip to content

Commit a1e5304

Browse files
authored
TYP: sas, stata, style (#36990)
1 parent a9f8bea commit a1e5304

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

pandas/io/formats/format.py

+1
Original file line numberDiff line numberDiff line change
@@ -1407,6 +1407,7 @@ def _value_formatter(
14071407
if float_format:
14081408

14091409
def base_formatter(v):
1410+
assert float_format is not None # for mypy
14101411
return float_format(value=v) if notna(v) else self.na_rep
14111412

14121413
else:

pandas/io/formats/style.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,10 @@ def from_custom_template(cls, searchpath, name):
15111511
"""
15121512
loader = jinja2.ChoiceLoader([jinja2.FileSystemLoader(searchpath), cls.loader])
15131513

1514-
class MyStyler(cls):
1514+
# mypy doesnt like dynamically-defined class
1515+
# error: Variable "cls" is not valid as a type [valid-type]
1516+
# error: Invalid base class "cls" [misc]
1517+
class MyStyler(cls): # type:ignore[valid-type,misc]
15151518
env = jinja2.Environment(loader=loader)
15161519
template = env.get_template(name)
15171520

pandas/io/sas/sas7bdat.py

+63-27
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import abc
1717
from datetime import datetime, timedelta
1818
import struct
19+
from typing import IO, Any, Union
1920

2021
import numpy as np
2122

@@ -62,12 +63,42 @@ def _convert_datetimes(sas_datetimes: pd.Series, unit: str) -> pd.Series:
6263
raise ValueError("unit must be 'd' or 's'")
6364

6465

65-
class _subheader_pointer:
66-
pass
66+
class _SubheaderPointer:
67+
offset: int
68+
length: int
69+
compression: int
70+
ptype: int
6771

72+
def __init__(self, offset: int, length: int, compression: int, ptype: int):
73+
self.offset = offset
74+
self.length = length
75+
self.compression = compression
76+
self.ptype = ptype
6877

69-
class _column:
70-
pass
78+
79+
class _Column:
80+
col_id: int
81+
name: Union[str, bytes]
82+
label: Union[str, bytes]
83+
format: Union[str, bytes] # TODO: i think allowing bytes is from py2 days
84+
ctype: bytes
85+
length: int
86+
87+
def __init__(
88+
self,
89+
col_id: int,
90+
name: Union[str, bytes],
91+
label: Union[str, bytes],
92+
format: Union[str, bytes],
93+
ctype: bytes,
94+
length: int,
95+
):
96+
self.col_id = col_id
97+
self.name = name
98+
self.label = label
99+
self.format = format
100+
self.ctype = ctype
101+
self.length = length
71102

72103

73104
# SAS7BDAT represents a SAS data file in SAS7BDAT format.
@@ -100,6 +131,8 @@ class SAS7BDATReader(ReaderBase, abc.Iterator):
100131
bytes.
101132
"""
102133

134+
_path_or_buf: IO[Any]
135+
103136
def __init__(
104137
self,
105138
path_or_buf,
@@ -121,7 +154,7 @@ def __init__(
121154
self.convert_header_text = convert_header_text
122155

123156
self.default_encoding = "latin-1"
124-
self.compression = ""
157+
self.compression = b""
125158
self.column_names_strings = []
126159
self.column_names = []
127160
self.column_formats = []
@@ -137,10 +170,14 @@ def __init__(
137170
self._current_row_on_page_index = 0
138171
self._current_row_in_file_index = 0
139172

140-
self._path_or_buf = get_filepath_or_buffer(path_or_buf).filepath_or_buffer
141-
if isinstance(self._path_or_buf, str):
142-
self._path_or_buf = open(self._path_or_buf, "rb")
143-
self.handle = self._path_or_buf
173+
path_or_buf = get_filepath_or_buffer(path_or_buf).filepath_or_buffer
174+
if isinstance(path_or_buf, str):
175+
buf = open(path_or_buf, "rb")
176+
self.handle = buf
177+
else:
178+
buf = path_or_buf
179+
180+
self._path_or_buf: IO[Any] = buf
144181

145182
try:
146183
self._get_properties()
@@ -319,7 +356,7 @@ def _read_float(self, offset, width):
319356
return struct.unpack(self.byte_order + fd, buf)[0]
320357

321358
# Read a single signed integer of the given width (1, 2, 4 or 8).
322-
def _read_int(self, offset, width):
359+
def _read_int(self, offset: int, width: int) -> int:
323360
if width not in (1, 2, 4, 8):
324361
self.close()
325362
raise ValueError("invalid int width")
@@ -328,7 +365,7 @@ def _read_int(self, offset, width):
328365
iv = struct.unpack(self.byte_order + it, buf)[0]
329366
return iv
330367

331-
def _read_bytes(self, offset, length):
368+
def _read_bytes(self, offset: int, length: int):
332369
if self._cached_page is None:
333370
self._path_or_buf.seek(offset)
334371
buf = self._path_or_buf.read(length)
@@ -400,14 +437,14 @@ def _get_subheader_index(self, signature, compression, ptype):
400437
if index is None:
401438
f1 = (compression == const.compressed_subheader_id) or (compression == 0)
402439
f2 = ptype == const.compressed_subheader_type
403-
if (self.compression != "") and f1 and f2:
440+
if (self.compression != b"") and f1 and f2:
404441
index = const.SASIndex.data_subheader_index
405442
else:
406443
self.close()
407444
raise ValueError("Unknown subheader signature")
408445
return index
409446

410-
def _process_subheader_pointers(self, offset, subheader_pointer_index):
447+
def _process_subheader_pointers(self, offset: int, subheader_pointer_index: int):
411448

412449
subheader_pointer_length = self._subheader_pointer_length
413450
total_offset = offset + subheader_pointer_length * subheader_pointer_index
@@ -423,11 +460,9 @@ def _process_subheader_pointers(self, offset, subheader_pointer_index):
423460

424461
subheader_type = self._read_int(total_offset, 1)
425462

426-
x = _subheader_pointer()
427-
x.offset = subheader_offset
428-
x.length = subheader_length
429-
x.compression = subheader_compression
430-
x.ptype = subheader_type
463+
x = _SubheaderPointer(
464+
subheader_offset, subheader_length, subheader_compression, subheader_type
465+
)
431466

432467
return x
433468

@@ -519,7 +554,7 @@ def _process_columntext_subheader(self, offset, length):
519554
self.column_names_strings.append(cname)
520555

521556
if len(self.column_names_strings) == 1:
522-
compression_literal = ""
557+
compression_literal = b""
523558
for cl in const.compression_literals:
524559
if cl in cname_raw:
525560
compression_literal = cl
@@ -532,7 +567,7 @@ def _process_columntext_subheader(self, offset, length):
532567

533568
buf = self._read_bytes(offset1, self._lcp)
534569
compression_literal = buf.rstrip(b"\x00")
535-
if compression_literal == "":
570+
if compression_literal == b"":
536571
self._lcs = 0
537572
offset1 = offset + 32
538573
if self.U64:
@@ -657,13 +692,14 @@ def _process_format_subheader(self, offset, length):
657692
column_format = format_names[format_start : format_start + format_len]
658693
current_column_number = len(self.columns)
659694

660-
col = _column()
661-
col.col_id = current_column_number
662-
col.name = self.column_names[current_column_number]
663-
col.label = column_label
664-
col.format = column_format
665-
col.ctype = self._column_types[current_column_number]
666-
col.length = self._column_data_lengths[current_column_number]
695+
col = _Column(
696+
current_column_number,
697+
self.column_names[current_column_number],
698+
column_label,
699+
column_format,
700+
self._column_types[current_column_number],
701+
self._column_data_lengths[current_column_number],
702+
)
667703

668704
self.column_formats.append(column_format)
669705
self.columns.append(col)

pandas/io/sas/sas_xport.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,16 @@ def _read_header(self):
337337
obs_length = 0
338338
while len(fielddata) >= fieldnamelength:
339339
# pull data for one field
340-
field, fielddata = (
340+
fieldbytes, fielddata = (
341341
fielddata[:fieldnamelength],
342342
fielddata[fieldnamelength:],
343343
)
344344

345345
# rest at end gets ignored, so if field is short, pad out
346346
# to match struct pattern below
347-
field = field.ljust(140)
347+
fieldbytes = fieldbytes.ljust(140)
348348

349-
fieldstruct = struct.unpack(">hhhh8s40s8shhh2s8shhl52s", field)
349+
fieldstruct = struct.unpack(">hhhh8s40s8shhh2s8shhl52s", fieldbytes)
350350
field = dict(zip(_fieldkeys, fieldstruct))
351351
del field["_"]
352352
field["ntype"] = types[field["ntype"]]
@@ -408,8 +408,8 @@ def _record_count(self) -> int:
408408
return total_records_length // self.record_length
409409

410410
self.filepath_or_buffer.seek(-80, 2)
411-
last_card = self.filepath_or_buffer.read(80)
412-
last_card = np.frombuffer(last_card, dtype=np.uint64)
411+
last_card_bytes = self.filepath_or_buffer.read(80)
412+
last_card = np.frombuffer(last_card_bytes, dtype=np.uint64)
413413

414414
# 8 byte blank
415415
ix = np.flatnonzero(last_card == 2314885530818453536)
@@ -483,7 +483,7 @@ def read(self, nrows=None):
483483
df[x] = v
484484

485485
if self._index is None:
486-
df.index = range(self._lines_read, self._lines_read + read_lines)
486+
df.index = pd.Index(range(self._lines_read, self._lines_read + read_lines))
487487
else:
488488
df = df.set_index(self._index)
489489

pandas/io/stata.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,18 @@
1616
from pathlib import Path
1717
import struct
1818
import sys
19-
from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union
19+
from typing import (
20+
Any,
21+
AnyStr,
22+
BinaryIO,
23+
Dict,
24+
List,
25+
Optional,
26+
Sequence,
27+
Tuple,
28+
Union,
29+
cast,
30+
)
2031
import warnings
2132

2233
from dateutil.relativedelta import relativedelta
@@ -1389,6 +1400,7 @@ def _setup_dtype(self) -> np.dtype:
13891400
dtypes = [] # Convert struct data types to numpy data type
13901401
for i, typ in enumerate(self.typlist):
13911402
if typ in self.NUMPY_TYPE_MAP:
1403+
typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP
13921404
dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ]))
13931405
else:
13941406
dtypes.append(("s" + str(i), "S" + str(typ)))
@@ -1699,6 +1711,7 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra
16991711
if fmt not in self.VALID_RANGE:
17001712
continue
17011713

1714+
fmt = cast(str, fmt) # only strs in VALID_RANGE
17021715
nmin, nmax = self.VALID_RANGE[fmt]
17031716
series = data[colname]
17041717
missing = np.logical_or(series < nmin, series > nmax)

setup.cfg

-9
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,12 @@ check_untyped_defs=False
226226
[mypy-pandas.io.formats.format]
227227
check_untyped_defs=False
228228

229-
[mypy-pandas.io.formats.style]
230-
check_untyped_defs=False
231-
232229
[mypy-pandas.io.parsers]
233230
check_untyped_defs=False
234231

235232
[mypy-pandas.io.pytables]
236233
check_untyped_defs=False
237234

238-
[mypy-pandas.io.sas.sas_xport]
239-
check_untyped_defs=False
240-
241-
[mypy-pandas.io.sas.sas7bdat]
242-
check_untyped_defs=False
243-
244235
[mypy-pandas.io.stata]
245236
check_untyped_defs=False
246237

0 commit comments

Comments
 (0)