Skip to content

Commit 8a2fcb0

Browse files
OXPHOSOXPHOS
OXPHOS
authored and
OXPHOS
committed
Fix 14072 pivot_table dropna
1 parent d50b162 commit 8a2fcb0

File tree

7 files changed

+73
-32
lines changed

7 files changed

+73
-32
lines changed

pandas/_libs/hashtable_class_helper.pxi.in

+9-5
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ cdef class {{name}}HashTable(HashTable):
330330
@cython.boundscheck(False)
331331
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
332332
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
333-
bint check_null=True):
333+
bint check_null=True, bint dropna=True):
334334
cdef:
335335
Py_ssize_t i, n = len(values)
336336
int64_t[:] labels
@@ -642,7 +642,7 @@ cdef class StringHashTable(HashTable):
642642
@cython.boundscheck(False)
643643
def get_labels(self, ndarray[object] values, ObjectVector uniques,
644644
Py_ssize_t count_prior, int64_t na_sentinel,
645-
bint check_null=1):
645+
bint check_null=1, bint dropna=True):
646646
cdef:
647647
Py_ssize_t i, n = len(values)
648648
int64_t[:] labels
@@ -815,7 +815,7 @@ cdef class PyObjectHashTable(HashTable):
815815

816816
def get_labels(self, ndarray[object] values, ObjectVector uniques,
817817
Py_ssize_t count_prior, int64_t na_sentinel,
818-
bint check_null=True):
818+
bint check_null=True, bint dropna=True):
819819
cdef:
820820
Py_ssize_t i, n = len(values)
821821
int64_t[:] labels
@@ -830,7 +830,11 @@ cdef class PyObjectHashTable(HashTable):
830830
val = values[i]
831831
hash(val)
832832

833-
if check_null and val != val or val is None:
833+
if check_null and val != val:
834+
labels[i] = na_sentinel
835+
continue
836+
837+
if dropna and val is None:
834838
labels[i] = na_sentinel
835839
continue
836840

@@ -968,5 +972,5 @@ cdef class MultiIndexHashTable(HashTable):
968972

969973
def get_labels(self, object mi, ObjectVector uniques,
970974
Py_ssize_t count_prior, int64_t na_sentinel,
971-
bint check_null=True):
975+
bint check_null=True, bint dropna=True):
972976
raise NotImplementedError

pandas/core/algorithms.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ def sort_mixed(values):
519519
return ordered, _ensure_platform_int(new_labels)
520520

521521

522-
def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
522+
def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None,
523+
dropna=False):
523524
"""
524525
Encode input values as an enumerated type or categorical variable
525526
@@ -552,7 +553,8 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
552553
table = hash_klass(size_hint or len(values))
553554
uniques = vec_klass()
554555
check_nulls = not is_integer_dtype(original)
555-
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls)
556+
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls,
557+
dropna)
556558

557559
labels = _ensure_platform_int(labels)
558560
uniques = uniques.to_array()

pandas/core/categorical.py

-4
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,6 @@ def _validate_categories(cls, categories, fastpath=False):
548548

549549
if not fastpath:
550550

551-
# Categories cannot contain NaN.
552-
if categories.hasnans:
553-
raise ValueError('Categorial categories cannot be null')
554-
555551
# Categories must be unique.
556552
if not categories.is_unique:
557553
raise ValueError('Categorical categories must be unique')

pandas/core/generic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4216,7 +4216,7 @@ def clip_lower(self, threshold, axis=None):
42164216
return self.where(subset, threshold, axis=axis)
42174217

42184218
def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
4219-
group_keys=True, squeeze=False, **kwargs):
4219+
group_keys=True, squeeze=False, dropna=True, **kwargs):
42204220
"""
42214221
Group series using mapper (dict or key function, apply given function
42224222
to group, return result as series) or by a series of columns.
@@ -4272,7 +4272,7 @@ def groupby(self, by=None, axis=0, level=None, as_index=True, sort=True,
42724272
axis = self._get_axis_number(axis)
42734273
return groupby(self, by=by, axis=axis, level=level, as_index=as_index,
42744274
sort=sort, group_keys=group_keys, squeeze=squeeze,
4275-
**kwargs)
4275+
dropna=dropna, **kwargs)
42764276

42774277
def asfreq(self, freq, method=None, how=None, normalize=False,
42784278
fill_value=None):

pandas/core/groupby.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ class _GroupBy(PandasObject, SelectionMixin):
361361

362362
def __init__(self, obj, keys=None, axis=0, level=None,
363363
grouper=None, exclusions=None, selection=None, as_index=True,
364-
sort=True, group_keys=True, squeeze=False, **kwargs):
364+
sort=True, group_keys=True, squeeze=False, dropna=True,
365+
**kwargs):
365366

366367
self._selection = selection
367368

@@ -388,7 +389,8 @@ def __init__(self, obj, keys=None, axis=0, level=None,
388389
axis=axis,
389390
level=level,
390391
sort=sort,
391-
mutated=self.mutated)
392+
mutated=self.mutated,
393+
dropna=dropna)
392394

393395
self.obj = obj
394396
self.axis = obj._get_axis_number(axis)
@@ -1614,15 +1616,15 @@ def tail(self, n=5):
16141616

16151617

16161618
@Appender(GroupBy.__doc__)
1617-
def groupby(obj, by, **kwds):
1619+
def groupby(obj, by, dropna=True, **kwds):
16181620
if isinstance(obj, Series):
16191621
klass = SeriesGroupBy
16201622
elif isinstance(obj, DataFrame):
16211623
klass = DataFrameGroupBy
16221624
else: # pragma: no cover
16231625
raise TypeError('invalid type: %s' % type(obj))
16241626

1625-
return klass(obj, by, **kwds)
1627+
return klass(obj, by, dropna=dropna, **kwds)
16261628

16271629

16281630
def _get_axes(group):
@@ -2339,7 +2341,7 @@ class Grouping(object):
23392341
"""
23402342

23412343
def __init__(self, index, grouper=None, obj=None, name=None, level=None,
2342-
sort=True, in_axis=False):
2344+
sort=True, in_axis=False, dropna=True):
23432345

23442346
self.name = name
23452347
self.level = level
@@ -2348,6 +2350,7 @@ def __init__(self, index, grouper=None, obj=None, name=None, level=None,
23482350
self.sort = sort
23492351
self.obj = obj
23502352
self.in_axis = in_axis
2353+
self.dropna = dropna
23512354

23522355
# right place for this?
23532356
if isinstance(grouper, (Series, Index)) and name is None:
@@ -2468,7 +2471,7 @@ def group_index(self):
24682471
def _make_labels(self):
24692472
if self._labels is None or self._group_index is None:
24702473
labels, uniques = algorithms.factorize(
2471-
self.grouper, sort=self.sort)
2474+
self.grouper, sort=self.sort, dropna=self.dropna)
24722475
uniques = Index(uniques, name=self.name)
24732476
self._labels = labels
24742477
self._group_index = uniques
@@ -2480,7 +2483,7 @@ def groups(self):
24802483

24812484

24822485
def _get_grouper(obj, key=None, axis=0, level=None, sort=True,
2483-
mutated=False):
2486+
mutated=False, dropna=True):
24842487
"""
24852488
create and return a BaseGrouper, which is an internal
24862489
mapping of how to create the grouper indexers.
@@ -2633,7 +2636,8 @@ def is_in_obj(gpr):
26332636
name=name,
26342637
level=level,
26352638
sort=sort,
2636-
in_axis=in_axis) \
2639+
in_axis=in_axis,
2640+
dropna=dropna) \
26372641
if not isinstance(gpr, Grouping) else gpr
26382642

26392643
groupings.append(ping)

pandas/core/reshape/pivot.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
132132
pass
133133
values = list(values)
134134

135-
grouped = data.groupby(keys)
135+
grouped = data.groupby(keys, dropna=dropna)
136136
agged = grouped.agg(aggfunc)
137137

138138
table = agged
@@ -159,15 +159,15 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
159159
if isinstance(table, DataFrame):
160160
table = table.sort_index(axis=1)
161161

162-
if fill_value is not None:
163-
table = table.fillna(value=fill_value, downcast='infer')
164-
165162
if margins:
166163
if dropna:
167164
data = data[data.notnull().all(axis=1)]
168165
table = _add_margins(table, data, values, rows=index,
169166
cols=columns, aggfunc=aggfunc,
170-
margins_name=margins_name)
167+
margins_name=margins_name, dropna=dropna)
168+
169+
if fill_value is not None:
170+
table = table.fillna(value=fill_value, downcast='infer')
171171

172172
# discard the top level
173173
if values_passed and not values_multi and not table.empty and \
@@ -188,7 +188,7 @@ def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
188188

189189

190190
def _add_margins(table, data, values, rows, cols, aggfunc,
191-
margins_name='All'):
191+
margins_name='All', dropna=True):
192192
if not isinstance(margins_name, compat.string_types):
193193
raise ValueError('margins_name argument must be a string')
194194

@@ -219,7 +219,8 @@ def _add_margins(table, data, values, rows, cols, aggfunc,
219219
marginal_result_set = _generate_marginal_results(table, data, values,
220220
rows, cols, aggfunc,
221221
grand_margin,
222-
margins_name)
222+
margins_name,
223+
dropna=dropna)
223224
if not isinstance(marginal_result_set, tuple):
224225
return marginal_result_set
225226
result, margin_keys, row_margin = marginal_result_set
@@ -277,8 +278,7 @@ def _compute_grand_margin(data, values, aggfunc,
277278

278279

279280
def _generate_marginal_results(table, data, values, rows, cols, aggfunc,
280-
grand_margin,
281-
margins_name='All'):
281+
grand_margin, margins_name='All', dropna=True):
282282
if len(cols) > 0:
283283
# need to "interleave" the margins
284284
table_pieces = []
@@ -288,7 +288,8 @@ def _all_key(key):
288288
return (key, margins_name) + ('',) * (len(cols) - 1)
289289

290290
if len(rows) > 0:
291-
margin = data[rows + values].groupby(rows).agg(aggfunc)
291+
margin = data[rows +
292+
values].groupby(rows, dropna=dropna).agg(aggfunc)
292293
cat_axis = 1
293294

294295
for key, piece in table.groupby(level=0, axis=cat_axis):
@@ -325,7 +326,8 @@ def _all_key(key):
325326
margin_keys = table.columns
326327

327328
if len(cols) > 0:
328-
row_margin = data[cols + values].groupby(cols).agg(aggfunc)
329+
row_margin = data[cols +
330+
values].groupby(cols, dropna=dropna).agg(aggfunc)
329331
row_margin = row_margin.stack()
330332

331333
# slight hack

pandas/tests/reshape/test_pivot.py

+33
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,39 @@ def test_pivot_table_dropna(self):
9090
tm.assert_index_equal(pv_col.columns, m)
9191
tm.assert_index_equal(pv_ind.index, m)
9292

93+
def test_pivot_table_dropna_margins(self):
94+
# GH 14072
95+
df = DataFrame([
96+
[1, 'a', 'A'],
97+
[1, 'b', 'B'],
98+
[1, 'c', None]],
99+
columns=['x', 'y', 'z'])
100+
101+
result_false = df.pivot_table(values='x', index='y', columns='z',
102+
aggfunc='sum', fill_value=0,
103+
margins=True, dropna=False)
104+
expected_index = Series(['a', 'b', 'c', 'All'], name='y')
105+
expected_columns = Series([None, 'A', 'B', 'All'], name='z')
106+
expected_false = DataFrame([[0, 1, 0, 1],
107+
[0, 0, 1, 1],
108+
[1, 0, 0, 1],
109+
[1, 1, 1, 3]],
110+
index=expected_index,
111+
columns=expected_columns)
112+
tm.assert_frame_equal(expected_false, result_false)
113+
114+
result_true = df.pivot_table(values='x', index='y', columns='z',
115+
aggfunc='sum', fill_value=0,
116+
margins=True, dropna=True)
117+
expected_index = Series(['a', 'b', 'All'], name='y')
118+
expected_columns = Series(['A', 'B', 'All'], name='z')
119+
expected_true = DataFrame([[1, 0, 1],
120+
[0, 1, 1],
121+
[1, 1, 2]],
122+
index=expected_index,
123+
columns=expected_columns)
124+
tm.assert_frame_equal(expected_true, result_true)
125+
93126
def test_pivot_table_dropna_categoricals(self):
94127
# GH 15193
95128
categories = ['a', 'b', 'c', 'd']

0 commit comments

Comments
 (0)