Skip to content

Commit d87c762

Browse files
committed
CLN: StataReader: refactor repeated struct.unpack/read calls to helpers
1 parent cedd122 commit d87c762

File tree

1 file changed

+69
-71
lines changed

1 file changed

+69
-71
lines changed

pandas/io/stata.py

+69-71
Original file line numberDiff line numberDiff line change
@@ -1201,9 +1201,42 @@ def _set_encoding(self) -> None:
12011201
else:
12021202
self._encoding = "utf-8"
12031203

1204+
def _read_int8(self) -> int:
1205+
return struct.unpack("b", self.path_or_buf.read(1))[0]
1206+
1207+
def _read_uint8(self) -> int:
1208+
return struct.unpack("B", self.path_or_buf.read(1))[0]
1209+
1210+
def _read_uint16(self) -> int:
1211+
return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0]
1212+
1213+
def _read_uint32(self) -> int:
1214+
return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0]
1215+
1216+
def _read_uint64(self) -> int:
1217+
return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0]
1218+
1219+
def _read_int16(self) -> int:
1220+
return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0]
1221+
1222+
def _read_int32(self) -> int:
1223+
return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0]
1224+
1225+
def _read_int64(self) -> int:
1226+
return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0]
1227+
1228+
def _read_char8(self) -> bytes:
1229+
return struct.unpack("c", self.path_or_buf.read(1))[0]
1230+
1231+
def _read_int16_count(self, count: int) -> tuple[int, ...]:
1232+
return struct.unpack(
1233+
f"{self.byteorder}{'h' * count}",
1234+
self.path_or_buf.read(2 * count),
1235+
)
1236+
12041237
def _read_header(self) -> None:
1205-
first_char = self.path_or_buf.read(1)
1206-
if struct.unpack("c", first_char)[0] == b"<":
1238+
first_char = self._read_char8()
1239+
if first_char == b"<":
12071240
self._read_new_header()
12081241
else:
12091242
self._read_old_header(first_char)
@@ -1223,11 +1256,9 @@ def _read_new_header(self) -> None:
12231256
self.path_or_buf.read(21) # </release><byteorder>
12241257
self.byteorder = self.path_or_buf.read(3) == b"MSF" and ">" or "<"
12251258
self.path_or_buf.read(15) # </byteorder><K>
1226-
nvar_type = "H" if self.format_version <= 118 else "I"
1227-
nvar_size = 2 if self.format_version <= 118 else 4
1228-
self.nvar = struct.unpack(
1229-
self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1230-
)[0]
1259+
self.nvar = (
1260+
self._read_uint16() if self.format_version <= 118 else self._read_uint32()
1261+
)
12311262
self.path_or_buf.read(7) # </K><N>
12321263

12331264
self.nobs = self._get_nobs()
@@ -1239,46 +1270,27 @@ def _read_new_header(self) -> None:
12391270
self.path_or_buf.read(8) # 0x0000000000000000
12401271
self.path_or_buf.read(8) # position of <map>
12411272

1242-
self._seek_vartypes = (
1243-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1244-
)
1245-
self._seek_varnames = (
1246-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1247-
)
1248-
self._seek_sortlist = (
1249-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1250-
)
1251-
self._seek_formats = (
1252-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1253-
)
1254-
self._seek_value_label_names = (
1255-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1256-
)
1273+
self._seek_vartypes = self._read_int64() + 16
1274+
self._seek_varnames = self._read_int64() + 10
1275+
self._seek_sortlist = self._read_int64() + 10
1276+
self._seek_formats = self._read_int64() + 9
1277+
self._seek_value_label_names = self._read_int64() + 19
12571278

12581279
# Requires version-specific treatment
12591280
self._seek_variable_labels = self._get_seek_variable_labels()
12601281

12611282
self.path_or_buf.read(8) # <characteristics>
1262-
self.data_location = (
1263-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1264-
)
1265-
self.seek_strls = (
1266-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1267-
)
1268-
self.seek_value_labels = (
1269-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1270-
)
1283+
self.data_location = self._read_int64() + 6
1284+
self.seek_strls = self._read_int64() + 7
1285+
self.seek_value_labels = self._read_int64() + 14
12711286

12721287
self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
12731288

12741289
self.path_or_buf.seek(self._seek_varnames)
12751290
self.varlist = self._get_varlist()
12761291

12771292
self.path_or_buf.seek(self._seek_sortlist)
1278-
self.srtlist = struct.unpack(
1279-
self.byteorder + ("h" * (self.nvar + 1)),
1280-
self.path_or_buf.read(2 * (self.nvar + 1)),
1281-
)[:-1]
1293+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
12821294

12831295
self.path_or_buf.seek(self._seek_formats)
12841296
self.fmtlist = self._get_fmtlist()
@@ -1295,10 +1307,7 @@ def _get_dtypes(
12951307
) -> tuple[list[int | str], list[str | np.dtype]]:
12961308

12971309
self.path_or_buf.seek(seek_vartypes)
1298-
raw_typlist = [
1299-
struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1300-
for _ in range(self.nvar)
1301-
]
1310+
raw_typlist = [self._read_uint16() for _ in range(self.nvar)]
13021311

13031312
def f(typ: int) -> int | str:
13041313
if typ <= 2045:
@@ -1367,16 +1376,16 @@ def _get_variable_labels(self) -> list[str]:
13671376

13681377
def _get_nobs(self) -> int:
13691378
if self.format_version >= 118:
1370-
return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1379+
return self._read_uint64()
13711380
else:
1372-
return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1381+
return self._read_uint32()
13731382

13741383
def _get_data_label(self) -> str:
13751384
if self.format_version >= 118:
1376-
strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1385+
strlen = self._read_uint16()
13771386
return self._decode(self.path_or_buf.read(strlen))
13781387
elif self.format_version == 117:
1379-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1388+
strlen = self._read_int8()
13801389
return self._decode(self.path_or_buf.read(strlen))
13811390
elif self.format_version > 105:
13821391
return self._decode(self.path_or_buf.read(81))
@@ -1385,10 +1394,10 @@ def _get_data_label(self) -> str:
13851394

13861395
def _get_time_stamp(self) -> str:
13871396
if self.format_version >= 118:
1388-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1397+
strlen = self._read_int8()
13891398
return self.path_or_buf.read(strlen).decode("utf-8")
13901399
elif self.format_version == 117:
1391-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1400+
strlen = self._read_int8()
13921401
return self._decode(self.path_or_buf.read(strlen))
13931402
elif self.format_version > 104:
13941403
return self._decode(self.path_or_buf.read(18))
@@ -1403,22 +1412,20 @@ def _get_seek_variable_labels(self) -> int:
14031412
# variable, 20 for the closing tag and 17 for the opening tag
14041413
return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
14051414
elif self.format_version >= 118:
1406-
return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1415+
return self._read_int64() + 17
14071416
else:
14081417
raise ValueError()
14091418

14101419
def _read_old_header(self, first_char: bytes) -> None:
1411-
self.format_version = struct.unpack("b", first_char)[0]
1420+
self.format_version = int(first_char[0])
14121421
if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
14131422
raise ValueError(_version_error.format(version=self.format_version))
14141423
self._set_encoding()
1415-
self.byteorder = (
1416-
struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" or "<"
1417-
)
1418-
self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1424+
self.byteorder = self._read_int8() == 0x1 and ">" or "<"
1425+
self.filetype = self._read_int8()
14191426
self.path_or_buf.read(1) # unused
14201427

1421-
self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1428+
self.nvar = self._read_uint16()
14221429
self.nobs = self._get_nobs()
14231430

14241431
self._data_label = self._get_data_label()
@@ -1427,7 +1434,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14271434

14281435
# descriptors
14291436
if self.format_version > 108:
1430-
typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1437+
typlist = [int(c) for c in self.path_or_buf.read(self.nvar)]
14311438
else:
14321439
buf = self.path_or_buf.read(self.nvar)
14331440
typlistb = np.frombuffer(buf, dtype=np.uint8)
@@ -1457,10 +1464,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14571464
self.varlist = [
14581465
self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
14591466
]
1460-
self.srtlist = struct.unpack(
1461-
self.byteorder + ("h" * (self.nvar + 1)),
1462-
self.path_or_buf.read(2 * (self.nvar + 1)),
1463-
)[:-1]
1467+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
14641468

14651469
self.fmtlist = self._get_fmtlist()
14661470

@@ -1475,17 +1479,11 @@ def _read_old_header(self, first_char: bytes) -> None:
14751479

14761480
if self.format_version > 104:
14771481
while True:
1478-
data_type = struct.unpack(
1479-
self.byteorder + "b", self.path_or_buf.read(1)
1480-
)[0]
1482+
data_type = self._read_int8()
14811483
if self.format_version > 108:
1482-
data_len = struct.unpack(
1483-
self.byteorder + "i", self.path_or_buf.read(4)
1484-
)[0]
1484+
data_len = self._read_int32()
14851485
else:
1486-
data_len = struct.unpack(
1487-
self.byteorder + "h", self.path_or_buf.read(2)
1488-
)[0]
1486+
data_len = self._read_int16()
14891487
if data_type == 0:
14901488
break
14911489
self.path_or_buf.read(data_len)
@@ -1569,8 +1567,8 @@ def _read_value_labels(self) -> None:
15691567
labname = self._decode(self.path_or_buf.read(129))
15701568
self.path_or_buf.read(3) # padding
15711569

1572-
n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1573-
txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1570+
n = self._read_uint32()
1571+
txtlen = self._read_uint32()
15741572
off = np.frombuffer(
15751573
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
15761574
)
@@ -1598,7 +1596,7 @@ def _read_strls(self) -> None:
15981596
break
15991597

16001598
if self.format_version == 117:
1601-
v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1599+
v_o = self._read_uint64()
16021600
else:
16031601
buf = self.path_or_buf.read(12)
16041602
# Only tested on little endian file on little endian machine.
@@ -1609,8 +1607,8 @@ def _read_strls(self) -> None:
16091607
# This path may not be correct, impossible to test
16101608
buf = buf[0:v_size] + buf[(4 + v_size) :]
16111609
v_o = struct.unpack("Q", buf)[0]
1612-
typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1613-
length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1610+
typ = self._read_uint8()
1611+
length = self._read_uint32()
16141612
va = self.path_or_buf.read(length)
16151613
if typ == 130:
16161614
decoded_va = va[0:-1].decode(self._encoding)

0 commit comments

Comments
 (0)