Skip to content

Commit cb5e7dd

Browse files
committed
Merge pull request #2871 from jreback/dtypes1
ENH: implement Block splitting to avoid upcasts where possible (GH #2794)
2 parents 5b5e532 + 9fc888f commit cb5e7dd

File tree

3 files changed

+145
-42
lines changed

3 files changed

+145
-42
lines changed

pandas/core/frame.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3714,14 +3714,14 @@ def _combine_match_columns(self, other, func, fill_value=None):
37143714
if fill_value is not None:
37153715
raise NotImplementedError
37163716

3717-
new_data = left._data.where(func, right, axes = [left.columns, self.index])
3717+
new_data = left._data.eval(func, right, axes = [left.columns, self.index])
37183718
return self._constructor(new_data)
37193719

37203720
def _combine_const(self, other, func, raise_on_error = True):
37213721
if self.empty:
37223722
return self
37233723

3724-
new_data = self._data.where(func, other, raise_on_error=raise_on_error)
3724+
new_data = self._data.eval(func, other, raise_on_error=raise_on_error)
37253725
return self._constructor(new_data)
37263726

37273727
def _compare_frame(self, other, func):
@@ -5293,8 +5293,7 @@ def where(self, cond, other=NA, inplace=False, try_cast=False, raise_on_error=Tr
52935293
self._data = self._data.putmask(cond,other,inplace=True)
52945294

52955295
else:
5296-
func = lambda values, others, conds: np.where(conds, values, others)
5297-
new_data = self._data.where(func, other, cond, raise_on_error=raise_on_error, try_cast=try_cast)
5296+
new_data = self._data.where(other, cond, raise_on_error=raise_on_error, try_cast=try_cast)
52985297

52995298
return self._constructor(new_data)
53005299

pandas/core/internals.py

+112-31
Original file line numberDiff line numberDiff line change
@@ -304,14 +304,15 @@ def putmask(self, mask, new, inplace=False):
304304
if self._can_hold_element(new):
305305
new = self._try_cast(new)
306306
np.putmask(new_values, mask, new)
307-
# upcast me
308-
else:
307+
308+
# maybe upcast me
309+
elif mask.any():
309310
# type of the new block
310311
if ((isinstance(new, np.ndarray) and issubclass(new.dtype, np.number)) or
311312
isinstance(new, float)):
312-
typ = float
313+
typ = np.float64
313314
else:
314-
typ = object
315+
typ = np.object_
315316

316317
# we need to exiplicty astype here to make a copy
317318
new_values = new_values.astype(typ)
@@ -384,17 +385,16 @@ def shift(self, indexer, periods):
384385
new_values[:, periods:] = fill_value
385386
return make_block(new_values, self.items, self.ref_items)
386387

387-
def where(self, func, other, cond = None, raise_on_error = True, try_cast = False):
388+
def eval(self, func, other, raise_on_error = True, try_cast = False):
388389
"""
389-
evaluate the block; return result block(s) from the result
390+
evaluate the block; return result block from the result
390391
391392
Parameters
392393
----------
393394
func : how to combine self, other
394395
other : a ndarray/object
395-
cond : the condition to respect, optional
396-
raise_on_error : if True, raise when I can't perform the function,
397-
False by default (and just return the data that we had coming in)
396+
raise_on_error : if True, raise when I can't perform the function, False by default (and just return
397+
the data that we had coming in)
398398
399399
Returns
400400
-------
@@ -414,28 +414,7 @@ def where(self, func, other, cond = None, raise_on_error = True, try_cast = Fals
414414
values = values.T
415415
is_transposed = True
416416

417-
# see if we can align cond
418-
if cond is not None:
419-
if not hasattr(cond, 'shape'):
420-
raise ValueError('where must have a condition that is ndarray'
421-
' like')
422-
if hasattr(cond, 'reindex_axis'):
423-
axis = getattr(cond, '_het_axis', 0)
424-
cond = cond.reindex_axis(self.items, axis=axis,
425-
copy=True).values
426-
else:
427-
cond = cond.values
428-
429-
# may need to undo transpose of values
430-
if hasattr(values, 'ndim'):
431-
if (values.ndim != cond.ndim or
432-
values.shape == cond.shape[::-1]):
433-
values = values.T
434-
is_transposed = not is_transposed
435-
436417
args = [ values, other ]
437-
if cond is not None:
438-
args.append(cond)
439418
try:
440419
result = func(*args)
441420
except:
@@ -458,7 +437,106 @@ def where(self, func, other, cond = None, raise_on_error = True, try_cast = Fals
458437
if try_cast:
459438
result = self._try_cast_result(result)
460439

461-
return [ make_block(result, self.items, self.ref_items) ]
440+
return make_block(result, self.items, self.ref_items)
441+
442+
def where(self, other, cond, raise_on_error = True, try_cast = False):
443+
"""
444+
evaluate the block; return result block(s) from the result
445+
446+
Parameters
447+
----------
448+
other : a ndarray/object
449+
cond : the condition to respect
450+
raise_on_error : if True, raise when I can't perform the function, False by default (and just return
451+
the data that we had coming in)
452+
453+
Returns
454+
-------
455+
a new block(s), the result of the func
456+
"""
457+
458+
values = self.values
459+
460+
# see if we can align other
461+
if hasattr(other,'reindex_axis'):
462+
axis = getattr(other,'_het_axis',0)
463+
other = other.reindex_axis(self.items, axis=axis, copy=True).values
464+
465+
# make sure that we can broadcast
466+
is_transposed = False
467+
if hasattr(other, 'ndim') and hasattr(values, 'ndim'):
468+
if values.ndim != other.ndim or values.shape == other.shape[::-1]:
469+
values = values.T
470+
is_transposed = True
471+
472+
# see if we can align cond
473+
if not hasattr(cond,'shape'):
474+
raise ValueError("where must have a condition that is ndarray like")
475+
if hasattr(cond,'reindex_axis'):
476+
axis = getattr(cond,'_het_axis',0)
477+
cond = cond.reindex_axis(self.items, axis=axis, copy=True).values
478+
else:
479+
cond = cond.values
480+
481+
# may need to undo transpose of values
482+
if hasattr(values, 'ndim'):
483+
if values.ndim != cond.ndim or values.shape == cond.shape[::-1]:
484+
values = values.T
485+
is_transposed = not is_transposed
486+
487+
# our where function
488+
def func(c,v,o):
489+
if c.flatten().all():
490+
return v
491+
492+
try:
493+
return np.where(c,v,o)
494+
except:
495+
if raise_on_error:
496+
raise TypeError('Coulnd not operate %s with block values'
497+
% repr(o))
498+
else:
499+
# return the values
500+
result = np.empty(v.shape,dtype='float64')
501+
result.fill(np.nan)
502+
return result
503+
504+
def create_block(result, items, transpose = True):
505+
if not isinstance(result, np.ndarray):
506+
raise TypeError('Could not compare %s with block values'
507+
% repr(other))
508+
509+
if transpose and is_transposed:
510+
result = result.T
511+
512+
# try to cast if requested
513+
if try_cast:
514+
result = self._try_cast_result(result)
515+
516+
return make_block(result, items, self.ref_items)
517+
518+
# see if we can operate on the entire block, or need item-by-item
519+
if not self._can_hold_na:
520+
axis = cond.ndim-1
521+
result_blocks = []
522+
for item in self.items:
523+
loc = self.items.get_loc(item)
524+
item = self.items.take([loc])
525+
v = values.take([loc],axis=axis)
526+
c = cond.take([loc],axis=axis)
527+
o = other.take([loc],axis=axis) if hasattr(other,'shape') else other
528+
529+
result = func(c,v,o)
530+
if len(result) == 1:
531+
result = np.repeat(result,self.shape[1:])
532+
533+
result = result.reshape(((1,) + self.shape[1:]))
534+
result_blocks.append(create_block(result, item, transpose = False))
535+
536+
return result_blocks
537+
else:
538+
result = func(cond,values,other)
539+
return create_block(result, self.items)
462540

463541
def _mask_missing(array, missing_values):
464542
if not isinstance(missing_values, (list, np.ndarray)):
@@ -840,6 +918,9 @@ def apply(self, f, *args, **kwargs):
840918
def where(self, *args, **kwargs):
841919
return self.apply('where', *args, **kwargs)
842920

921+
def eval(self, *args, **kwargs):
922+
return self.apply('eval', *args, **kwargs)
923+
843924
def putmask(self, *args, **kwargs):
844925
return self.apply('putmask', *args, **kwargs)
845926

pandas/tests/test_frame.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -244,20 +244,26 @@ def test_getitem_boolean(self):
244244

245245
def test_getitem_boolean_casting(self):
246246

247-
#### this currently disabled ###
248-
249247
# don't upcast if we don't need to
250248
df = self.tsframe.copy()
251249
df['E'] = 1
252250
df['E'] = df['E'].astype('int32')
251+
df['E1'] = df['E'].copy()
253252
df['F'] = 1
254253
df['F'] = df['F'].astype('int64')
254+
df['F1'] = df['F'].copy()
255+
255256
casted = df[df>0]
256257
result = casted.get_dtype_counts()
257-
#expected = Series({'float64': 4, 'int32' : 1, 'int64' : 1})
258-
expected = Series({'float64': 6 })
258+
expected = Series({'float64': 4, 'int32' : 2, 'int64' : 2})
259259
assert_series_equal(result, expected)
260260

261+
# int block splitting
262+
df.ix[1:3,['E1','F1']] = 0
263+
casted = df[df>0]
264+
result = casted.get_dtype_counts()
265+
expected = Series({'float64': 6, 'int32' : 1, 'int64' : 1})
266+
assert_series_equal(result, expected)
261267

262268
def test_getitem_boolean_list(self):
263269
df = DataFrame(np.arange(12).reshape(3, 4))
@@ -6145,6 +6151,19 @@ def _check_get(df, cond, check_dtypes = True):
61456151
cond = df > 0
61466152
_check_get(df, cond)
61476153

6154+
6155+
# upcasting case (GH # 2794)
6156+
df = DataFrame(dict([ (c,Series([1]*3,dtype=c)) for c in ['int64','int32','float32','float64'] ]))
6157+
df.ix[1,:] = 0
6158+
6159+
result = df.where(df>=0).get_dtype_counts()
6160+
6161+
#### when we don't preserver boolean casts ####
6162+
#expected = Series({ 'float32' : 1, 'float64' : 3 })
6163+
6164+
expected = Series({ 'float32' : 1, 'float64' : 1, 'int32' : 1, 'int64' : 1 })
6165+
assert_series_equal(result, expected)
6166+
61486167
# aligning
61496168
def _check_align(df, cond, other, check_dtypes = True):
61506169
rs = df.where(cond, other)
@@ -6161,10 +6180,12 @@ def _check_align(df, cond, other, check_dtypes = True):
61616180
else:
61626181
o = other[k].values
61636182

6164-
assert_series_equal(v, Series(np.where(c, d, o),index=v.index))
6165-
6183+
new_values = d if c.all() else np.where(c, d, o)
6184+
assert_series_equal(v, Series(new_values,index=v.index))
6185+
61666186
# dtypes
61676187
# can't check dtype when other is an ndarray
6188+
61686189
if check_dtypes and not isinstance(other,np.ndarray):
61696190
self.assert_((rs.dtypes == df.dtypes).all() == True)
61706191

@@ -6200,13 +6221,15 @@ def _check_set(df, cond, check_dtypes = True):
62006221
dfi = df.copy()
62016222
econd = cond.reindex_like(df).fillna(True)
62026223
expected = dfi.mask(~econd)
6224+
6225+
#import pdb; pdb.set_trace()
62036226
dfi.where(cond, np.nan, inplace=True)
62046227
assert_frame_equal(dfi, expected)
62056228

62066229
# dtypes (and confirm upcasts)x
62076230
if check_dtypes:
62086231
for k, v in df.dtypes.iteritems():
6209-
if issubclass(v.type,np.integer):
6232+
if issubclass(v.type,np.integer) and not cond[k].all():
62106233
v = np.dtype('float64')
62116234
self.assert_(dfi[k].dtype == v)
62126235

0 commit comments

Comments
 (0)