-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
added random_split in generic.py, for DataFrames etc. #11253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
add1acb
ef350cd
8b2b06c
81d8ba5
ed261ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import datetime | ||
import collections | ||
import warnings | ||
from numbers import Real | ||
|
||
from pandas.compat import( | ||
zip, builtins, range, long, lzip, | ||
|
@@ -296,6 +297,62 @@ def groups(self): | |
return self.grouper.groups | ||
|
||
|
||
class Partitioner(Grouper): | ||
''' | ||
|
||
''' | ||
|
||
def __init__(self, proportions=(1,1), axis=None): | ||
self._proportions = proportions | ||
self._axis = axis | ||
self.key = None | ||
# check weight type | ||
if len(self._proportions) < 2: | ||
raise ValueError("must split into more than 1 partition") | ||
for w in self._proportions: | ||
if not (com.is_float(w) or com.is_integer(w)) or w <=0: | ||
raise ValueError("weights must be strictly positive real numbers") | ||
|
||
# compute proportions as fractions | ||
self._proportions = np.asarray(self._proportions, dtype="float64") | ||
self._proportions = self._proportions/self._proportions.sum() | ||
super(Partitioner, self).__init__() | ||
|
||
def _get_grouper(self, obj): | ||
if self._axis is None: | ||
self._axis = obj._stat_axis_number | ||
self._axis = obj._get_axis_number(self._axis) | ||
axis_length = obj.shape[self._axis] | ||
|
||
numbers = np.rint(self._proportions * axis_length).astype("int32") | ||
|
||
newcol = reduce(lambda x, y: x + y, [[x]*numbers[x] for x in range(len(numbers))]) | ||
while len(newcol) < axis_length: | ||
newcol.append(newcol[-1]) | ||
|
||
self._transform(newcol) | ||
|
||
grouping = Grouping(obj._get_axis(self._axis), grouper=Series(newcol), obj=obj, sort=True, in_axis=True) | ||
|
||
return None, BaseGrouper(self._axis, [grouping]), obj | ||
|
||
def _transform(self, newcol): | ||
pass | ||
|
||
class RandomPartitioner(Partitioner): | ||
''' | ||
TODO | ||
''' | ||
|
||
def __init__(self, proportions=(1,1), axis=None, random=None): | ||
# Process random_state argument | ||
self.rs = com._random_state(random) | ||
super(RandomPartitioner, self).__init__(proportions, axis) | ||
|
||
def _transform(self, newcol): | ||
self.rs.shuffle(newcol) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what this is for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which part? |
||
|
||
|
||
class GroupByPlot(PandasObject): | ||
""" | ||
Class implementing the .plot attribute for groupby objects | ||
|
@@ -658,6 +715,10 @@ def __iter__(self): | |
""" | ||
return self.grouper.get_iterator(self.obj, axis=self.axis) | ||
|
||
def split(self): | ||
acc = [x for _, x in self] | ||
return tuple(acc) | ||
|
||
def apply(self, func, *args, **kwargs): | ||
""" | ||
Apply function and combine results together in an intelligent way. The | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -399,8 +399,13 @@ def test_grouper_multilevel_freq(self): | |
pd.Grouper(level=1, freq='W')]).sum() | ||
assert_frame_equal(result, expected) | ||
|
||
def test_grouper_creation_bug(self): | ||
def test_grouper_random(self): | ||
df = DataFrame({"A": [0,1,2,3,4,5], "b": [10,11,12,13,14,15]}) | ||
g = df.groupby(pd.RandomPartitioner((1,2))) | ||
a, b = g.split() | ||
assert_frame_equal(df, df) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. your test is not testing anything obviously need lots more tests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I know, this was just a placeholder |
||
|
||
def test_grouper_creation_bug(self): | ||
# GH 8795 | ||
df = DataFrame({'A':[0,0,1,1,2,2], 'B':[1,2,3,4,5,6]}) | ||
g = df.groupby('A') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obviously add a doc-string :)