|
137 | 137 | from pandas.core.missing import find_valid_index
|
138 | 138 | from pandas.core.ops import align_method_FRAME
|
139 | 139 | from pandas.core.reshape.concat import concat
|
| 140 | +import pandas.core.sample as sample |
140 | 141 | from pandas.core.shared_docs import _shared_docs
|
141 | 142 | from pandas.core.sorting import get_indexer_indexer
|
142 | 143 | from pandas.core.window import (
|
@@ -5146,7 +5147,7 @@ def tail(self: FrameOrSeries, n: int = 5) -> FrameOrSeries:
|
5146 | 5147 | @final
|
5147 | 5148 | def sample(
|
5148 | 5149 | self: FrameOrSeries,
|
5149 |
| - n=None, |
| 5150 | + n: int | None = None, |
5150 | 5151 | frac: float | None = None,
|
5151 | 5152 | replace: bool_t = False,
|
5152 | 5153 | weights=None,
|
@@ -5273,92 +5274,22 @@ def sample(
|
5273 | 5274 | axis = self._stat_axis_number
|
5274 | 5275 |
|
5275 | 5276 | axis = self._get_axis_number(axis)
|
5276 |
| - axis_length = self.shape[axis] |
| 5277 | + obj_len = self.shape[axis] |
5277 | 5278 |
|
5278 | 5279 | # Process random_state argument
|
5279 | 5280 | rs = com.random_state(random_state)
|
5280 | 5281 |
|
5281 |
| - # Check weights for compliance |
5282 |
| - if weights is not None: |
5283 |
| - |
5284 |
| - # If a series, align with frame |
5285 |
| - if isinstance(weights, ABCSeries): |
5286 |
| - weights = weights.reindex(self.axes[axis]) |
5287 |
| - |
5288 |
| - # Strings acceptable if a dataframe and axis = 0 |
5289 |
| - if isinstance(weights, str): |
5290 |
| - if isinstance(self, ABCDataFrame): |
5291 |
| - if axis == 0: |
5292 |
| - try: |
5293 |
| - weights = self[weights] |
5294 |
| - except KeyError as err: |
5295 |
| - raise KeyError( |
5296 |
| - "String passed to weights not a valid column" |
5297 |
| - ) from err |
5298 |
| - else: |
5299 |
| - raise ValueError( |
5300 |
| - "Strings can only be passed to " |
5301 |
| - "weights when sampling from rows on " |
5302 |
| - "a DataFrame" |
5303 |
| - ) |
5304 |
| - else: |
5305 |
| - raise ValueError( |
5306 |
| - "Strings cannot be passed as weights " |
5307 |
| - "when sampling from a Series." |
5308 |
| - ) |
5309 |
| - |
5310 |
| - if isinstance(self, ABCSeries): |
5311 |
| - func = self._constructor |
5312 |
| - else: |
5313 |
| - func = self._constructor_sliced |
5314 |
| - weights = func(weights, dtype="float64") |
| 5282 | + size = sample.process_sampling_size(n, frac, replace) |
| 5283 | + if size is None: |
| 5284 | + assert frac is not None |
| 5285 | + size = round(frac * obj_len) |
5315 | 5286 |
|
5316 |
| - if len(weights) != axis_length: |
5317 |
| - raise ValueError( |
5318 |
| - "Weights and axis to be sampled must be of same length" |
5319 |
| - ) |
5320 |
| - |
5321 |
| - if (weights == np.inf).any() or (weights == -np.inf).any(): |
5322 |
| - raise ValueError("weight vector may not include `inf` values") |
5323 |
| - |
5324 |
| - if (weights < 0).any(): |
5325 |
| - raise ValueError("weight vector many not include negative values") |
5326 |
| - |
5327 |
| - # If has nan, set to zero. |
5328 |
| - weights = weights.fillna(0) |
5329 |
| - |
5330 |
| - # Renormalize if don't sum to 1 |
5331 |
| - if weights.sum() != 1: |
5332 |
| - if weights.sum() != 0: |
5333 |
| - weights = weights / weights.sum() |
5334 |
| - else: |
5335 |
| - raise ValueError("Invalid weights: weights sum to zero") |
5336 |
| - |
5337 |
| - weights = weights._values |
| 5287 | + if weights is not None: |
| 5288 | + weights = sample.preprocess_weights(self, weights, axis) |
5338 | 5289 |
|
5339 |
| - # If no frac or n, default to n=1. |
5340 |
| - if n is None and frac is None: |
5341 |
| - n = 1 |
5342 |
| - elif frac is not None and frac > 1 and not replace: |
5343 |
| - raise ValueError( |
5344 |
| - "Replace has to be set to `True` when " |
5345 |
| - "upsampling the population `frac` > 1." |
5346 |
| - ) |
5347 |
| - elif frac is None and n % 1 != 0: |
5348 |
| - raise ValueError("Only integers accepted as `n` values") |
5349 |
| - elif n is None and frac is not None: |
5350 |
| - n = round(frac * axis_length) |
5351 |
| - elif frac is not None: |
5352 |
| - raise ValueError("Please enter a value for `frac` OR `n`, not both") |
5353 |
| - |
5354 |
| - # Check for negative sizes |
5355 |
| - if n < 0: |
5356 |
| - raise ValueError( |
5357 |
| - "A negative number of rows requested. Please provide positive value." |
5358 |
| - ) |
| 5290 | + sampled_indices = sample.sample(obj_len, size, replace, weights, rs) |
| 5291 | + result = self.take(sampled_indices, axis=axis) |
5359 | 5292 |
|
5360 |
| - locs = rs.choice(axis_length, size=n, replace=replace, p=weights) |
5361 |
| - result = self.take(locs, axis=axis) |
5362 | 5293 | if ignore_index:
|
5363 | 5294 | result.index = ibase.default_index(len(result))
|
5364 | 5295 |
|
|
0 commit comments