Skip to content

Commit dd83e7b

Browse files
committed
BUG: Set index when reading Stata file
Ensures index is set when requested during reading of a Stata dta file Rename index to index_col for API consistency closes pandas-dev#16342
1 parent 473a7f3 commit dd83e7b

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

doc/source/whatsnew/v0.21.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ Other API Changes
293293
- :func:`Series.argmin` and :func:`Series.argmax` will now raise a ``TypeError`` when used with ``object`` dtypes, instead of a ``ValueError`` (:issue:`13595`)
294294
- :class:`Period` is now immutable, and will now raise an ``AttributeError`` when a user tries to assign a new value to the ``ordinal`` or ``freq`` attributes (:issue:`17116`).
295295
- :func:`to_datetime` when passed a tz-aware ``origin=`` kwarg will now raise a more informative ``ValueError`` rather than a ``TypeError`` (:issue:`16842`)
296+
- Renamed non-functional `index` to `index_col` in :func:`read_stata` to improve API consistency (:issue:`16342`)
296297

297298

298299
.. _whatsnew_0210.deprecations:
@@ -369,6 +370,7 @@ I/O
369370
- Bug in :func:`read_csv` when called with ``low_memory=False`` in which a CSV with at least one column > 2GB in size would incorrectly raise a ``MemoryError`` (:issue:`16798`).
370371
- Bug in :func:`read_csv` when called with a single-element list ``header`` would return a ``DataFrame`` of all NaN values (:issue:`7757`)
371372
- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)
373+
- Bug in :func:`read_stata` where the index was not set (:issue:`16342`)
372374
- Bug in :func:`read_html` where import check fails when run in multiple threads (:issue:`16928`)
373375

374376
Plotting

pandas/io/stata.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,14 @@
160160

161161
@Appender(_read_stata_doc)
162162
def read_stata(filepath_or_buffer, convert_dates=True,
163-
convert_categoricals=True, encoding=None, index=None,
163+
convert_categoricals=True, encoding=None, index_col=None,
164164
convert_missing=False, preserve_dtypes=True, columns=None,
165165
order_categoricals=True, chunksize=None, iterator=False):
166166

167167
reader = StataReader(filepath_or_buffer,
168168
convert_dates=convert_dates,
169169
convert_categoricals=convert_categoricals,
170-
index=index, convert_missing=convert_missing,
170+
index_col=index_col, convert_missing=convert_missing,
171171
preserve_dtypes=preserve_dtypes,
172172
columns=columns,
173173
order_categoricals=order_categoricals,
@@ -945,7 +945,7 @@ class StataReader(StataParser, BaseIterator):
945945
__doc__ = _stata_reader_doc
946946

947947
def __init__(self, path_or_buf, convert_dates=True,
948-
convert_categoricals=True, index=None,
948+
convert_categoricals=True, index_col=None,
949949
convert_missing=False, preserve_dtypes=True,
950950
columns=None, order_categoricals=True,
951951
encoding='latin-1', chunksize=None):
@@ -956,7 +956,7 @@ def __init__(self, path_or_buf, convert_dates=True,
956956
# calls to read).
957957
self._convert_dates = convert_dates
958958
self._convert_categoricals = convert_categoricals
959-
self._index = index
959+
self._index_col = index_col
960960
self._convert_missing = convert_missing
961961
self._preserve_dtypes = preserve_dtypes
962962
self._columns = columns
@@ -1461,7 +1461,7 @@ def get_chunk(self, size=None):
14611461

14621462
@Appender(_read_method_doc)
14631463
def read(self, nrows=None, convert_dates=None,
1464-
convert_categoricals=None, index=None,
1464+
convert_categoricals=None, index_col=None,
14651465
convert_missing=None, preserve_dtypes=None,
14661466
columns=None, order_categoricals=None):
14671467
# Handle empty file or chunk. If reading incrementally raise
@@ -1486,6 +1486,8 @@ def read(self, nrows=None, convert_dates=None,
14861486
columns = self._columns
14871487
if order_categoricals is None:
14881488
order_categoricals = self._order_categoricals
1489+
if index_col is None:
1490+
index_col = self._index_col
14891491

14901492
if nrows is None:
14911493
nrows = self.nobs
@@ -1524,14 +1526,14 @@ def read(self, nrows=None, convert_dates=None,
15241526
self._read_value_labels()
15251527

15261528
if len(data) == 0:
1527-
data = DataFrame(columns=self.varlist, index=index)
1529+
data = DataFrame(columns=self.varlist)
15281530
else:
1529-
data = DataFrame.from_records(data, index=index)
1531+
data = DataFrame.from_records(data)
15301532
data.columns = self.varlist
15311533

15321534
# If index is not specified, use actual row number rather than
15331535
# restarting at 0 for each chunk.
1534-
if index is None:
1536+
if index_col is None:
15351537
ix = np.arange(self._lines_read - read_lines, self._lines_read)
15361538
data = data.set_index(ix)
15371539

@@ -1553,7 +1555,7 @@ def read(self, nrows=None, convert_dates=None,
15531555
cols_ = np.where(self.dtyplist)[0]
15541556

15551557
# Convert columns (if needed) to match input type
1556-
index = data.index
1558+
ix = data.index
15571559
requires_type_conversion = False
15581560
data_formatted = []
15591561
for i in cols_:
@@ -1563,7 +1565,7 @@ def read(self, nrows=None, convert_dates=None,
15631565
if dtype != np.dtype(object) and dtype != self.dtyplist[i]:
15641566
requires_type_conversion = True
15651567
data_formatted.append(
1566-
(col, Series(data[col], index, self.dtyplist[i])))
1568+
(col, Series(data[col], ix, self.dtyplist[i])))
15671569
else:
15681570
data_formatted.append((col, data[col]))
15691571
if requires_type_conversion:
@@ -1606,6 +1608,9 @@ def read(self, nrows=None, convert_dates=None,
16061608
if convert:
16071609
data = DataFrame.from_items(retyped_data)
16081610

1611+
if index_col is not None:
1612+
data = data.set_index(data.pop(index_col))
1613+
16091614
return data
16101615

16111616
def _do_convert_missing(self, data, convert_missing):

pandas/tests/io/test_stata.py

+8
Original file line numberDiff line numberDiff line change
@@ -1309,3 +1309,11 @@ def test_value_labels_iterator(self, write_index):
13091309
dta_iter = pd.read_stata(path, iterator=True)
13101310
value_labels = dta_iter.value_labels()
13111311
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}
1312+
1313+
def test_set_index(self):
1314+
df = tm.makeDataFrame()
1315+
df.index.name = 'index'
1316+
with tm.ensure_clean() as path:
1317+
df.to_stata(path)
1318+
reread = pd.read_stata(path, index_col='index')
1319+
tm.assert_frame_equal(df, reread)

0 commit comments

Comments
 (0)