diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index ddcce820f5a19..49c168cd5eb84 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -278,6 +278,7 @@ Other enhancements - Improved error message in ``corr`` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`) - :meth:`Series.between` can now accept ``left`` or ``right`` as arguments to ``inclusive`` to include only the left or right boundary (:issue:`40245`) - :meth:`DataFrame.explode` now supports exploding multiple columns. Its ``column`` argument now also accepts a list of str or tuples for exploding on multiple columns at the same time (:issue:`39240`) +- :meth:`DataFrame.sample` now accepts the ``ignore_index`` argument to reset the index after sampling, similar to :meth:`DataFrame.drop_duplicates` and :meth:`DataFrame.sort_values` (:issue:`38581`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 3669d4df8a76b..7070228feeefa 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -5143,11 +5143,12 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries: def sample( self: FrameOrSeries, n=None, - frac=None, - replace=False, + frac: float | None = None, + replace: bool_t = False, weights=None, random_state=None, - axis=None, + axis: Axis | None = None, + ignore_index: bool_t = False, ) -> FrameOrSeries: """ Return a random sample of items from an axis of object. @@ -5189,6 +5190,10 @@ def sample( axis : {0 or ‘index’, 1 or ‘columns’, None}, default None Axis to sample. Accepts axis number or name. Default is stat axis for given data type (0 for Series and DataFrames). + ignore_index : bool, default False + If True, the resulting index will be labeled 0, 1, …, n - 1. + + .. versionadded:: 1.3.0 Returns ------- @@ -5350,7 +5355,11 @@ def sample( ) locs = rs.choice(axis_length, size=n, replace=replace, p=weights) - return self.take(locs, axis=axis) + result = self.take(locs, axis=axis) + if ignore_index: + result.index = ibase.default_index(len(result)) + + return result @final @doc(klass=_shared_doc_kwargs["klass"]) diff --git a/pandas/tests/frame/methods/test_sample.py b/pandas/tests/frame/methods/test_sample.py index 55ef665c55241..604788ba91633 100644 --- a/pandas/tests/frame/methods/test_sample.py +++ b/pandas/tests/frame/methods/test_sample.py @@ -5,6 +5,7 @@ from pandas import ( DataFrame, + Index, Series, ) import pandas._testing as tm @@ -326,3 +327,12 @@ def test_sample_is_copy(self): with tm.assert_produces_warning(None): df2["d"] = 1 + + def test_sample_ignore_index(self): + # GH 38581 + df = DataFrame( + {"col1": range(10, 20), "col2": range(20, 30), "colString": ["a"] * 10} + ) + result = df.sample(3, ignore_index=True) + expected_index = Index([0, 1, 2]) + tm.assert_index_equal(result.index, expected_index)