Skip to content

Commit c6b9f19

Browse files
committed
ENH: provide boolean indexing with dtype preservation if possible
1 parent fc8de6d commit c6b9f19

File tree

3 files changed

+129
-37
lines changed

3 files changed

+129
-37
lines changed

pandas/core/frame.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3714,14 +3714,14 @@ def _combine_match_columns(self, other, func, fill_value=None):
37143714
if fill_value is not None:
37153715
raise NotImplementedError
37163716

3717-
new_data = left._data.where(func, right, axes = [left.columns, self.index])
3717+
new_data = left._data.eval(func, right, axes = [left.columns, self.index])
37183718
return self._constructor(new_data)
37193719

37203720
def _combine_const(self, other, func, raise_on_error = True):
37213721
if self.empty:
37223722
return self
37233723

3724-
new_data = self._data.where(func, other, raise_on_error=raise_on_error)
3724+
new_data = self._data.eval(func, other, raise_on_error=raise_on_error)
37253725
return self._constructor(new_data)
37263726

37273727
def _compare_frame(self, other, func):
@@ -5293,8 +5293,7 @@ def where(self, cond, other=NA, inplace=False, try_cast=False, raise_on_error=Tr
52935293
self._data = self._data.putmask(cond,other,inplace=True)
52945294

52955295
else:
5296-
func = lambda values, others, conds: np.where(conds, values, others)
5297-
new_data = self._data.where(func, other, cond, raise_on_error=raise_on_error, try_cast=try_cast)
5296+
new_data = self._data.where(other, cond, raise_on_error=raise_on_error, try_cast=try_cast)
52985297

52995298
return self._constructor(new_data)
53005299

pandas/core/internals.py

+106-27
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,16 @@ def shift(self, indexer, periods):
384384
new_values[:, periods:] = np.nan
385385
return make_block(new_values, self.items, self.ref_items)
386386

387-
def where(self, func, other, cond = None, raise_on_error = True, try_cast = False):
387+
def eval(self, func, other, raise_on_error = True, try_cast = False):
388388
"""
389-
evaluate the block; return result block(s) from the result
389+
evaluate the block; return result block from the result
390390
391391
Parameters
392392
----------
393393
func : how to combine self, other
394394
other : a ndarray/object
395-
cond : the condition to respect, optional
396-
raise_on_error : if True, raise when I can't perform the function,
397-
False by default (and just return the data that we had coming in)
395+
raise_on_error : if True, raise when I can't perform the function, False by default (and just return
396+
the data that we had coming in)
398397
399398
Returns
400399
-------
@@ -414,28 +413,7 @@ def where(self, func, other, cond = None, raise_on_error = True, try_cast = Fals
414413
values = values.T
415414
is_transposed = True
416415

417-
# see if we can align cond
418-
if cond is not None:
419-
if not hasattr(cond, 'shape'):
420-
raise ValueError('where must have a condition that is ndarray'
421-
' like')
422-
if hasattr(cond, 'reindex_axis'):
423-
axis = getattr(cond, '_het_axis', 0)
424-
cond = cond.reindex_axis(self.items, axis=axis,
425-
copy=True).values
426-
else:
427-
cond = cond.values
428-
429-
# may need to undo transpose of values
430-
if hasattr(values, 'ndim'):
431-
if (values.ndim != cond.ndim or
432-
values.shape == cond.shape[::-1]):
433-
values = values.T
434-
is_transposed = not is_transposed
435-
436416
args = [ values, other ]
437-
if cond is not None:
438-
args.append(cond)
439417
try:
440418
result = func(*args)
441419
except:
@@ -458,7 +436,105 @@ def where(self, func, other, cond = None, raise_on_error = True, try_cast = Fals
458436
if try_cast:
459437
result = self._try_cast_result(result)
460438

461-
return [ make_block(result, self.items, self.ref_items) ]
439+
return make_block(result, self.items, self.ref_items)
440+
441+
def where(self, other, cond, raise_on_error = True, try_cast = False):
442+
"""
443+
evaluate the block; return result block(s) from the result
444+
445+
Parameters
446+
----------
447+
other : a ndarray/object
448+
cond : the condition to respect
449+
raise_on_error : if True, raise when I can't perform the function, False by default (and just return
450+
the data that we had coming in)
451+
452+
Returns
453+
-------
454+
a new block(s), the result of the func
455+
"""
456+
457+
values = self.values
458+
459+
# see if we can align other
460+
if hasattr(other,'reindex_axis'):
461+
axis = getattr(other,'_het_axis',0)
462+
other = other.reindex_axis(self.items, axis=axis, copy=True).values
463+
464+
# make sure that we can broadcast
465+
is_transposed = False
466+
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
467+
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
468+
values = values.T
469+
is_transposed = True
470+
471+
# see if we can align cond
472+
if not hasattr(cond,'shape'):
473+
raise ValueError("where must have a condition that is ndarray like")
474+
if hasattr(cond,'reindex_axis'):
475+
axis = getattr(cond,'_het_axis',0)
476+
cond = cond.reindex_axis(self.items, axis=axis, copy=True).values
477+
else:
478+
cond = cond.values
479+
480+
# may need to undo transpose of values
481+
if hasattr(values, 'ndim'):
482+
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
483+
values = values.T
484+
is_transposed = not is_transposed
485+
486+
# our where function
487+
def func(c,v,o):
488+
if c.flatten().all():
489+
return v
490+
491+
try:
492+
return np.where(c,v,o)
493+
except:
494+
if raise_on_error:
495+
raise TypeError('Coulnd not operate %s with block values'
496+
% repr(o))
497+
else:
498+
# return the values
499+
result = np.empty(v.shape,dtype='O')
500+
result.fill(np.nan)
501+
return result
502+
503+
def create_block(result, items, transpose = True):
504+
if not isinstance(result, np.ndarray):
505+
raise TypeError('Could not compare %s with block values'
506+
% repr(other))
507+
508+
if transpose and is_transposed:
509+
result = result.T
510+
511+
# try to cast if requested
512+
if try_cast:
513+
result = self._try_cast_result(result)
514+
515+
return make_block(result, items, self.ref_items)
516+
517+
# see if we can operate on the entire block, or need item-by-item
518+
if cond.all().any():
519+
result_blocks = []
520+
for item in self.items:
521+
loc = self.items.get_loc(item)
522+
item = self.items.take([loc])
523+
v = values.take([loc])
524+
c = cond.take([loc])
525+
o = other.take([loc]) if hasattr(other,'shape') else other
526+
527+
result = func(c,v,o)
528+
if len(result) == 1:
529+
result = np.repeat(result,self.shape[1:])
530+
531+
result = result.reshape(((1,) + self.shape[1:]))
532+
result_blocks.append(create_block(result, item, transpose = False))
533+
534+
return result_blocks
535+
else:
536+
result = func(cond,values,other)
537+
return create_block(result, self.items)
462538

463539
def _mask_missing(array, missing_values):
464540
if not isinstance(missing_values, (list, np.ndarray)):
@@ -840,6 +916,9 @@ def apply(self, f, *args, **kwargs):
840916
def where(self, *args, **kwargs):
841917
return self.apply('where', *args, **kwargs)
842918

919+
def eval(self, *args, **kwargs):
920+
return self.apply('eval', *args, **kwargs)
921+
843922
def putmask(self, *args, **kwargs):
844923
return self.apply('putmask', *args, **kwargs)
845924

pandas/tests/test_frame.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,6 @@ def test_getitem_boolean(self):
244244

245245
def test_getitem_boolean_casting(self):
246246

247-
#### this currently disabled ###
248-
249247
# don't upcast if we don't need to
250248
df = self.tsframe.copy()
251249
df['E'] = 1
@@ -254,8 +252,10 @@ def test_getitem_boolean_casting(self):
254252
df['F'] = df['F'].astype('int64')
255253
casted = df[df>0]
256254
result = casted.get_dtype_counts()
257-
#expected = Series({'float64': 4, 'int32' : 1, 'int64' : 1})
258-
expected = Series({'float64': 6 })
255+
expected = Series({'float64': 4, 'int32' : 1, 'int64' : 1})
256+
257+
### when we always cast here's the result ###
258+
#expected = Series({'float64': 6 })
259259
assert_series_equal(result, expected)
260260

261261

@@ -5997,6 +5997,19 @@ def _check_get(df, cond, check_dtypes = True):
59975997
cond = df > 0
59985998
_check_get(df, cond)
59995999

6000+
6001+
# upcasting case (GH # 2794)
6002+
df = DataFrame(dict([ (c,Series([1]*3,dtype=c)) for c in ['int64','int32','float32','float64'] ]))
6003+
df.ix[1,:] = 0
6004+
6005+
result = df.where(df>=0).get_dtype_counts()
6006+
6007+
#### when we don't preserver boolean casts ####
6008+
#expected = Series({ 'float32' : 1, 'float64' : 3 })
6009+
6010+
expected = Series({ 'float32' : 1, 'float64' : 1, 'int32' : 1, 'int64' : 1 })
6011+
assert_series_equal(result, expected)
6012+
60006013
# aligning
60016014
def _check_align(df, cond, other, check_dtypes = True):
60026015
rs = df.where(cond, other)
@@ -6013,8 +6026,9 @@ def _check_align(df, cond, other, check_dtypes = True):
60136026
else:
60146027
o = other[k].values
60156028

6016-
assert_series_equal(v, Series(np.where(c, d, o),index=v.index))
6017-
6029+
new_values = d if c.all() else np.where(c, d, o)
6030+
assert_series_equal(v, Series(new_values,index=v.index))
6031+
60186032
# dtypes
60196033
# can't check dtype when other is an ndarray
60206034
if check_dtypes and not isinstance(other,np.ndarray):

0 commit comments

Comments
 (0)