Skip to content

Commit c98dcdf

Browse files
committed
Merge pull request #10354 from behzadnouri/cat-reduce
BUG: closes bug in apply when function returns categorical
2 parents 37fa925 + 47c0695 commit c98dcdf

File tree

4 files changed

+32
-33
lines changed

4 files changed

+32
-33
lines changed

doc/source/whatsnew/v0.17.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ Performance Improvements
5858

5959
Bug Fixes
6060
~~~~~~~~~
61+
- Bug in ``DataFrame.apply`` when function returns categorical series. (:issue:`9573`)

pandas/core/internals.py

+3
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,9 @@ def is_view(self):
16701670
def to_dense(self):
16711671
return self.values.to_dense().view()
16721672

1673+
def convert(self, copy=True, **kwargs):
1674+
return [self.copy() if copy else self]
1675+
16731676
@property
16741677
def shape(self):
16751678
return (len(self.mgr_locs), len(self.values))

pandas/src/reduce.pyx

+21-33
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@ from distutils.version import LooseVersion
66

77
is_numpy_prior_1_6_2 = LooseVersion(np.__version__) < '1.6.2'
88

9+
cdef _get_result_array(object obj,
10+
Py_ssize_t size,
11+
Py_ssize_t cnt):
12+
13+
if isinstance(obj, np.ndarray) \
14+
or isinstance(obj, list) and len(obj) == cnt \
15+
or getattr(obj, 'shape', None) == (cnt,):
16+
raise ValueError('function does not reduce')
17+
18+
return np.empty(size, dtype='O')
19+
20+
921
cdef class Reducer:
1022
'''
1123
Performs generic reduction operation on a C or Fortran-contiguous ndarray
@@ -124,7 +136,9 @@ cdef class Reducer:
124136
if hasattr(res,'values'):
125137
res = res.values
126138
if i == 0:
127-
result = self._get_result_array(res)
139+
result = _get_result_array(res,
140+
self.nresults,
141+
len(self.dummy))
128142
it = <flatiter> PyArray_IterNew(result)
129143

130144
PyArray_SETITEM(result, PyArray_ITER_DATA(it), res)
@@ -143,17 +157,6 @@ cdef class Reducer:
143157

144158
return result
145159

146-
def _get_result_array(self, object res):
147-
try:
148-
assert(not isinstance(res, np.ndarray))
149-
assert(not (isinstance(res, list) and len(res) == len(self.dummy)))
150-
151-
result = np.empty(self.nresults, dtype='O')
152-
result[0] = res
153-
except Exception:
154-
raise ValueError('function does not reduce')
155-
return result
156-
157160

158161
cdef class SeriesBinGrouper:
159162
'''
@@ -257,8 +260,10 @@ cdef class SeriesBinGrouper:
257260
res = self.f(cached_typ)
258261
res = _extract_result(res)
259262
if not initialized:
260-
result = self._get_result_array(res)
261263
initialized = 1
264+
result = _get_result_array(res,
265+
self.ngroups,
266+
len(self.dummy_arr))
262267

263268
util.assign_value_1d(result, i, res)
264269

@@ -277,16 +282,6 @@ cdef class SeriesBinGrouper:
277282

278283
return result, counts
279284

280-
def _get_result_array(self, object res):
281-
try:
282-
assert(not isinstance(res, np.ndarray))
283-
assert(not (isinstance(res, list) and len(res) == len(self.dummy_arr)))
284-
285-
result = np.empty(self.ngroups, dtype='O')
286-
except Exception:
287-
raise ValueError('function does not reduce')
288-
return result
289-
290285

291286
cdef class SeriesGrouper:
292287
'''
@@ -388,8 +383,10 @@ cdef class SeriesGrouper:
388383
res = self.f(cached_typ)
389384
res = _extract_result(res)
390385
if not initialized:
391-
result = self._get_result_array(res)
392386
initialized = 1
387+
result = _get_result_array(res,
388+
self.ngroups,
389+
len(self.dummy_arr))
393390

394391
util.assign_value_1d(result, lab, res)
395392
counts[lab] = group_size
@@ -410,15 +407,6 @@ cdef class SeriesGrouper:
410407

411408
return result, counts
412409

413-
def _get_result_array(self, object res):
414-
try:
415-
assert(not isinstance(res, np.ndarray))
416-
assert(not (isinstance(res, list) and len(res) == len(self.dummy_arr)))
417-
418-
result = np.empty(self.ngroups, dtype='O')
419-
except Exception:
420-
raise ValueError('function does not reduce')
421-
return result
422410

423411
cdef inline _extract_result(object res):
424412
''' extract the result object, it might be a 0-dim ndarray

pandas/tests/test_frame.py

+7
Original file line numberDiff line numberDiff line change
@@ -10382,6 +10382,13 @@ def test_apply(self):
1038210382
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=['a', 'a', 'c'])
1038310383
self.assertRaises(ValueError, df.apply, lambda x: x, 2)
1038410384

10385+
# GH9573
10386+
df = DataFrame({'c0':['A','A','B','B'], 'c1':['C','C','D','D']})
10387+
df = df.apply(lambda ts: ts.astype('category'))
10388+
self.assertEqual(df.shape, (4, 2))
10389+
self.assertTrue(isinstance(df['c0'].dtype, com.CategoricalDtype))
10390+
self.assertTrue(isinstance(df['c1'].dtype, com.CategoricalDtype))
10391+
1038510392
def test_apply_mixed_datetimelike(self):
1038610393
# mixed datetimelike
1038710394
# GH 7778

0 commit comments

Comments
 (0)