Skip to content

Commit 8ffa2a9

Browse files
authored
BUG: read_csv not applying dtype to index col (#44632)
1 parent d8068e5 commit 8ffa2a9

File tree

4 files changed

+41
-18
lines changed

4 files changed

+41
-18
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ I/O
663663
- Bug in :func:`json_normalize` where multi-character ``sep`` parameter is incorrectly prefixed to every key (:issue:`43831`)
664664
- Bug in :func:`json_normalize` where reading data with missing multi-level metadata would not respect errors="ignore" (:issue:`44312`)
665665
- Bug in :func:`read_csv` with :code:`float_precision="round_trip"` which did not skip initial/trailing whitespace (:issue:`43713`)
666+
- Bug in :func:`read_csv` not applying dtype for ``index_col`` (:issue:`9435`)
666667
- Bug in dumping/loading a :class:`DataFrame` with ``yaml.dump(frame)`` (:issue:`42748`)
667668
- Bug in :class:`ExcelWriter`, where ``engine_kwargs`` were not passed through to all engines (:issue:`43442`)
668669
- Bug in :func:`read_csv` raising ``ValueError`` when ``parse_dates`` was used with ``MultiIndex`` columns (:issue:`8991`)

pandas/io/parsers/base_parser.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections import defaultdict
4+
from copy import copy
45
import csv
56
import datetime
67
from enum import Enum
@@ -149,6 +150,8 @@ def __init__(self, kwds):
149150
self.na_filter = kwds.get("na_filter", False)
150151
self.keep_default_na = kwds.get("keep_default_na", True)
151152

153+
self.dtype = copy(kwds.get("dtype", None))
154+
152155
self.true_values = kwds.get("true_values")
153156
self.false_values = kwds.get("false_values")
154157
self.mangle_dupe_cols = kwds.get("mangle_dupe_cols", True)
@@ -511,6 +514,19 @@ def _get_name(icol):
511514

512515
return index
513516

517+
def _clean_mapping(self, mapping):
518+
"""converts col numbers to names"""
519+
if not isinstance(mapping, dict):
520+
return mapping
521+
clean = {}
522+
for col, v in mapping.items():
523+
# for mypy
524+
assert self.orig_names is not None
525+
if isinstance(col, int) and col not in self.orig_names:
526+
col = self.orig_names[col]
527+
clean[col] = v
528+
return clean
529+
514530
@final
515531
def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
516532
arrays = []
@@ -535,7 +551,17 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
535551
col_name, self.na_values, self.na_fvalues, self.keep_default_na
536552
)
537553

538-
arr, _ = self._infer_types(arr, col_na_values | col_na_fvalues)
554+
clean_dtypes = self._clean_mapping(self.dtype)
555+
556+
cast_type = None
557+
if isinstance(clean_dtypes, dict) and self.index_names is not None:
558+
cast_type = clean_dtypes.get(self.index_names[i], None)
559+
560+
try_num_bool = not (cast_type and is_string_dtype(cast_type))
561+
562+
arr, _ = self._infer_types(
563+
arr, col_na_values | col_na_fvalues, try_num_bool
564+
)
539565
arrays.append(arr)
540566

541567
names = self.index_names

pandas/io/parsers/python_parser.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
abc,
55
defaultdict,
66
)
7-
from copy import copy
87
import csv
98
from io import StringIO
109
import re
@@ -89,7 +88,6 @@ def __init__(
8988
self.verbose = kwds["verbose"]
9089
self.converters = kwds["converters"]
9190

92-
self.dtype = copy(kwds["dtype"])
9391
self.thousands = kwds["thousands"]
9492
self.decimal = kwds["decimal"]
9593

@@ -308,21 +306,8 @@ def get_chunk(self, size=None):
308306

309307
def _convert_data(self, data):
310308
# apply converters
311-
def _clean_mapping(mapping):
312-
"""converts col numbers to names"""
313-
clean = {}
314-
for col, v in mapping.items():
315-
if isinstance(col, int) and col not in self.orig_names:
316-
col = self.orig_names[col]
317-
clean[col] = v
318-
return clean
319-
320-
clean_conv = _clean_mapping(self.converters)
321-
if not isinstance(self.dtype, dict):
322-
# handles single dtype applied to all columns
323-
clean_dtypes = self.dtype
324-
else:
325-
clean_dtypes = _clean_mapping(self.dtype)
309+
clean_conv = self._clean_mapping(self.converters)
310+
clean_dtypes = self._clean_mapping(self.dtype)
326311

327312
# Apply NA values.
328313
clean_na_values = {}

pandas/tests/io/parser/test_index_col.py

+11
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,14 @@ def test_infer_types_boolean_sum(all_parsers):
321321
# index column of dtype 'object', and the Python parser will return a
322322
# index column of dtype 'int64'.
323323
tm.assert_frame_equal(result, expected, check_index_type=False)
324+
325+
326+
@skip_pyarrow
327+
@pytest.mark.parametrize("dtype, val", [(object, "01"), ("int64", 1)])
328+
def test_specify_dtype_for_index_col(all_parsers, dtype, val):
329+
# GH#9435
330+
data = "a,b\n01,2"
331+
parser = all_parsers
332+
result = parser.read_csv(StringIO(data), index_col="a", dtype={"a": dtype})
333+
expected = DataFrame({"b": [2]}, index=Index([val], name="a"))
334+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)