Skip to content

Commit b8e36ac

Browse files
committed
BUG: Set index when reading stata file
Ensures index is set when requested when reading state dta file closes pandas-dev#16342
1 parent 96f92eb commit b8e36ac

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

doc/source/whatsnew/v0.21.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ I/O
369369
- Bug in :func:`read_csv` when called with ``low_memory=False`` in which a CSV with at least one column > 2GB in size would incorrectly raise a ``MemoryError`` (:issue:`16798`).
370370
- Bug in :func:`read_csv` when called with a single-element list ``header`` would return a ``DataFrame`` of all NaN values (:issue:`7757`)
371371
- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)
372+
- Bug in :func:`read_stata` where the index was not set (:issue:`16342`)
372373
- Bug in :func:`read_html` where import check fails when run in multiple threads (:issue:`16928`)
373374

374375
Plotting

pandas/io/stata.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,8 @@ def read(self, nrows=None, convert_dates=None,
14861486
columns = self._columns
14871487
if order_categoricals is None:
14881488
order_categoricals = self._order_categoricals
1489+
if index is None:
1490+
index = self._index
14891491

14901492
if nrows is None:
14911493
nrows = self.nobs
@@ -1526,7 +1528,7 @@ def read(self, nrows=None, convert_dates=None,
15261528
if len(data) == 0:
15271529
data = DataFrame(columns=self.varlist, index=index)
15281530
else:
1529-
data = DataFrame.from_records(data, index=index)
1531+
data = DataFrame.from_records(data)
15301532
data.columns = self.varlist
15311533

15321534
# If index is not specified, use actual row number rather than
@@ -1553,7 +1555,7 @@ def read(self, nrows=None, convert_dates=None,
15531555
cols_ = np.where(self.dtyplist)[0]
15541556

15551557
# Convert columns (if needed) to match input type
1556-
index = data.index
1558+
ix = data.index
15571559
requires_type_conversion = False
15581560
data_formatted = []
15591561
for i in cols_:
@@ -1563,7 +1565,7 @@ def read(self, nrows=None, convert_dates=None,
15631565
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
15641566
requires_type_conversion = True
15651567
data_formatted.append(
1566-
(col, Series(data[col], index, self.dtyplist[i])))
1568+
(col, Series(data[col], ix, self.dtyplist[i])))
15671569
else:
15681570
data_formatted.append((col, data[col]))
15691571
if requires_type_conversion:
@@ -1606,6 +1608,9 @@ def read(self, nrows=None, convert_dates=None,
16061608
if convert:
16071609
data = DataFrame.from_items(retyped_data)
16081610

1611+
if index is not None:
1612+
data = data.set_index(data.pop(index))
1613+
16091614
return data
16101615

16111616
def _do_convert_missing(self, data, convert_missing):

pandas/tests/io/test_stata.py

+8
Original file line numberDiff line numberDiff line change
@@ -1309,3 +1309,11 @@ def test_value_labels_iterator(self, write_index):
13091309
dta_iter = pd.read_stata(path, iterator=True)
13101310
value_labels = dta_iter.value_labels()
13111311
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}
1312+
1313+
def test_set_index(self):
1314+
df = tm.makeDataFrame()
1315+
df.index.name = 'index'
1316+
with tm.ensure_clean() as path:
1317+
df.to_stata(path)
1318+
reread = pd.read_stata(path, index='index')
1319+
tm.assert_frame_equal(df, reread)

0 commit comments

Comments
 (0)