-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
Add normalization to crosstab #12578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -371,7 +371,7 @@ def _convert_by(by): | |
|
||
|
||
def crosstab(index, columns, values=None, rownames=None, colnames=None, | ||
aggfunc=None, margins=False, dropna=True): | ||
aggfunc=None, margins=False, dropna=True, normalize=False): | ||
""" | ||
Compute a simple cross-tabulation of two (or more) factors. By default | ||
computes a frequency table of the factors unless an array of values and an | ||
|
@@ -384,9 +384,10 @@ def crosstab(index, columns, values=None, rownames=None, colnames=None, | |
columns : array-like, Series, or list of arrays/Series | ||
Values to group by in the columns | ||
values : array-like, optional | ||
Array of values to aggregate according to the factors | ||
Array of values to aggregate according to the factors. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the same thing as having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed last week, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but you are adding a default, that was not discussed. nor do I see a good reason for it. pls justify. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only kind of -- in the past, if The other option is to not allow users to normalize frequency tables of an arbitrary series of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nickeubank this is too magical the way you have done it. Pls remove this automatic |
||
Requires `aggfunc` be specified. | ||
aggfunc : function, optional | ||
If no values array is passed, computes a frequency table | ||
If specified, requires `values` be specified as well | ||
rownames : sequence, default None | ||
If passed, must match number of row arrays passed | ||
colnames : sequence, default None | ||
|
@@ -395,6 +396,16 @@ def crosstab(index, columns, values=None, rownames=None, colnames=None, | |
Add row/column margins (subtotals) | ||
dropna : boolean, default True | ||
Do not include columns whose entries are all NaN | ||
normalize : boolean, {'all', 'index', 'columns'}, or {0,1}, default False | ||
Normalize by dividing all values by the sum of values. | ||
|
||
- If passed 'all' or `True`, will normalize over all values. | ||
- If passed 'index' will normalize over each row. | ||
- If passed 'columns' will normalize over each column. | ||
- If margins is `True`, will also normalize margin values. | ||
|
||
.. versionadded:: 0.18.1 | ||
|
||
|
||
Notes | ||
----- | ||
|
@@ -438,18 +449,97 @@ def crosstab(index, columns, values=None, rownames=None, colnames=None, | |
data.update(zip(rownames, index)) | ||
data.update(zip(colnames, columns)) | ||
|
||
if values is None and aggfunc is not None: | ||
raise ValueError("aggfunc cannot be used without values.") | ||
|
||
if values is not None and aggfunc is None: | ||
raise ValueError("values cannot be used without an aggfunc.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you maybe also add a notice of this in the docstring? For example in the explanation of |
||
|
||
if values is None: | ||
df = DataFrame(data) | ||
df['__dummy__'] = 0 | ||
table = df.pivot_table('__dummy__', index=rownames, columns=colnames, | ||
aggfunc=len, margins=margins, dropna=dropna) | ||
return table.fillna(0).astype(np.int64) | ||
table = table.fillna(0).astype(np.int64) | ||
|
||
else: | ||
data['__dummy__'] = values | ||
df = DataFrame(data) | ||
table = df.pivot_table('__dummy__', index=rownames, columns=colnames, | ||
aggfunc=aggfunc, margins=margins, dropna=dropna) | ||
return table | ||
|
||
# Post-process | ||
if normalize is not False: | ||
table = _normalize(table, normalize=normalize, margins=margins) | ||
|
||
return table | ||
|
||
|
||
def _normalize(table, normalize, margins): | ||
|
||
if not isinstance(normalize, bool) and not isinstance(normalize, | ||
compat.string_types): | ||
axis_subs = {0: 'index', 1: 'columns'} | ||
try: | ||
normalize = axis_subs[normalize] | ||
except KeyError: | ||
raise ValueError("Not a valid normalize argument") | ||
|
||
if margins is False: | ||
|
||
# Actual Normalizations | ||
normalizers = { | ||
'all': lambda x: x / x.sum(axis=1).sum(axis=0), | ||
'columns': lambda x: x / x.sum(), | ||
'index': lambda x: x.div(x.sum(axis=1), axis=0) | ||
} | ||
|
||
normalizers[True] = normalizers['all'] | ||
|
||
try: | ||
f = normalizers[normalize] | ||
except KeyError: | ||
raise ValueError("Not a valid normalize argument") | ||
|
||
table = f(table) | ||
table = table.fillna(0) | ||
|
||
elif margins is True: | ||
|
||
column_margin = table.loc[:, 'All'].drop('All') | ||
index_margin = table.loc['All', :].drop('All') | ||
table = table.drop('All', axis=1).drop('All') | ||
|
||
# Normalize core | ||
table = _normalize(table, normalize=normalize, margins=False) | ||
|
||
# Fix Margins | ||
if normalize == 'columns': | ||
column_margin = column_margin / column_margin.sum() | ||
table = concat([table, column_margin], axis=1) | ||
table = table.fillna(0) | ||
|
||
elif normalize == 'index': | ||
index_margin = index_margin / index_margin.sum() | ||
table = table.append(index_margin) | ||
table = table.fillna(0) | ||
|
||
elif normalize == "all" or normalize is True: | ||
column_margin = column_margin / column_margin.sum() | ||
index_margin = index_margin / index_margin.sum() | ||
index_margin.loc['All'] = 1 | ||
table = concat([table, column_margin], axis=1) | ||
table = table.append(index_margin) | ||
|
||
table = table.fillna(0) | ||
|
||
else: | ||
raise ValueError("Not a valid normalize argument") | ||
|
||
else: | ||
raise ValueError("Not a valid margins argument") | ||
|
||
return table | ||
|
||
|
||
def _get_names(arrs, names, prefix='row'): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1021,6 +1021,150 @@ def test_margin_dropna(self): | |
expected.columns = Index(['dull', 'shiny', 'All'], name='c') | ||
tm.assert_frame_equal(actual, expected) | ||
|
||
def test_crosstab_normalize(self): | ||
# Issue 12578 | ||
df = pd.DataFrame({'a': [1, 2, 2, 2, 2], 'b': [3, 3, 4, 4, 4], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addtl tests: try these with all values for normalize & also do
|
||
'c': [1, 1, np.nan, 1, 1]}) | ||
|
||
rindex = pd.Index([1, 2], name='a') | ||
cindex = pd.Index([3, 4], name='b') | ||
full_normal = pd.DataFrame([[0.2, 0], [0.2, 0.6]], | ||
index=rindex, columns=cindex) | ||
row_normal = pd.DataFrame([[1.0, 0], [0.25, 0.75]], | ||
index=rindex, columns=cindex) | ||
col_normal = pd.DataFrame([[0.5, 0], [0.5, 1.0]], | ||
index=rindex, columns=cindex) | ||
|
||
# Check all normalize args | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize='all'), | ||
full_normal) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize=True), | ||
full_normal) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize='index'), | ||
row_normal) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize='columns'), | ||
col_normal) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize=1), | ||
pd.crosstab(df.a, df.b, normalize='columns')) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize=0), | ||
pd.crosstab(df.a, df.b, normalize='index')) | ||
|
||
row_normal_margins = pd.DataFrame([[1.0, 0], | ||
[0.25, 0.75], | ||
[0.4, 0.6]], | ||
index=pd.Index([1, 2, 'All'], | ||
name='a', | ||
dtype='object'), | ||
columns=pd.Index([3, 4], name='b')) | ||
col_normal_margins = pd.DataFrame([[0.5, 0, 0.2], [0.5, 1.0, 0.8]], | ||
index=pd.Index([1, 2], name='a', | ||
dtype='object'), | ||
columns=pd.Index([3, 4, 'All'], | ||
name='b')) | ||
|
||
all_normal_margins = pd.DataFrame([[0.2, 0, 0.2], | ||
[0.2, 0.6, 0.8], | ||
[0.4, 0.6, 1]], | ||
index=pd.Index([1, 2, 'All'], | ||
name='a', | ||
dtype='object'), | ||
columns=pd.Index([3, 4, 'All'], | ||
name='b')) | ||
|
||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize='index', | ||
margins=True), row_normal_margins) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize='columns', | ||
margins=True), col_normal_margins) | ||
tm.assert_frame_equal(pd.crosstab(df.a, df.b, normalize=True, | ||
margins=True), all_normal_margins) | ||
|
||
# Test arrays | ||
pd.crosstab([np.array([1, 1, 2, 2]), np.array([1, 2, 1, 2])], | ||
np.array([1, 2, 1, 2])) | ||
|
||
# Test with aggfunc | ||
norm_counts = pd.DataFrame([[0.25, 0, 0.25], | ||
[0.25, 0.5, 0.75], | ||
[0.5, 0.5, 1]], | ||
index=pd.Index([1, 2, 'All'], | ||
name='a', | ||
dtype='object'), | ||
columns=pd.Index([3, 4, 'All'], | ||
name='b')) | ||
test_case = pd.crosstab(df.a, df.b, df.c, aggfunc='count', | ||
normalize='all', | ||
margins=True) | ||
tm.assert_frame_equal(test_case, norm_counts) | ||
|
||
df = pd.DataFrame({'a': [1, 2, 2, 2, 2], 'b': [3, 3, 4, 4, 4], | ||
'c': [0, 4, np.nan, 3, 3]}) | ||
|
||
norm_sum = pd.DataFrame([[0, 0, 0.], | ||
[0.4, 0.6, 1], | ||
[0.4, 0.6, 1]], | ||
index=pd.Index([1, 2, 'All'], | ||
name='a', | ||
dtype='object'), | ||
columns=pd.Index([3, 4, 'All'], | ||
name='b', | ||
dtype='object')) | ||
test_case = pd.crosstab(df.a, df.b, df.c, aggfunc=np.sum, | ||
normalize='all', | ||
margins=True) | ||
tm.assert_frame_equal(test_case, norm_sum) | ||
|
||
def test_crosstab_with_empties(self): | ||
# Check handling of empties | ||
df = pd.DataFrame({'a': [1, 2, 2, 2, 2], 'b': [3, 3, 4, 4, 4], | ||
'c': [np.nan, np.nan, np.nan, np.nan, np.nan]}) | ||
|
||
empty = pd.DataFrame([[0.0, 0.0], [0.0, 0.0]], | ||
index=pd.Index([1, 2], | ||
name='a', | ||
dtype='int64'), | ||
columns=pd.Index([3, 4], name='b')) | ||
|
||
for i in [True, 'index', 'columns']: | ||
calculated = pd.crosstab(df.a, df.b, values=df.c, aggfunc='count', | ||
normalize=i) | ||
tm.assert_frame_equal(empty, calculated) | ||
|
||
nans = pd.DataFrame([[0.0, np.nan], [0.0, 0.0]], | ||
index=pd.Index([1, 2], | ||
name='a', | ||
dtype='int64'), | ||
columns=pd.Index([3, 4], name='b')) | ||
|
||
calculated = pd.crosstab(df.a, df.b, values=df.c, aggfunc='count', | ||
normalize=False) | ||
tm.assert_frame_equal(nans, calculated) | ||
|
||
def test_crosstab_errors(self): | ||
# Issue 12578 | ||
|
||
df = pd.DataFrame({'a': [1, 2, 2, 2, 2], 'b': [3, 3, 4, 4, 4], | ||
'c': [1, 1, np.nan, 1, 1]}) | ||
|
||
error = 'values cannot be used without an aggfunc.' | ||
with tm.assertRaisesRegexp(ValueError, error): | ||
pd.crosstab(df.a, df.b, values=df.c) | ||
|
||
error = 'aggfunc cannot be used without values' | ||
with tm.assertRaisesRegexp(ValueError, error): | ||
pd.crosstab(df.a, df.b, aggfunc=np.mean) | ||
|
||
error = 'Not a valid normalize argument' | ||
with tm.assertRaisesRegexp(ValueError, error): | ||
pd.crosstab(df.a, df.b, normalize='42') | ||
|
||
with tm.assertRaisesRegexp(ValueError, error): | ||
pd.crosstab(df.a, df.b, normalize=42) | ||
|
||
error = 'Not a valid margins argument' | ||
with tm.assertRaisesRegexp(ValueError, error): | ||
pd.crosstab(df.a, df.b, normalize='all', margins=42) | ||
|
||
|
||
if __name__ == '__main__': | ||
import nose | ||
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
versionadded tag here