Skip to content

Commit a5f1653

Browse files
committed
BUG: Fix bugs in stata
Fix incorrect skipping in strl writer Fix incorrect byteorder when exporting bigendian Fix incorrect byteorder parsing when importing bigendian Improve test coverage for errors
1 parent 900c9f7 commit a5f1653

File tree

2 files changed

+84
-27
lines changed

2 files changed

+84
-27
lines changed

pandas/io/stata.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ def _read_new_header(self, first_char):
10501050
if self.format_version not in [117, 118]:
10511051
raise ValueError(_version_error)
10521052
self.path_or_buf.read(21) # </release><byteorder>
1053-
self.byteorder = self.path_or_buf.read(3) == "MSF" and '>' or '<'
1053+
self.byteorder = self.path_or_buf.read(3) == b'MSF' and '>' or '<'
10541054
self.path_or_buf.read(15) # </byteorder><K>
10551055
self.nvar = struct.unpack(self.byteorder + 'H',
10561056
self.path_or_buf.read(2))[0]
@@ -1824,9 +1824,7 @@ def _dtype_to_stata_type(dtype, column):
18241824
type inserted.
18251825
"""
18261826
# TODO: expand to handle datetime to integer conversion
1827-
if dtype.type == np.string_:
1828-
return dtype.itemsize
1829-
elif dtype.type == np.object_: # try to coerce it to the biggest string
1827+
if dtype.type == np.object_: # try to coerce it to the biggest string
18301828
# not memory efficient, what else could we
18311829
# do?
18321830
itemsize = max_len_string_array(_ensure_object(column.values))
@@ -2347,25 +2345,30 @@ def _prepare_data(self):
23472345
data = self._convert_strls(data)
23482346

23492347
# 3. Convert bad string data to '' and pad to correct length
2350-
dtype = []
2348+
dtypes = []
23512349
data_cols = []
23522350
has_strings = False
2351+
native_byteorder = self._byteorder == _set_endianness(sys.byteorder)
23532352
for i, col in enumerate(data):
23542353
typ = typlist[i]
23552354
if typ <= self._max_string_length:
23562355
has_strings = True
23572356
data[col] = data[col].fillna('').apply(_pad_bytes, args=(typ,))
23582357
stype = 'S%d' % typ
2359-
dtype.append(('c' + str(i), stype))
2358+
dtypes.append(('c' + str(i), stype))
23602359
string = data[col].str.encode(self._encoding)
23612360
data_cols.append(string.values.astype(stype))
23622361
else:
2363-
dtype.append(('c' + str(i), data[col].dtype))
2364-
data_cols.append(data[col].values)
2365-
dtype = np.dtype(dtype)
2366-
2367-
if has_strings:
2368-
self.data = np.fromiter(zip(*data_cols), dtype=dtype)
2362+
values = data[col].values
2363+
dtype = data[col].dtype
2364+
if not native_byteorder:
2365+
dtype = dtype.newbyteorder(self._byteorder)
2366+
dtypes.append(('c' + str(i), dtype))
2367+
data_cols.append(values)
2368+
dtypes = np.dtype(dtypes)
2369+
2370+
if has_strings or not native_byteorder:
2371+
self.data = np.fromiter(zip(*data_cols), dtype=dtypes)
23692372
else:
23702373
self.data = data.to_records(index=False)
23712374

@@ -2403,9 +2406,7 @@ def _dtype_to_stata_type_117(dtype, column, force_strl):
24032406
# TODO: expand to handle datetime to integer conversion
24042407
if force_strl:
24052408
return 32768
2406-
if dtype.type == np.string_:
2407-
return chr(dtype.itemsize)
2408-
elif dtype.type == np.object_: # try to coerce it to the biggest string
2409+
if dtype.type == np.object_: # try to coerce it to the biggest string
24092410
# not memory efficient, what else could we
24102411
# do?
24112412
itemsize = max_len_string_array(_ensure_object(column.values))
@@ -2513,11 +2514,13 @@ def generate_table(self):
25132514
Ordered dictionary using the string found as keys
25142515
and their lookup position (v,o) as values
25152516
gso_df : DataFrame
2516-
Copy of DataFrame where strl columns have been converted
2517-
to encoded (v,o) values
2517+
DataFrame where strl columns have been converted to
2518+
(v,o) values
25182519
25192520
Notes
25202521
-----
2522+
Modifies the DataFrame in-place.
2523+
25212524
The DataFrame returned encodes the (v,o) values as uint64s. The
25222525
encoding depends on teh dta version, and can be expressed as
25232526
@@ -2532,10 +2535,9 @@ def generate_table(self):
25322535
"""
25332536

25342537
gso_table = self._gso_table
2535-
df_out = self.df.copy()
2536-
df = self.df
2537-
columns = list(df.columns)
2538-
selected = df[self.columns]
2538+
gso_df = self.df
2539+
columns = list(gso_df.columns)
2540+
selected = gso_df[self.columns]
25392541
col_index = [(col, columns.index(col)) for col in self.columns]
25402542
keys = np.empty(selected.shape, dtype=np.uint64)
25412543
for o, (idx, row) in enumerate(selected.iterrows()):
@@ -2548,9 +2550,9 @@ def generate_table(self):
25482550
gso_table[val] = key
25492551
keys[o, j] = self._convert_key(key)
25502552
for i, col in enumerate(self.columns):
2551-
df_out[col] = keys[:, i]
2553+
gso_df[col] = keys[:, i]
25522554

2553-
return gso_table, df_out
2555+
return gso_table, gso_df
25542556

25552557
def _encode(self, s):
25562558
"""
@@ -2599,7 +2601,7 @@ def generate_blob(self, table):
25992601
o_type = self._byteorder + self._gso_o_type
26002602
len_type = self._byteorder + 'I'
26012603
for strl, vo in table.items():
2602-
if vo == 0:
2604+
if vo == (0, 0):
26032605
continue
26042606
v, o = vo
26052607
# GSO

pandas/tests/io/test_stata.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,15 @@ def test_timestamp_and_label(self, version):
499499
assert reader.time_stamp == '29 Feb 2000 14:21'
500500
assert reader.data_label == data_label
501501

502+
@pytest.mark.parametrize('version', [114, 117])
503+
def test_invalid_timestamp(self, version):
504+
original = DataFrame([(1,)], columns=['variable'])
505+
time_stamp = '01 Jan 2000, 00:00:00'
506+
with tm.ensure_clean() as path:
507+
with pytest.raises(ValueError):
508+
original.to_stata(path, time_stamp=time_stamp,
509+
version=version)
510+
502511
def test_numeric_column_names(self):
503512
original = DataFrame(np.reshape(np.arange(25.0), (5, 5)))
504513
original.index.name = 'index'
@@ -639,7 +648,8 @@ def test_write_missing_strings(self):
639648
expected)
640649

641650
@pytest.mark.parametrize('version', [114, 117])
642-
def test_bool_uint(self, version):
651+
@pytest.mark.parametrize('byteorder', ['>', '<'])
652+
def test_bool_uint(self, byteorder, version):
643653
s0 = Series([0, 1, True], dtype=np.bool)
644654
s1 = Series([0, 1, 100], dtype=np.uint8)
645655
s2 = Series([0, 1, 255], dtype=np.uint8)
@@ -658,7 +668,7 @@ def test_bool_uint(self, version):
658668
expected[c] = expected[c].astype(t)
659669

660670
with tm.ensure_clean() as path:
661-
original.to_stata(path, version=version)
671+
original.to_stata(path, byteorder=byteorder, version=version)
662672
written_and_read_again = self.read_dta(path)
663673
written_and_read_again = written_and_read_again.set_index('index')
664674
tm.assert_frame_equal(written_and_read_again, expected)
@@ -1173,6 +1183,29 @@ def test_write_variable_labels(self, version):
11731183
read_labels = sr.variable_labels()
11741184
assert read_labels == variable_labels
11751185

1186+
@pytest.mark.parametrize('version', [114, 117])
1187+
def test_invalid_variable_labels(self, version):
1188+
original = pd.DataFrame({'a': [1, 2, 3, 4],
1189+
'b': [1.0, 3.0, 27.0, 81.0],
1190+
'c': ['Atlanta', 'Birmingham',
1191+
'Cincinnati', 'Detroit']})
1192+
original.index.name = 'index'
1193+
variable_labels = {'a': 'very long' * 10,
1194+
'b': 'City Exponent',
1195+
'c': 'City'}
1196+
with tm.ensure_clean() as path:
1197+
with pytest.raises(ValueError):
1198+
original.to_stata(path,
1199+
variable_labels=variable_labels,
1200+
version=version)
1201+
1202+
variable_labels['a'] = u'invalid character Œ'
1203+
with tm.ensure_clean() as path:
1204+
with pytest.raises(ValueError):
1205+
original.to_stata(path,
1206+
variable_labels=variable_labels,
1207+
version=version)
1208+
11761209
def test_write_variable_label_errors(self):
11771210
original = pd.DataFrame({'a': [1, 2, 3, 4],
11781211
'b': [1.0, 3.0, 27.0, 81.0],
@@ -1220,6 +1253,13 @@ def test_default_date_conversion(self):
12201253
direct = read_stata(path, convert_dates=True)
12211254
tm.assert_frame_equal(reread, direct)
12221255

1256+
dates_idx = original.columns.tolist().index('dates')
1257+
original.to_stata(path,
1258+
write_index=False,
1259+
convert_dates={dates_idx: 'tc'})
1260+
direct = read_stata(path, convert_dates=True)
1261+
tm.assert_frame_equal(reread, direct)
1262+
12231263
def test_unsupported_type(self):
12241264
original = pd.DataFrame({'a': [1 + 2j, 2 + 4j]})
12251265

@@ -1394,7 +1434,7 @@ def test_writer_117(self):
13941434
original['float32'] = Series(original['float32'], dtype=np.float32)
13951435
original.index.name = 'index'
13961436
original.index = original.index.astype(np.int32)
1397-
1437+
copy = original.copy()
13981438
with tm.ensure_clean() as path:
13991439
original.to_stata(path,
14001440
convert_dates={'datetime': 'tc'},
@@ -1404,6 +1444,7 @@ def test_writer_117(self):
14041444
# original.index is np.int32, read index is np.int64
14051445
tm.assert_frame_equal(written_and_read_again.set_index('index'),
14061446
original, check_index_type=False)
1447+
tm.assert_frame_equal(original, copy)
14071448

14081449
def test_convert_strl_name_swap(self):
14091450
original = DataFrame([['a' * 3000, 'A', 'apple'],
@@ -1419,3 +1460,17 @@ def test_convert_strl_name_swap(self):
14191460
reread.columns = original.columns
14201461
tm.assert_frame_equal(reread, original,
14211462
check_index_type=False)
1463+
1464+
def test_invalid_date_conversion(self):
1465+
# GH 12259
1466+
dates = [dt.datetime(1999, 12, 31, 12, 12, 12, 12000),
1467+
dt.datetime(2012, 12, 21, 12, 21, 12, 21000),
1468+
dt.datetime(1776, 7, 4, 7, 4, 7, 4000)]
1469+
original = pd.DataFrame({'nums': [1.0, 2.0, 3.0],
1470+
'strs': ['apple', 'banana', 'cherry'],
1471+
'dates': dates})
1472+
1473+
with tm.ensure_clean() as path:
1474+
with pytest.raises(ValueError):
1475+
original.to_stata(path,
1476+
convert_dates={'wrong_name': 'tc'})

0 commit comments

Comments
 (0)