Skip to content

Commit 17ec1d9

Browse files
author
Nick Eubank
committed
amend sample to return copy and align weight axis
1 parent 2396370 commit 17ec1d9

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

doc/source/whatsnew/v0.17.0.txt

+2
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,8 @@ Bug Fixes
628628
- Bug in ``stack`` when index or columns are not unique. (:issue:`10417`)
629629
- Bug in setting a Panel when an axis has a multi-index (:issue:`10360`)
630630
- Bug in ``USFederalHolidayCalendar`` where ``USMemorialDay`` and ``USMartinLutherKingJr`` were incorrect (:issue:`10278` and :issue:`9760` )
631+
- Bug in ``.sample()`` where returned object, if set, gives unnecessary ``SettingWithCopyWarning`` (:issue:`10738`)
632+
- Bug in ``.sample()`` where weights passed as Series were not aligned along axis before being treated positionally, potentially causing problems if weight indices were not aligned with sampled object. (:issue:`10738`)
631633

632634

633635

pandas/core/generic.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -2004,9 +2004,14 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20042004
Sample with or without replacement. Default = False.
20052005
weights : str or ndarray-like, optional
20062006
Default 'None' results in equal probability weighting.
2007+
If passed a Series, will align with target object on index. Index
2008+
values in weights not found in sampled object will be ignored and
2009+
index values in sampled object not in weights will be assigned
2010+
weights of zero.
20072011
If called on a DataFrame, will accept the name of a column
20082012
when axis = 0.
2009-
Weights must be same length as axis being sampled.
2013+
Unless weights are a Series, weights must be same length as axis
2014+
being sampled.
20102015
If weights do not sum to 1, they will be normalized to sum to 1.
20112016
Missing values in the weights column will be treated as zero.
20122017
inf and -inf values not allowed.
@@ -2019,7 +2024,7 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20192024
20202025
Returns
20212026
-------
2022-
Same type as caller.
2027+
A new object of same type as caller.
20232028
"""
20242029

20252030
if axis is None:
@@ -2034,6 +2039,10 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20342039
# Check weights for compliance
20352040
if weights is not None:
20362041

2042+
# If a series, align with frame
2043+
if isinstance(weights, pd.Series):
2044+
weights = weights.reindex(self.axes[axis])
2045+
20372046
# Strings acceptable if a dataframe and axis = 0
20382047
if isinstance(weights, string_types):
20392048
if isinstance(self, pd.DataFrame):
@@ -2063,7 +2072,10 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20632072

20642073
# Renormalize if don't sum to 1
20652074
if weights.sum() != 1:
2066-
weights = weights / weights.sum()
2075+
if weights.sum() != 0:
2076+
weights = weights / weights.sum()
2077+
else:
2078+
raise ValueError("Invalid weights: weights sum to zero")
20672079

20682080
weights = weights.values
20692081

@@ -2082,7 +2094,8 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20822094
raise ValueError("A negative number of rows requested. Please provide positive value.")
20832095

20842096
locs = rs.choice(axis_length, size=n, replace=replace, p=weights)
2085-
return self.take(locs, axis=axis)
2097+
return self.take(locs, axis=axis, is_copy=False)
2098+
20862099

20872100
_shared_docs['pipe'] = ("""
20882101
Apply func(self, \*args, \*\*kwargs)

pandas/tests/test_generic.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def test_sample(self):
374374

375375
self._compare(o.sample(frac=0.7,random_state=np.random.RandomState(test)),
376376
o.sample(frac=0.7, random_state=np.random.RandomState(test)))
377-
377+
378378

379379
# Check for error when random_state argument invalid.
380380
with tm.assertRaises(ValueError):
@@ -415,6 +415,10 @@ def test_sample(self):
415415
bad_weights = [0.5]*11
416416
o.sample(n=3, weights=bad_weights)
417417

418+
with tm.assertRaises(ValueError):
419+
bad_weight_series = Series([0,0,0.2])
420+
o.sample(n=4, weights=bad_weight_series)
421+
418422
# Check won't accept negative weights
419423
with tm.assertRaises(ValueError):
420424
bad_weights = [-0.1]*10
@@ -431,6 +435,16 @@ def test_sample(self):
431435
weights_with_ninf[0] = -np.inf
432436
o.sample(n=3, weights=weights_with_ninf)
433437

438+
# All zeros raises errors
439+
zero_weights = [0]*10
440+
with tm.assertRaises(ValueError):
441+
o.sample(n=3, weights=zero_weights)
442+
443+
# All missing weights
444+
nan_weights = [np.nan]*10
445+
with tm.assertRaises(ValueError):
446+
o.sample(n=3, weights=nan_weights)
447+
434448

435449
# A few dataframe test with degenerate weights.
436450
easy_weight_list = [0]*10
@@ -496,7 +510,6 @@ def test_sample(self):
496510
assert_frame_equal(df.sample(n=1, axis='index', weights=weight),
497511
df.iloc[5:6])
498512

499-
500513
# Check out of range axis values
501514
with tm.assertRaises(ValueError):
502515
df.sample(n=1, axis=2)
@@ -527,6 +540,26 @@ def test_sample(self):
527540
assert_panel_equal(p.sample(n=3, random_state=42), p.sample(n=3, axis=1, random_state=42))
528541
assert_frame_equal(df.sample(n=3, random_state=42), df.sample(n=3, axis=0, random_state=42))
529542

543+
# Test that function aligns weights with frame
544+
df = DataFrame({'col1':[5,6,7], 'col2':['a','b','c'], }, index = [9,5,3])
545+
s = Series([1,0,0], index=[3,5,9])
546+
assert_frame_equal(df.loc[[3]], df.sample(1, weights=s))
547+
548+
# Weights have index values to be dropped because not in
549+
# sampled DataFrame
550+
s2 = Series([0.001,0,10000], index=[3,5,10])
551+
assert_frame_equal(df.loc[[3]], df.sample(1, weights=s2))
552+
553+
# Weights have empty values to be filed with zeros
554+
s3 = Series([0.01,0], index=[3,5])
555+
assert_frame_equal(df.loc[[3]], df.sample(1, weights=s3))
556+
557+
# No overlap in weight and sampled DataFrame indices
558+
s4 = Series([1,0], index=[1,2])
559+
with tm.assertRaises(ValueError):
560+
df.sample(1, weights=s4)
561+
562+
530563
def test_size_compat(self):
531564
# GH8846
532565
# size property should be defined

0 commit comments

Comments
 (0)