diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index a0bf2f9b3758a..9d475390175b2 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -515,7 +515,6 @@ Other Enhancements - Options added to allow one to turn on/off using ``bottleneck`` and ``numexpr``, see :ref:`here ` (:issue:`16157`) - ``DataFrame.style.bar()`` now accepts two more options to further customize the bar chart. Bar alignment is set with ``align='left'|'mid'|'zero'``, the default is "left", which is backward compatible; You can now pass a list of ``color=[color_negative, color_positive]``. (:issue:`14757`) - .. _ISO 8601 duration: https://en.wikipedia.org/wiki/ISO_8601#Durations diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index b4ca3f011a81d..a6b6d704737bd 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -37,6 +37,7 @@ Other Enhancements - :func:`api.types.infer_dtype` now infers decimals. (:issue: `15690`) - :func:`read_feather` has gained the ``nthreads`` parameter for multi-threaded operations (:issue:`16359`) - :func:`DataFrame.clip()` and :func: `Series.cip()` have gained an inplace argument. (:issue: `15388`) +- :func:`crosstab` has gained a ``margins_name`` parameter to define the name of the row / column that will contain the totals when margins=True. (:issue:`15972`) .. _whatsnew_0210.api_breaking: diff --git a/pandas/core/reshape/pivot.py b/pandas/core/reshape/pivot.py index 74dbbfc00cb11..b562f8a32f5c9 100644 --- a/pandas/core/reshape/pivot.py +++ b/pandas/core/reshape/pivot.py @@ -388,7 +388,8 @@ def _convert_by(by): def crosstab(index, columns, values=None, rownames=None, colnames=None, - aggfunc=None, margins=False, dropna=True, normalize=False): + aggfunc=None, margins=False, margins_name='All', dropna=True, + normalize=False): """ Compute a simple cross-tabulation of two (or more) factors. By default computes a frequency table of the factors unless an array of values and an @@ -411,6 +412,12 @@ def crosstab(index, columns, values=None, rownames=None, colnames=None, If passed, must match number of column arrays passed margins : boolean, default False Add row/column margins (subtotals) + margins_name : string, default 'All' + Name of the row / column that will contain the totals + when margins is True. + + .. versionadded:: 0.21.0 + dropna : boolean, default True Do not include columns whose entries are all NaN normalize : boolean, {'all', 'index', 'columns'}, or {0,1}, default False @@ -490,23 +497,26 @@ def crosstab(index, columns, values=None, rownames=None, colnames=None, df = DataFrame(data) df['__dummy__'] = 0 table = df.pivot_table('__dummy__', index=rownames, columns=colnames, - aggfunc=len, margins=margins, dropna=dropna) + aggfunc=len, margins=margins, + margins_name=margins_name, dropna=dropna) table = table.fillna(0).astype(np.int64) else: data['__dummy__'] = values df = DataFrame(data) table = df.pivot_table('__dummy__', index=rownames, columns=colnames, - aggfunc=aggfunc, margins=margins, dropna=dropna) + aggfunc=aggfunc, margins=margins, + margins_name=margins_name, dropna=dropna) # Post-process if normalize is not False: - table = _normalize(table, normalize=normalize, margins=margins) + table = _normalize(table, normalize=normalize, margins=margins, + margins_name=margins_name) return table -def _normalize(table, normalize, margins): +def _normalize(table, normalize, margins, margins_name='All'): if not isinstance(normalize, bool) and not isinstance(normalize, compat.string_types): @@ -537,9 +547,9 @@ def _normalize(table, normalize, margins): elif margins is True: - column_margin = table.loc[:, 'All'].drop('All') - index_margin = table.loc['All', :].drop('All') - table = table.drop('All', axis=1).drop('All') + column_margin = table.loc[:, margins_name].drop(margins_name) + index_margin = table.loc[margins_name, :].drop(margins_name) + table = table.drop(margins_name, axis=1).drop(margins_name) # to keep index and columns names table_index_names = table.index.names table_columns_names = table.columns.names @@ -561,7 +571,7 @@ def _normalize(table, normalize, margins): elif normalize == "all" or normalize is True: column_margin = column_margin / column_margin.sum() index_margin = index_margin / index_margin.sum() - index_margin.loc['All'] = 1 + index_margin.loc[margins_name] = 1 table = concat([table, column_margin], axis=1) table = table.append(index_margin) diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py index 270a93e4ae382..fc5a2eb468d4f 100644 --- a/pandas/tests/reshape/test_pivot.py +++ b/pandas/tests/reshape/test_pivot.py @@ -1071,6 +1071,43 @@ def test_crosstab_margins(self): exp_rows = exp_rows.fillna(0).astype(np.int64) tm.assert_series_equal(all_rows, exp_rows) + def test_crosstab_margins_set_margin_name(self): + # GH 15972 + a = np.random.randint(0, 7, size=100) + b = np.random.randint(0, 3, size=100) + c = np.random.randint(0, 5, size=100) + + df = DataFrame({'a': a, 'b': b, 'c': c}) + + result = crosstab(a, [b, c], rownames=['a'], colnames=('b', 'c'), + margins=True, margins_name='TOTAL') + + assert result.index.names == ('a',) + assert result.columns.names == ['b', 'c'] + + all_cols = result['TOTAL', ''] + exp_cols = df.groupby(['a']).size().astype('i8') + # to keep index.name + exp_margin = Series([len(df)], index=Index(['TOTAL'], name='a')) + exp_cols = exp_cols.append(exp_margin) + exp_cols.name = ('TOTAL', '') + + tm.assert_series_equal(all_cols, exp_cols) + + all_rows = result.loc['TOTAL'] + exp_rows = df.groupby(['b', 'c']).size().astype('i8') + exp_rows = exp_rows.append(Series([len(df)], index=[('TOTAL', '')])) + exp_rows.name = 'TOTAL' + + exp_rows = exp_rows.reindex(all_rows.index) + exp_rows = exp_rows.fillna(0).astype(np.int64) + tm.assert_series_equal(all_rows, exp_rows) + + for margins_name in [666, None, ['a', 'b']]: + with pytest.raises(ValueError): + crosstab(a, [b, c], rownames=['a'], colnames=('b', 'c'), + margins=True, margins_name=margins_name) + def test_crosstab_pass_values(self): a = np.random.randint(0, 7, size=100) b = np.random.randint(0, 3, size=100)