Skip to content

REF: Allow Index._with_infer to also return RangeIndex #58143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,25 @@ def _simple_new(
return result

@classmethod
def _with_infer(cls, *args, **kwargs):
def _with_infer(
cls,
data=None,
dtype=None,
copy: bool = False,
name=None,
tupleize_cols: bool = True,
):
"""
Constructor that uses the 1.0.x behavior inferring numeric dtypes
for ndarray[object] inputs.
"""
result = cls(*args, **kwargs)
result = cls(
data=maybe_sequence_to_range(data),
dtype=dtype,
copy=copy,
name=name,
tupleize_cols=tupleize_cols,
)

if result.dtype == _dtype_obj and not result._is_multi:
# error: Argument 1 to "maybe_convert_objects" has incompatible type
Expand Down Expand Up @@ -7140,7 +7153,9 @@ def maybe_sequence_to_range(sequence) -> Any | range:
"""
if isinstance(sequence, (range, ExtensionArray)):
return sequence
elif len(sequence) == 1 or lib.infer_dtype(sequence, skipna=False) != "integer":
elif isinstance(sequence, abc.Generator):
sequence = list(sequence)
if len(sequence) == 1 or lib.infer_dtype(sequence, skipna=False) != "integer":
return sequence
elif isinstance(sequence, (ABCSeries, Index)) and not (
isinstance(sequence.dtype, np.dtype) and sequence.dtype.kind == "i"
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/arrays/categorical/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,3 +794,8 @@ def test_range_values_preserves_rangeindex_categories(self, values, categories):
result = Categorical(values=values, categories=categories).categories
expected = RangeIndex(range(5))
tm.assert_index_equal(result, expected, exact=True)

def test_categoricaldtype_numeric_object_to_rangeindex_categories(self):
result = CategoricalDtype(np.array([1, 2], dtype=object)).categories
expected = RangeIndex(1, 3)
tm.assert_index_equal(result, expected, exact=True)
14 changes: 12 additions & 2 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DataFrame,
Index,
MultiIndex,
RangeIndex,
Series,
concat,
to_datetime,
Expand Down Expand Up @@ -517,7 +518,7 @@ def test_callable_result_dtype_frame(
df["c"] = df["c"].astype(input_dtype)
op = getattr(df.groupby(keys)[["c"]], method)
result = op(lambda x: x.astype(result_dtype).iloc[0])
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
expected_index = RangeIndex(0, 1) if method == "transform" else agg_index
expected = DataFrame({"c": [df["c"].iloc[0]]}, index=expected_index).astype(
result_dtype
)
Expand All @@ -541,7 +542,7 @@ def test_callable_result_dtype_series(keys, agg_index, input, dtype, method):
df = DataFrame({"a": [1], "b": [2], "c": [input]})
op = getattr(df.groupby(keys)["c"], method)
result = op(lambda x: x.astype(dtype).iloc[0])
expected_index = pd.RangeIndex(0, 1) if method == "transform" else agg_index
expected_index = RangeIndex(0, 1) if method == "transform" else agg_index
expected = Series([df["c"].iloc[0]], index=expected_index, name="c").astype(dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -1663,3 +1664,12 @@ def func(x):
msg = "length must not be 0"
with pytest.raises(ValueError, match=msg):
df.groupby("A", observed=False).agg(func)


def test_agg_groups_returns_rangeindex():
df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]})
result = df.groupby("group").agg(max)
expected = DataFrame(
[2, 3], index=RangeIndex(1, 3, name="group"), columns=["value"]
)
tm.assert_frame_equal(result, expected, check_index_type=True)
13 changes: 10 additions & 3 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def test_groupby_nonobject_dtype(multiindex_dataframe_random_data):
result = grouped.sum()

expected = multiindex_dataframe_random_data.groupby(key.astype("O")).sum()
assert result.index.dtype == np.int8
assert expected.index.dtype == np.int64
tm.assert_frame_equal(result, expected, check_index_type=False)
tm.assert_frame_equal(result, expected, check_index_type=True)


def test_groupby_nonobject_dtype_mixed():
Expand Down Expand Up @@ -2955,3 +2953,12 @@ def test_groupby_dropna_with_nunique_unique():
)

tm.assert_frame_equal(result, expected)


def test_groupby_groups_returns_rangeindex():
df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]})
result = df.groupby("group").max()
expected = DataFrame(
[2, 3], index=RangeIndex(1, 3, name="group"), columns=["value"]
)
tm.assert_frame_equal(result, expected, check_index_type=True)
10 changes: 9 additions & 1 deletion pandas/tests/groupby/transform/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataFrame,
Index,
MultiIndex,
RangeIndex,
Series,
Timestamp,
concat,
Expand Down Expand Up @@ -290,7 +291,7 @@ def test_transform_casting():
),
"DATETIME": pd.to_datetime([f"2014-10-08 {time}" for time in times]),
},
index=pd.RangeIndex(11, name="idx"),
index=RangeIndex(11, name="idx"),
)

result = df.groupby("ID3")["DATETIME"].transform(lambda x: x.diff())
Expand Down Expand Up @@ -1535,3 +1536,10 @@ def test_transform_sum_one_column_with_matching_labels_and_missing_labels():
result = df.groupby(series, as_index=False).transform("sum")
expected = DataFrame({"X": [-93203.0, -93203.0, np.nan]})
tm.assert_frame_equal(result, expected)


def test_transform_groups_returns_rangeindex():
df = DataFrame({"group": [1, 1, 2], "value": [1, 2, 3]})
result = df.groupby("group").transform(lambda x: x + 1)
expected = DataFrame([2, 3, 4], index=RangeIndex(0, 3), columns=["value"])
tm.assert_frame_equal(result, expected, check_index_type=True)
Loading