Skip to content

Commit 74debcc

Browse files
author
Nick Eubank
committed
weight tweaks
1 parent 9d2b3e1 commit 74debcc

File tree

3 files changed

+36
-31
lines changed

3 files changed

+36
-31
lines changed

doc/source/whatsnew/v0.16.1.txt

+7-3
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,18 @@ for passing in a column for weights for non-uniform sampling, and for setting se
168168
example_series.sample(n=1, weights=example_weights2)
169169

170170

171-
When applied to a DataFrame, one may pass the name of a column to specify sampling weights,
172-
although note that the value of the weights column must sum to one.
171+
When applied to a DataFrame, one may pass the name of a column to specify sampling weights
172+
when sampling from rows (thought row names may not be passed to sample from rows).
173173

174174
.. ipython :: python
175175

176-
df = DataFrame({'col1':[9,8,7,6], 'weight_column':[0.5, 0.4, 0.1, 0]})
176+
df = DataFrame({'col1':[9,8,7,6], 'weight_column':[0.5, 0.4, 0.1, 0]}, index=['a', 'b', 'c', 'd'])
177177
df.sample(n=3, weights='weight_column')
178178

179+
df.sample(n=3, weights='weight_column', axis = )
180+
181+
182+
179183
.. _whatsnew_0161.api:
180184

181185
API changes

pandas/core/generic.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -1964,8 +1964,9 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
19641964
Sample with or without replacement. Default = False.
19651965
weights : str or ndarray-like, optional
19661966
Default 'None' results in equal probability weighting.
1967-
If called on a DataFrame or Panel, will also accept the name of a
1968-
column as a string. Must be same length as index.
1967+
If called on a DataFrame, will accept the name of a column
1968+
when axis = 0.
1969+
Weights must be same length as axis being sampled.
19691970
If weights do not sum to 1, they will be normalized to sum to 1.
19701971
Missing values in the weights column will be treated as zero.
19711972
inf and -inf values not allowed.
@@ -2003,17 +2004,18 @@ def sample(self, n=None, frac=None, replace=False, weights=None, random_state=No
20032004
# Check weights for compliance
20042005
if weights is not None:
20052006

2006-
# Strings acceptable if not a series
2007+
# Strings acceptable if a dataframe and axis = 0
20072008
if isinstance(weights, string_types):
2008-
2009-
if self.ndim > 1 :
2010-
try:
2011-
weights = self[weights]
2012-
except KeyError:
2013-
raise KeyError("String passed to weights not a valid name for an item in specified axis")
2014-
2009+
if isinstance(self, pd.DataFrame):
2010+
if axis == 0:
2011+
try:
2012+
weights = self[weights]
2013+
except KeyError:
2014+
raise KeyError("String passed to weights not a valid column")
2015+
else:
2016+
raise ValueError("Strings can only be passed to weights when sampling from rows on a DataFrame")
20152017
else:
2016-
raise ValueError("Strings cannot be passed as weights when sampling from a Series.")
2018+
raise ValueError("Strings cannot be passed as weights when sampling from a Series or Panel.")
20172019

20182020
#normalize format of weights to ndarray.
20192021
weights = pd.Series(weights, dtype = 'float64')

pandas/tests/test_generic.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,6 @@ def test_sample(self):
431431
weights_with_ninf[0] = -np.inf
432432
o.sample(n=3, weights=weights_with_ninf)
433433

434-
# Ensure proper error if string given as weight for Series
435-
s = Series(range(10))
436-
with tm.assertRaises(ValueError):
437-
s.sample(n=3, weights='weight_column')
438434

439435
# A few dataframe test with degenerate weights.
440436
easy_weight_list = [0]*10
@@ -443,33 +439,36 @@ def test_sample(self):
443439
df = pd.DataFrame({'col1':range(10,20),
444440
'col2':range(20,30),
445441
'colString': ['a']*10,
446-
'easyweights':easy_weight_list})
442+
'easyweights':easy_weight_list}, index = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
447443
sample1 = df.sample(n=1, weights='easyweights')
448444
assert_frame_equal(sample1, df.iloc[5:6])
449445

446+
# Ensure proper error if string given as weight for Series, panel, or
447+
# DataFrame with axis = 1.
448+
s = Series(range(10))
449+
with tm.assertRaises(ValueError):
450+
s.sample(n=3, weights='weight_column')
451+
452+
panel = pd.Panel(items = [0,1,2], major_axis = [2,3,4], minor_axis = [3,4,5])
453+
with tm.assertRaises(ValueError):
454+
panel.sample(n=1, weights='weight_column')
455+
456+
with tm.assertRaises(ValueError):
457+
df.sample(n=1, weights='weight_column', axis = 1)
458+
450459
# Check weighting key error
451460
with tm.assertRaises(KeyError):
452461
df.sample(n=3, weights='not_a_real_column_name')
453462

454463
# Check np.nan are replaced by zeros.
455464
weights_with_nan = [np.nan]*10
456465
weights_with_nan[5] = 0.5
457-
458-
sampled_df = df.sample(n=1, weights = weights_with_nan)
459-
tm.assert_frame_equal(sampled_df, df.iloc[5:6])
460-
461-
sampled_s = s.sample(n=1, weights = weights_with_nan)
462-
tm.assert_series_equal(sampled_s, s.iloc[5:6])
466+
self._compare(o.sample(n=1, weights=weights_with_nan), o.iloc[5:6])
463467

464468
# Check None are also replaced by zeros.
465469
weights_with_None = [None]*10
466470
weights_with_None[5] = 0.5
467-
468-
sampled_df2 = df.sample(n=1, weights = weights_with_None)
469-
tm.assert_frame_equal(sampled_df2, df.iloc[5:6])
470-
471-
sampled_s2 = s.sample(n=1, weights = weights_with_None)
472-
tm.assert_series_equal(sampled_s2, s.iloc[5:6])
471+
self._compare(o.sample(n=1, weights=weights_with_None), o.iloc[5:6])
473472

474473
# Check that re-normalizes weights that don't sum to one.
475474
weights_less_than_1 = [0]*10

0 commit comments

Comments
 (0)