Skip to content

Commit b817246

Browse files
committed
CLN: StataReader: refactor repeated struct.unpack/read calls to helpers
1 parent 890d097 commit b817246

File tree

1 file changed

+70
-71
lines changed

1 file changed

+70
-71
lines changed

pandas/io/stata.py

+70-71
Original file line numberDiff line numberDiff line change
@@ -1202,9 +1202,42 @@ def _set_encoding(self) -> None:
12021202
else:
12031203
self._encoding = "utf-8"
12041204

1205+
def _read_int8(self) -> int:
1206+
return struct.unpack("b", self.path_or_buf.read(1))[0]
1207+
1208+
def _read_uint8(self) -> int:
1209+
return struct.unpack("B", self.path_or_buf.read(1))[0]
1210+
1211+
def _read_uint16(self) -> int:
1212+
return struct.unpack(f"{self.byteorder}H", self.path_or_buf.read(2))[0]
1213+
1214+
def _read_uint32(self) -> int:
1215+
return struct.unpack(f"{self.byteorder}I", self.path_or_buf.read(4))[0]
1216+
1217+
def _read_uint64(self) -> int:
1218+
return struct.unpack(f"{self.byteorder}Q", self.path_or_buf.read(8))[0]
1219+
1220+
def _read_int16(self) -> int:
1221+
return struct.unpack(f"{self.byteorder}h", self.path_or_buf.read(2))[0]
1222+
1223+
def _read_int32(self) -> int:
1224+
return struct.unpack(f"{self.byteorder}i", self.path_or_buf.read(4))[0]
1225+
1226+
def _read_int64(self) -> int:
1227+
return struct.unpack(f"{self.byteorder}q", self.path_or_buf.read(8))[0]
1228+
1229+
def _read_char8(self) -> bytes:
1230+
return struct.unpack("c", self.path_or_buf.read(1))[0]
1231+
1232+
def _read_int16_count(self, count: int) -> tuple[int, ...]:
1233+
return struct.unpack(
1234+
f"{self.byteorder}{'h' * count}",
1235+
self.path_or_buf.read(2 * count),
1236+
)
1237+
12051238
def _read_header(self) -> None:
1206-
first_char = self.path_or_buf.read(1)
1207-
if struct.unpack("c", first_char)[0] == b"<":
1239+
first_char = self._read_char8()
1240+
if first_char == b"<":
12081241
self._read_new_header()
12091242
else:
12101243
self._read_old_header(first_char)
@@ -1224,11 +1257,10 @@ def _read_new_header(self) -> None:
12241257
self.path_or_buf.read(21) # </release><byteorder>
12251258
self.byteorder = self.path_or_buf.read(3) == b"MSF" and ">" or "<"
12261259
self.path_or_buf.read(15) # </byteorder><K>
1227-
nvar_type = "H" if self.format_version <= 118 else "I"
1228-
nvar_size = 2 if self.format_version <= 118 else 4
1229-
self.nvar = struct.unpack(
1230-
self.byteorder + nvar_type, self.path_or_buf.read(nvar_size)
1231-
)[0]
1260+
if self.format_version <= 118:
1261+
self.nvar = self._read_uint16()
1262+
else:
1263+
self.nvar = self._read_uint32()
12321264
self.path_or_buf.read(7) # </K><N>
12331265

12341266
self.nobs = self._get_nobs()
@@ -1240,46 +1272,27 @@ def _read_new_header(self) -> None:
12401272
self.path_or_buf.read(8) # 0x0000000000000000
12411273
self.path_or_buf.read(8) # position of <map>
12421274

1243-
self._seek_vartypes = (
1244-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 16
1245-
)
1246-
self._seek_varnames = (
1247-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1248-
)
1249-
self._seek_sortlist = (
1250-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 10
1251-
)
1252-
self._seek_formats = (
1253-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 9
1254-
)
1255-
self._seek_value_label_names = (
1256-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 19
1257-
)
1275+
self._seek_vartypes = self._read_int64() + 16
1276+
self._seek_varnames = self._read_int64() + 10
1277+
self._seek_sortlist = self._read_int64() + 10
1278+
self._seek_formats = self._read_int64() + 9
1279+
self._seek_value_label_names = self._read_int64() + 19
12581280

12591281
# Requires version-specific treatment
12601282
self._seek_variable_labels = self._get_seek_variable_labels()
12611283

12621284
self.path_or_buf.read(8) # <characteristics>
1263-
self.data_location = (
1264-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 6
1265-
)
1266-
self.seek_strls = (
1267-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 7
1268-
)
1269-
self.seek_value_labels = (
1270-
struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 14
1271-
)
1285+
self.data_location = self._read_int64() + 6
1286+
self.seek_strls = self._read_int64() + 7
1287+
self.seek_value_labels = self._read_int64() + 14
12721288

12731289
self.typlist, self.dtyplist = self._get_dtypes(self._seek_vartypes)
12741290

12751291
self.path_or_buf.seek(self._seek_varnames)
12761292
self.varlist = self._get_varlist()
12771293

12781294
self.path_or_buf.seek(self._seek_sortlist)
1279-
self.srtlist = struct.unpack(
1280-
self.byteorder + ("h" * (self.nvar + 1)),
1281-
self.path_or_buf.read(2 * (self.nvar + 1)),
1282-
)[:-1]
1295+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
12831296

12841297
self.path_or_buf.seek(self._seek_formats)
12851298
self.fmtlist = self._get_fmtlist()
@@ -1296,10 +1309,7 @@ def _get_dtypes(
12961309
) -> tuple[list[int | str], list[str | np.dtype]]:
12971310

12981311
self.path_or_buf.seek(seek_vartypes)
1299-
raw_typlist = [
1300-
struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1301-
for _ in range(self.nvar)
1302-
]
1312+
raw_typlist = [self._read_uint16() for _ in range(self.nvar)]
13031313

13041314
def f(typ: int) -> int | str:
13051315
if typ <= 2045:
@@ -1368,16 +1378,16 @@ def _get_variable_labels(self) -> list[str]:
13681378

13691379
def _get_nobs(self) -> int:
13701380
if self.format_version >= 118:
1371-
return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1381+
return self._read_uint64()
13721382
else:
1373-
return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1383+
return self._read_uint32()
13741384

13751385
def _get_data_label(self) -> str:
13761386
if self.format_version >= 118:
1377-
strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1387+
strlen = self._read_uint16()
13781388
return self._decode(self.path_or_buf.read(strlen))
13791389
elif self.format_version == 117:
1380-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1390+
strlen = self._read_int8()
13811391
return self._decode(self.path_or_buf.read(strlen))
13821392
elif self.format_version > 105:
13831393
return self._decode(self.path_or_buf.read(81))
@@ -1386,10 +1396,10 @@ def _get_data_label(self) -> str:
13861396

13871397
def _get_time_stamp(self) -> str:
13881398
if self.format_version >= 118:
1389-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1399+
strlen = self._read_int8()
13901400
return self.path_or_buf.read(strlen).decode("utf-8")
13911401
elif self.format_version == 117:
1392-
strlen = struct.unpack("b", self.path_or_buf.read(1))[0]
1402+
strlen = self._read_int8()
13931403
return self._decode(self.path_or_buf.read(strlen))
13941404
elif self.format_version > 104:
13951405
return self._decode(self.path_or_buf.read(18))
@@ -1404,22 +1414,20 @@ def _get_seek_variable_labels(self) -> int:
14041414
# variable, 20 for the closing tag and 17 for the opening tag
14051415
return self._seek_value_label_names + (33 * self.nvar) + 20 + 17
14061416
elif self.format_version >= 118:
1407-
return struct.unpack(self.byteorder + "q", self.path_or_buf.read(8))[0] + 17
1417+
return self._read_int64() + 17
14081418
else:
14091419
raise ValueError()
14101420

14111421
def _read_old_header(self, first_char: bytes) -> None:
1412-
self.format_version = struct.unpack("b", first_char)[0]
1422+
self.format_version = int(first_char[0])
14131423
if self.format_version not in [104, 105, 108, 111, 113, 114, 115]:
14141424
raise ValueError(_version_error.format(version=self.format_version))
14151425
self._set_encoding()
1416-
self.byteorder = (
1417-
struct.unpack("b", self.path_or_buf.read(1))[0] == 0x1 and ">" or "<"
1418-
)
1419-
self.filetype = struct.unpack("b", self.path_or_buf.read(1))[0]
1426+
self.byteorder = self._read_int8() == 0x1 and ">" or "<"
1427+
self.filetype = self._read_int8()
14201428
self.path_or_buf.read(1) # unused
14211429

1422-
self.nvar = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0]
1430+
self.nvar = self._read_uint16()
14231431
self.nobs = self._get_nobs()
14241432

14251433
self._data_label = self._get_data_label()
@@ -1428,7 +1436,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14281436

14291437
# descriptors
14301438
if self.format_version > 108:
1431-
typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)]
1439+
typlist = [int(c) for c in self.path_or_buf.read(self.nvar)]
14321440
else:
14331441
buf = self.path_or_buf.read(self.nvar)
14341442
typlistb = np.frombuffer(buf, dtype=np.uint8)
@@ -1458,10 +1466,7 @@ def _read_old_header(self, first_char: bytes) -> None:
14581466
self.varlist = [
14591467
self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar)
14601468
]
1461-
self.srtlist = struct.unpack(
1462-
self.byteorder + ("h" * (self.nvar + 1)),
1463-
self.path_or_buf.read(2 * (self.nvar + 1)),
1464-
)[:-1]
1469+
self.srtlist = self._read_int16_count(self.nvar + 1)[:-1]
14651470

14661471
self.fmtlist = self._get_fmtlist()
14671472

@@ -1476,17 +1481,11 @@ def _read_old_header(self, first_char: bytes) -> None:
14761481

14771482
if self.format_version > 104:
14781483
while True:
1479-
data_type = struct.unpack(
1480-
self.byteorder + "b", self.path_or_buf.read(1)
1481-
)[0]
1484+
data_type = self._read_int8()
14821485
if self.format_version > 108:
1483-
data_len = struct.unpack(
1484-
self.byteorder + "i", self.path_or_buf.read(4)
1485-
)[0]
1486+
data_len = self._read_int32()
14861487
else:
1487-
data_len = struct.unpack(
1488-
self.byteorder + "h", self.path_or_buf.read(2)
1489-
)[0]
1488+
data_len = self._read_int16()
14901489
if data_type == 0:
14911490
break
14921491
self.path_or_buf.read(data_len)
@@ -1570,8 +1569,8 @@ def _read_value_labels(self) -> None:
15701569
labname = self._decode(self.path_or_buf.read(129))
15711570
self.path_or_buf.read(3) # padding
15721571

1573-
n = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1574-
txtlen = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1572+
n = self._read_uint32()
1573+
txtlen = self._read_uint32()
15751574
off = np.frombuffer(
15761575
self.path_or_buf.read(4 * n), dtype=self.byteorder + "i4", count=n
15771576
)
@@ -1599,7 +1598,7 @@ def _read_strls(self) -> None:
15991598
break
16001599

16011600
if self.format_version == 117:
1602-
v_o = struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0]
1601+
v_o = self._read_uint64()
16031602
else:
16041603
buf = self.path_or_buf.read(12)
16051604
# Only tested on little endian file on little endian machine.
@@ -1610,8 +1609,8 @@ def _read_strls(self) -> None:
16101609
# This path may not be correct, impossible to test
16111610
buf = buf[0:v_size] + buf[(4 + v_size) :]
16121611
v_o = struct.unpack("Q", buf)[0]
1613-
typ = struct.unpack("B", self.path_or_buf.read(1))[0]
1614-
length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0]
1612+
typ = self._read_uint8()
1613+
length = self._read_uint32()
16151614
va = self.path_or_buf.read(length)
16161615
if typ == 130:
16171616
decoded_va = va[0:-1].decode(self._encoding)

0 commit comments

Comments
 (0)