Skip to content

Commit b2245ec

Browse files
committed
CLN: StataReader: refactor repeated struct.unpack/read calls to helpers
1 parent 887d2c4 commit b2245ec

File tree

1 file changed

+69
-71
lines changed

1 file changed

+69
-71
lines changed

pandas/io/stata.py

+69-71
Original file line numberDiff line numberDiff line change
@@ -1198,9 +1198,42 @@ def _set_encoding(self) -> None:
11981198
else:
11991199
self._encoding = "utf-8"
12001200

1201+
def _read_int8(self) -> int:
1202+
return struct.unpack("b", self.path_or_buf.read(1))[0]
1203+
1204+
def _read_uint8(self) -> int:
1205+
return struct.unpack("B", self.path_or_buf.read(1))[0]
1206+
1207+
def _read_uint16(self) -> int:
1208+
return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0]
1209+
1210+
def _read_uint32(self) -> int:
1211+
return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0]
1212+
1213+
def _read_uint64(self) -> int:
1214+
return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0]
1215+
1216+
def _read_int16(self) -> int:
1217+
return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0]
1218+
1219+
def _read_int32(self) -> int:
1220+
return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0]
1221+
1222+
def _read_int64(self) -> int:
1223+
return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0]
1224+
1225+
def _read_char8(self) -> bytes:
1226+
return struct.unpack("c", self.path_or_buf.read(1))[0]
1227+
1228+
def _read_int16_count(self, count: int) -> tuple[int, ...]:
1229+
return struct.unpack(
1230+
f"{self.byteorder}{'h' * count}",
1231+
self.path_or_buf.read(2 * count),
1232+
)
1233+
12011234
def _read_header(self) -> None:
1202-
first_char = self.path_or_buf.read(1)
1203-
if struct.unpack("c", first_char)[0] == b"<":
1235+
first_char = self._read_char8()
1236+
if first_char == b"<":
12041237
self._read_new_header()
12051238
else:
12061239
self._read_old_header(first_char)
@@ -1220,11 +1253,9 @@ def _read_new_header(self) -> None:
12201253
self.path_or_buf.read(21) # </release><byteorder>
12211254
self.byteorder = ">" if self.path_or_buf.read(3) == b"MSF" else "<"
12221255
self.path_or_buf.read(15) # </byteorder><K>
1223-
nvar_type = "H" if self.format_version <= 118 else "I"
1224-
nvar_size = 2 if self.format_version <= 118 else 4
1225-
self.nvar = struct.unpack(
1226-
self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1227-
)[0]
1256+
self.nvar = (
1257+
self._read_uint16() if self.format_version <= 118 else self._read_uint32()
1258+
)
12281259
self.path_or_buf.read(7) # </K><N>
12291260

12301261
self.nobs = self._get_nobs()
@@ -1236,46 +1267,27 @@ def _read_new_header(self) -> None:
12361267
self.path_or_buf.read(8) # 0x0000000000000000
12371268
self.path_or_buf.read(8) # position of <map>
12381269

1239-
self._seek_vartypes = (
1240-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1241-
)
1242-
self._seek_varnames = (
1243-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1244-
)
1245-
self._seek_sortlist = (
1246-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1247-
)
1248-
self._seek_formats = (
1249-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1250-
)
1251-
self._seek_value_label_names = (
1252-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1253-
)
1270+
self._seek_vartypes = self._read_int64() + 16
1271+
self._seek_varnames = self._read_int64() + 10
1272+
self._seek_sortlist = self._read_int64() + 10
1273+
self._seek_formats = self._read_int64() + 9
1274+
self._seek_value_label_names = self._read_int64() + 19
12541275

12551276
# Requires version-specific treatment
12561277
self._seek_variable_labels = self._get_seek_variable_labels()
12571278

12581279
self.path_or_buf.read(8) # <characteristics>
1259-
self.data_location = (
1260-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1261-
)
1262-
self.seek_strls = (
1263-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1264-
)
1265-
self.seek_value_labels = (
1266-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1267-
)
1280+
self.data_location = self._read_int64() + 6
1281+
self.seek_strls = self._read_int64() + 7
1282+
self.seek_value_labels = self._read_int64() + 14
12681283

12691284
self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
12701285

12711286
self.path_or_buf.seek(self._seek_varnames)
12721287
self.varlist = self._get_varlist()
12731288

12741289
self.path_or_buf.seek(self._seek_sortlist)
1275-
self.srtlist = struct.unpack(
1276-
self.byteorder + ("h" * (self.nvar + 1)),
1277-
self.path_or_buf.read(2 * (self.nvar + 1)),
1278-
)[:-1]
1290+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
12791291

12801292
self.path_or_buf.seek(self._seek_formats)
12811293
self.fmtlist = self._get_fmtlist()
@@ -1291,10 +1303,7 @@ def _get_dtypes(
12911303
self, seek_vartypes: int
12921304
) -> tuple[list[int | str], list[str | np.dtype]]:
12931305
self.path_or_buf.seek(seek_vartypes)
1294-
raw_typlist = [
1295-
struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1296-
for _ in range(self.nvar)
1297-
]
1306+
raw_typlist = [self._read_uint16() for _ in range(self.nvar)]
12981307

12991308
def f(typ: int) -> int | str:
13001309
if typ <= 2045:
@@ -1363,16 +1372,16 @@ def _get_variable_labels(self) -> list[str]:
13631372

13641373
def _get_nobs(self) -> int:
13651374
if self.format_version >= 118:
1366-
return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1375+
return self._read_uint64()
13671376
else:
1368-
return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1377+
return self._read_uint32()
13691378

13701379
def _get_data_label(self) -> str:
13711380
if self.format_version >= 118:
1372-
strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1381+
strlen = self._read_uint16()
13731382
return self._decode(self.path_or_buf.read(strlen))
13741383
elif self.format_version == 117:
1375-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1384+
strlen = self._read_int8()
13761385
return self._decode(self.path_or_buf.read(strlen))
13771386
elif self.format_version > 105:
13781387
return self._decode(self.path_or_buf.read(81))
@@ -1381,10 +1390,10 @@ def _get_data_label(self) -> str:
13811390

13821391
def _get_time_stamp(self) -> str:
13831392
if self.format_version >= 118:
1384-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1393+
strlen = self._read_int8()
13851394
return self.path_or_buf.read(strlen).decode("utf-8")
13861395
elif self.format_version == 117:
1387-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1396+
strlen = self._read_int8()
13881397
return self._decode(self.path_or_buf.read(strlen))
13891398
elif self.format_version > 104:
13901399
return self._decode(self.path_or_buf.read(18))
@@ -1399,22 +1408,20 @@ def _get_seek_variable_labels(self) -> int:
13991408
# variable, 20 for the closing tag and 17 for the opening tag
14001409
return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
14011410
elif self.format_version >= 118:
1402-
return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1411+
return self._read_int64() + 17
14031412
else:
14041413
raise ValueError()
14051414

14061415
def _read_old_header(self, first_char: bytes) -> None:
1407-
self.format_version = struct.unpack("b", first_char)[0]
1416+
self.format_version = int(first_char[0])
14081417
if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
14091418
raise ValueError(_version_error.format(version=self.format_version))
14101419
self._set_encoding()
1411-
self.byteorder = (
1412-
">" if struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 else "<"
1413-
)
1414-
self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1420+
self.byteorder = (">" if self._read_int8() == 0x1 else "<")
1421+
self.filetype = self._read_int8()
14151422
self.path_or_buf.read(1) # unused
14161423

1417-
self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1424+
self.nvar = self._read_uint16()
14181425
self.nobs = self._get_nobs()
14191426

14201427
self._data_label = self._get_data_label()
@@ -1423,7 +1430,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14231430

14241431
# descriptors
14251432
if self.format_version > 108:
1426-
typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1433+
typlist = [int(c) for c in self.path_or_buf.read(self.nvar)]
14271434
else:
14281435
buf = self.path_or_buf.read(self.nvar)
14291436
typlistb = np.frombuffer(buf, dtype=np.uint8)
@@ -1453,10 +1460,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14531460
self.varlist = [
14541461
self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
14551462
]
1456-
self.srtlist = struct.unpack(
1457-
self.byteorder + ("h" * (self.nvar + 1)),
1458-
self.path_or_buf.read(2 * (self.nvar + 1)),
1459-
)[:-1]
1463+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
14601464

14611465
self.fmtlist = self._get_fmtlist()
14621466

@@ -1471,17 +1475,11 @@ def _read_old_header(self, first_char: bytes) -> None:
14711475

14721476
if self.format_version > 104:
14731477
while True:
1474-
data_type = struct.unpack(
1475-
self.byteorder + "b", self.path_or_buf.read(1)
1476-
)[0]
1478+
data_type = self._read_int8()
14771479
if self.format_version > 108:
1478-
data_len = struct.unpack(
1479-
self.byteorder + "i", self.path_or_buf.read(4)
1480-
)[0]
1480+
data_len = self._read_int32()
14811481
else:
1482-
data_len = struct.unpack(
1483-
self.byteorder + "h", self.path_or_buf.read(2)
1484-
)[0]
1482+
data_len = self._read_int16()
14851483
if data_type == 0:
14861484
break
14871485
self.path_or_buf.read(data_len)
@@ -1565,8 +1563,8 @@ def _read_value_labels(self) -> None:
15651563
labname = self._decode(self.path_or_buf.read(129))
15661564
self.path_or_buf.read(3) # padding
15671565

1568-
n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1569-
txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1566+
n = self._read_uint32()
1567+
txtlen = self._read_uint32()
15701568
off = np.frombuffer(
15711569
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
15721570
)
@@ -1594,7 +1592,7 @@ def _read_strls(self) -> None:
15941592
break
15951593

15961594
if self.format_version == 117:
1597-
v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1595+
v_o = self._read_uint64()
15981596
else:
15991597
buf = self.path_or_buf.read(12)
16001598
# Only tested on little endian file on little endian machine.
@@ -1605,8 +1603,8 @@ def _read_strls(self) -> None:
16051603
# This path may not be correct, impossible to test
16061604
buf = buf[0:v_size] + buf[(4 + v_size) :]
16071605
v_o = struct.unpack("Q", buf)[0]
1608-
typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1609-
length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1606+
typ = self._read_uint8()
1607+
length = self._read_uint32()
16101608
va = self.path_or_buf.read(length)
16111609
if typ == 130:
16121610
decoded_va = va[0:-1].decode(self._encoding)

0 commit comments

Comments
 (0)