Skip to content

Commit aa0a90a

Browse files
mroeschkejorisvandenbossche
authored andcommitted
BUG: Avoid RangeIndex conversion in read_csv if dtype is specified (pandas-dev#59316)
* BUG: Avoid RangeIndex conversion in read_csv if dtype is specified * Undo change * Typing
1 parent 284e359 commit aa0a90a

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

pandas/io/parsers/base_parser.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,11 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
464464
arrays = []
465465
converters = self._clean_mapping(self.converters)
466466

467-
for i, arr in enumerate(index):
467+
if self.index_names is not None:
468+
names: Iterable = self.index_names
469+
else:
470+
names = itertools.cycle([None])
471+
for i, (arr, name) in enumerate(zip(index, names)):
468472
if try_parse_dates and self._should_parse_dates(i):
469473
arr = self._date_conv(
470474
arr,
@@ -504,12 +508,17 @@ def _agg_index(self, index, try_parse_dates: bool = True) -> Index:
504508
arr, _ = self._infer_types(
505509
arr, col_na_values | col_na_fvalues, cast_type is None, try_num_bool
506510
)
507-
arrays.append(arr)
508-
509-
names = self.index_names
510-
index = ensure_index_from_sequences(arrays, names)
511+
if cast_type is not None:
512+
# Don't perform RangeIndex inference
513+
idx = Index(arr, name=name, dtype=cast_type)
514+
else:
515+
idx = ensure_index_from_sequences([arr], [name])
516+
arrays.append(idx)
511517

512-
return index
518+
if len(arrays) == 1:
519+
return arrays[0]
520+
else:
521+
return MultiIndex.from_arrays(arrays)
513522

514523
@final
515524
def _convert_to_ndarrays(
@@ -1084,12 +1093,11 @@ def _get_empty_meta(self, columns, dtype: DtypeArg | None = None):
10841093
dtype_dict: defaultdict[Hashable, Any]
10851094
if not is_dict_like(dtype):
10861095
# if dtype == None, default will be object.
1087-
default_dtype = dtype or object
1088-
dtype_dict = defaultdict(lambda: default_dtype)
1096+
dtype_dict = defaultdict(lambda: dtype)
10891097
else:
10901098
dtype = cast(dict, dtype)
10911099
dtype_dict = defaultdict(
1092-
lambda: object,
1100+
lambda: None,
10931101
{columns[k] if is_integer(k) else k: v for k, v in dtype.items()},
10941102
)
10951103

@@ -1106,8 +1114,14 @@ def _get_empty_meta(self, columns, dtype: DtypeArg | None = None):
11061114
if (index_col is None or index_col is False) or index_names is None:
11071115
index = default_index(0)
11081116
else:
1109-
data = [Series([], dtype=dtype_dict[name]) for name in index_names]
1110-
index = ensure_index_from_sequences(data, names=index_names)
1117+
# TODO: We could return default_index(0) if dtype_dict[name] is None
1118+
data = [
1119+
Index([], name=name, dtype=dtype_dict[name]) for name in index_names
1120+
]
1121+
if len(data) == 1:
1122+
index = data[0]
1123+
else:
1124+
index = MultiIndex.from_arrays(data)
11111125
index_col.sort()
11121126

11131127
for i, n in enumerate(index_col):

pandas/tests/io/parser/dtypes/test_dtypes_basic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
"ignore:Passing a BlockManager to DataFrame:DeprecationWarning"
2323
)
2424

25+
xfail_pyarrow = pytest.mark.usefixtures("pyarrow_xfail")
26+
2527

2628
@pytest.mark.parametrize("dtype", [str, object])
2729
@pytest.mark.parametrize("check_orig", [True, False])
@@ -594,6 +596,7 @@ def test_string_inference_object_dtype(all_parsers, dtype, using_infer_string):
594596
tm.assert_frame_equal(result, expected)
595597

596598

599+
@xfail_pyarrow
597600
def test_accurate_parsing_of_large_integers(all_parsers):
598601
# GH#52505
599602
data = """SYMBOL,MOMENT,ID,ID_DEAL
@@ -604,7 +607,7 @@ def test_accurate_parsing_of_large_integers(all_parsers):
604607
AMZN,20230301181139587,2023552585717889759,2023552585717263360
605608
MSFT,20230301181139587,2023552585717889863,2023552585717263361
606609
NVDA,20230301181139587,2023552585717889827,2023552585717263361"""
607-
orders = pd.read_csv(StringIO(data), dtype={"ID_DEAL": pd.Int64Dtype()})
610+
orders = all_parsers.read_csv(StringIO(data), dtype={"ID_DEAL": pd.Int64Dtype()})
608611
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263358, "ID_DEAL"]) == 1
609612
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263359, "ID_DEAL"]) == 1
610613
assert len(orders.loc[orders["ID_DEAL"] == 2023552585717263360, "ID_DEAL"]) == 2
@@ -626,3 +629,16 @@ def test_dtypes_with_usecols(all_parsers):
626629
values = ["1", "4"]
627630
expected = DataFrame({"a": pd.Series(values, dtype=object), "c": [3, 6]})
628631
tm.assert_frame_equal(result, expected)
632+
633+
634+
def test_index_col_with_dtype_no_rangeindex(all_parsers):
635+
data = StringIO("345.5,519.5,0\n519.5,726.5,1")
636+
result = all_parsers.read_csv(
637+
data,
638+
header=None,
639+
names=["start", "stop", "bin_id"],
640+
dtype={"start": np.float32, "stop": np.float32, "bin_id": np.uint32},
641+
index_col="bin_id",
642+
).index
643+
expected = pd.Index([0, 1], dtype=np.uint32, name="bin_id")
644+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)