Skip to content

Commit ddc806f

Browse files
committed
MAINT: Refactor decode
Refactor decode and null terminate to use file encoding
1 parent 2aff757 commit ddc806f

File tree

2 files changed

+31
-30
lines changed

2 files changed

+31
-30
lines changed

pandas/io/stata.py

+24-29
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ def _get_varlist(self):
11361136
elif self.format_version == 118:
11371137
b = 129
11381138

1139-
return [self._null_terminate(self.path_or_buf.read(b))
1139+
return [self._decode(self.path_or_buf.read(b))
11401140
for i in range(self.nvar)]
11411141

11421142
# Returns the format list
@@ -1150,7 +1150,7 @@ def _get_fmtlist(self):
11501150
else:
11511151
b = 7
11521152

1153-
return [self._null_terminate(self.path_or_buf.read(b))
1153+
return [self._decode(self.path_or_buf.read(b))
11541154
for i in range(self.nvar)]
11551155

11561156
# Returns the label list
@@ -1161,18 +1161,18 @@ def _get_lbllist(self):
11611161
b = 33
11621162
else:
11631163
b = 9
1164-
return [self._null_terminate(self.path_or_buf.read(b))
1164+
return [self._decode(self.path_or_buf.read(b))
11651165
for i in range(self.nvar)]
11661166

11671167
def _get_variable_labels(self):
11681168
if self.format_version == 118:
11691169
vlblist = [self._decode(self.path_or_buf.read(321))
11701170
for i in range(self.nvar)]
11711171
elif self.format_version > 105:
1172-
vlblist = [self._null_terminate(self.path_or_buf.read(81))
1172+
vlblist = [self._decode(self.path_or_buf.read(81))
11731173
for i in range(self.nvar)]
11741174
else:
1175-
vlblist = [self._null_terminate(self.path_or_buf.read(32))
1175+
vlblist = [self._decode(self.path_or_buf.read(32))
11761176
for i in range(self.nvar)]
11771177
return vlblist
11781178

@@ -1191,21 +1191,21 @@ def _get_data_label(self):
11911191
return self._decode(self.path_or_buf.read(strlen))
11921192
elif self.format_version == 117:
11931193
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
1194-
return self._null_terminate(self.path_or_buf.read(strlen))
1194+
return self._decode(self.path_or_buf.read(strlen))
11951195
elif self.format_version > 105:
1196-
return self._null_terminate(self.path_or_buf.read(81))
1196+
return self._decode(self.path_or_buf.read(81))
11971197
else:
1198-
return self._null_terminate(self.path_or_buf.read(32))
1198+
return self._decode(self.path_or_buf.read(32))
11991199

12001200
def _get_time_stamp(self):
12011201
if self.format_version == 118:
12021202
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
12031203
return self.path_or_buf.read(strlen).decode("utf-8")
12041204
elif self.format_version == 117:
12051205
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
1206-
return self._null_terminate(self.path_or_buf.read(strlen))
1206+
return self._decode(self.path_or_buf.read(strlen))
12071207
elif self.format_version > 104:
1208-
return self._null_terminate(self.path_or_buf.read(18))
1208+
return self._decode(self.path_or_buf.read(18))
12091209
else:
12101210
raise ValueError()
12111211

@@ -1266,10 +1266,10 @@ def _read_old_header(self, first_char):
12661266
.format(','.join(str(x) for x in typlist)))
12671267

12681268
if self.format_version > 108:
1269-
self.varlist = [self._null_terminate(self.path_or_buf.read(33))
1269+
self.varlist = [self._decode(self.path_or_buf.read(33))
12701270
for i in range(self.nvar)]
12711271
else:
1272-
self.varlist = [self._null_terminate(self.path_or_buf.read(9))
1272+
self.varlist = [self._decode(self.path_or_buf.read(9))
12731273
for i in range(self.nvar)]
12741274
self.srtlist = struct.unpack(
12751275
self.byteorder + ('h' * (self.nvar + 1)),
@@ -1326,20 +1326,19 @@ def _calcsize(self, fmt):
13261326
struct.calcsize(self.byteorder + fmt))
13271327

13281328
def _decode(self, s):
1329-
s = s.partition(b"\0")[0]
1330-
try:
1331-
return s.decode('utf-8')
1332-
except UnicodeDecodeError:
1333-
# GH 25960
1334-
return s.decode('latin-1')
1335-
1336-
def _null_terminate(self, s):
13371329
# have bytes not strings, so must decode
13381330
s = s.partition(b"\0")[0]
13391331
try:
13401332
return s.decode(self._encoding)
13411333
except UnicodeDecodeError:
1342-
# GH 25960
1334+
# GH 25960, fallback to handle incorrect format produced when 117
1335+
# files are converted to 118 files in Stata
1336+
msg = """
1337+
One or more strings in the dta file could not be decoded using {encoding}, and
1338+
so the fallback encoding of latin-1 is being used. This can happen when a file
1339+
has been incorrectly encoded by Stata or some other software. You should verify
1340+
the string values returned are correct."""
1341+
warnings.warn(msg.format(encoding=self._encoding), UnicodeWarning)
13431342
return s.decode('latin-1')
13441343

13451344
def _read_value_labels(self):
@@ -1370,7 +1369,7 @@ def _read_value_labels(self):
13701369
if not slength:
13711370
break # end of value label table (format < 117)
13721371
if self.format_version <= 117:
1373-
labname = self._null_terminate(self.path_or_buf.read(33))
1372+
labname = self._decode(self.path_or_buf.read(33))
13741373
else:
13751374
labname = self._decode(self.path_or_buf.read(129))
13761375
self.path_or_buf.read(3) # padding
@@ -1392,12 +1391,8 @@ def _read_value_labels(self):
13921391
self.value_label_dict[labname] = dict()
13931392
for i in range(n):
13941393
end = off[i + 1] if i < n - 1 else txtlen
1395-
if self.format_version <= 117:
1396-
self.value_label_dict[labname][val[i]] = (
1397-
self._null_terminate(txt[off[i]:end]))
1398-
else:
1399-
self.value_label_dict[labname][val[i]] = (
1400-
self._decode(txt[off[i]:end]))
1394+
self.value_label_dict[labname][val[i]] = \
1395+
self._decode(txt[off[i]:end])
14011396
if self.format_version >= 117:
14021397
self.path_or_buf.read(6) # </lbl>
14031398
self._value_labels_read = True
@@ -1552,7 +1547,7 @@ def read(self, nrows=None, convert_dates=None,
15521547
for col, typ in zip(data, self.typlist):
15531548
if type(typ) is int:
15541549
data[col] = data[col].apply(
1555-
self._null_terminate, convert_dtype=True)
1550+
self._decode, convert_dtype=True)
15561551

15571552
data = self._insert_strls(data)
15581553

pandas/tests/io/test_stata.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,12 @@ def test_strl_latin1(self):
16131613

16141614
def test_encoding_latin1_118(self):
16151615
# GH 25960
1616-
encoded = read_stata(self.dta_encoding_118)
1616+
msg = """
1617+
One or more strings in the dta file could not be decoded using utf-8, and
1618+
so the fallback encoding of latin-1 is being used. This can happen when a file
1619+
has been incorrectly encoded by Stata or some other software. You should verify
1620+
the string values returned are correct."""
1621+
with pytest.warns(UnicodeWarning, match=msg):
1622+
encoded = read_stata(self.dta_encoding_118)
16171623
expected = pd.DataFrame([['Düsseldorf']] * 151, columns=['kreis1849'])
16181624
tm.assert_frame_equal(encoded, expected)

0 commit comments

Comments
 (0)