Skip to content

Commit a1bd69c

Browse files
committed
TST: check appropriate tests for apply
PERF: allow apply to use the fast-path in mixed type frames except where datelike are present
1 parent 9e37a7d commit a1bd69c

File tree

4 files changed

+74
-33
lines changed

4 files changed

+74
-33
lines changed

pandas/core/frame.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pandas.core.common import (isnull, notnull, PandasError, _try_sort,
2525
_default_index, _maybe_upcast, _is_sequence,
2626
_infer_dtype_from_scalar, _values_from_object,
27-
_DATELIKE_DTYPES, is_list_like)
27+
is_list_like)
2828
from pandas.core.generic import NDFrame, _shared_docs
2929
from pandas.core.index import Index, MultiIndex, _ensure_index
3030
from pandas.core.indexing import (_maybe_droplevels,
@@ -1581,7 +1581,7 @@ def _ixs(self, i, axis=0, copy=False):
15811581
else:
15821582
new_values, copy = self._data.fast_2d_xs(i, copy=copy)
15831583
result = Series(new_values, index=self.columns,
1584-
name=self.index[i])
1584+
name=self.index[i], dtype=new_values.dtype)
15851585
result.is_copy=copy
15861586
return result
15871587

@@ -3324,10 +3324,9 @@ def _apply_standard(self, func, axis, ignore_failures=False, reduce=True):
33243324
if reduce:
33253325
try:
33263326

3327-
# can only work with numeric data in the fast path
3328-
numeric = self._get_numeric_data()
3329-
values = numeric.values
3330-
dummy = Series(NA, index=numeric._get_axis(axis),
3327+
# the is the fast-path
3328+
values = self.values
3329+
dummy = Series(NA, index=self._get_axis(axis),
33313330
dtype=values.dtype)
33323331

33333332
labels = self._get_agg_axis(axis)
@@ -3393,12 +3392,12 @@ def _apply_standard(self, func, axis, ignore_failures=False, reduce=True):
33933392
result = result.T
33943393
result = result.convert_objects(copy=False)
33953394

3396-
return result
33973395
else:
3398-
s = Series(results)
3399-
s.index = res_index
34003396

3401-
return s
3397+
result = Series(results)
3398+
result.index = res_index
3399+
3400+
return result
34023401

34033402
def _apply_broadcast(self, func, axis):
34043403
if axis == 0:
@@ -3932,8 +3931,7 @@ def _reduce(self, op, axis=0, skipna=True, numeric_only=None,
39323931
labels = self._get_agg_axis(axis)
39333932

39343933
# exclude timedelta/datetime unless we are uniform types
3935-
if axis == 1 and self._is_mixed_type and len(set(self.dtypes) &
3936-
_DATELIKE_DTYPES):
3934+
if axis == 1 and self._is_mixed_type and self._is_datelike_mixed_type:
39373935
numeric_only = True
39383936

39393937
if numeric_only is None:
@@ -3945,7 +3943,14 @@ def _reduce(self, op, axis=0, skipna=True, numeric_only=None,
39453943
# try by-column first
39463944
if filter_type is None and axis == 0:
39473945
try:
3948-
return self.apply(f).iloc[0]
3946+
3947+
# this can end up with a non-reduction
3948+
# but not always. if the types are mixed
3949+
# with datelike then need to make sure a series
3950+
result = self.apply(f,reduce=False)
3951+
if result.ndim == self.ndim:
3952+
result = result.iloc[0]
3953+
return result
39493954
except:
39503955
pass
39513956

pandas/core/generic.py

+5
Original file line numberDiff line numberDiff line change
@@ -1837,6 +1837,11 @@ def _is_numeric_mixed_type(self):
18371837
f = lambda: self._data.is_numeric_mixed_type
18381838
return self._protect_consolidate(f)
18391839

1840+
@property
1841+
def _is_datelike_mixed_type(self):
1842+
f = lambda: self._data.is_datelike_mixed_type
1843+
return self._protect_consolidate(f)
1844+
18401845
def _protect_consolidate(self, f):
18411846
blocks_before = len(self._data.blocks)
18421847
result = f()

pandas/core/internals.py

+11
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def _consolidate_key(self):
8383
def _is_single_block(self):
8484
return self.ndim == 1
8585

86+
@property
87+
def is_datelike(self):
88+
""" return True if I am a non-datelike """
89+
return self.is_datetime or self.is_timedelta
90+
8691
@property
8792
def fill_value(self):
8893
return np.nan
@@ -2439,6 +2444,12 @@ def is_numeric_mixed_type(self):
24392444
self._consolidate_inplace()
24402445
return all([block.is_numeric for block in self.blocks])
24412446

2447+
@property
2448+
def is_datelike_mixed_type(self):
2449+
# Warning, consolidation needs to get checked upstairs
2450+
self._consolidate_inplace()
2451+
return any([block.is_datelike for block in self.blocks])
2452+
24422453
def get_block_map(self, copy=False, typ=None, columns=None,
24432454
is_numeric=False, is_bool=False):
24442455
""" return a dictionary mapping the ftype -> block list

pandas/tests/test_frame.py

+40-20
Original file line numberDiff line numberDiff line change
@@ -9035,6 +9035,16 @@ def test_apply_mixed_dtype_corner(self):
90359035
expected = Series(np.nan, index=[])
90369036
assert_series_equal(result, expected)
90379037

9038+
df = DataFrame({'A': ['foo'],
9039+
'B': [1.]})
9040+
result = df.apply(lambda x: x['A'], axis=1)
9041+
expected = Series(['foo'],index=[0])
9042+
assert_series_equal(result, expected)
9043+
9044+
result = df.apply(lambda x: x['B'], axis=1)
9045+
expected = Series([1.],index=[0])
9046+
assert_series_equal(result, expected)
9047+
90389048
def test_apply_empty_infer_type(self):
90399049
no_cols = DataFrame(index=['a', 'b', 'c'])
90409050
no_index = DataFrame(columns=['a', 'b', 'c'])
@@ -9970,7 +9980,8 @@ def test_count(self):
99709980
self._check_stat_op('count', f,
99719981
has_skipna=False,
99729982
has_numeric_only=True,
9973-
check_dtypes=False)
9983+
check_dtype=False,
9984+
check_dates=True)
99749985

99759986
# corner case
99769987
frame = DataFrame()
@@ -9999,10 +10010,9 @@ def test_count(self):
999910010
def test_sum(self):
1000010011
self._check_stat_op('sum', np.sum, has_numeric_only=True)
1000110012

10002-
def test_sum_mixed_numeric(self):
10003-
raise nose.SkipTest("skipping for now")
10004-
# mixed types
10005-
self._check_stat_op('sum', np.sum, frame = self.mixed_float, has_numeric_only=True)
10013+
# mixed types (with upcasting happening)
10014+
self._check_stat_op('sum', np.sum, frame=self.mixed_float.astype('float32'),
10015+
has_numeric_only=True, check_dtype=False, check_less_precise=True)
1000610016

1000710017
def test_stat_operators_attempt_obj_array(self):
1000810018
data = {
@@ -10028,7 +10038,7 @@ def test_stat_operators_attempt_obj_array(self):
1002810038
assert_series_equal(result, expected)
1002910039

1003010040
def test_mean(self):
10031-
self._check_stat_op('mean', np.mean)
10041+
self._check_stat_op('mean', np.mean, check_dates=True)
1003210042

1003310043
def test_product(self):
1003410044
self._check_stat_op('product', np.prod)
@@ -10039,10 +10049,10 @@ def wrapper(x):
1003910049
return np.nan
1004010050
return np.median(x)
1004110051

10042-
self._check_stat_op('median', wrapper)
10052+
self._check_stat_op('median', wrapper, check_dates=True)
1004310053

1004410054
def test_min(self):
10045-
self._check_stat_op('min', np.min)
10055+
self._check_stat_op('min', np.min, check_dates=True)
1004610056
self._check_stat_op('min', np.min, frame=self.intframe)
1004710057

1004810058
def test_cummin(self):
@@ -10092,7 +10102,7 @@ def test_cummax(self):
1009210102
self.assertEqual(np.shape(cummax_xs), np.shape(self.tsframe))
1009310103

1009410104
def test_max(self):
10095-
self._check_stat_op('max', np.max)
10105+
self._check_stat_op('max', np.max, check_dates=True)
1009610106
self._check_stat_op('max', np.max, frame=self.intframe)
1009710107

1009810108
def test_mad(self):
@@ -10154,7 +10164,8 @@ def alt(x):
1015410164
assert_series_equal(df.kurt(), df.kurt(level=0).xs('bar'))
1015510165

1015610166
def _check_stat_op(self, name, alternative, frame=None, has_skipna=True,
10157-
has_numeric_only=False, check_dtypes=True):
10167+
has_numeric_only=False, check_dtype=True, check_dates=False,
10168+
check_less_precise=False):
1015810169
if frame is None:
1015910170
frame = self.frame
1016010171
# set some NAs
@@ -10163,14 +10174,16 @@ def _check_stat_op(self, name, alternative, frame=None, has_skipna=True,
1016310174

1016410175
f = getattr(frame, name)
1016510176

10166-
if not ('max' in name or 'min' in name or 'count' in name):
10177+
if check_dates:
1016710178
df = DataFrame({'b': date_range('1/1/2001', periods=2)})
1016810179
_f = getattr(df, name)
10169-
#print(df)
10170-
self.assertFalse(len(_f()))
10180+
result = _f()
10181+
self.assert_(isinstance(result, Series))
1017110182

1017210183
df['a'] = lrange(len(df))
10173-
self.assert_(len(getattr(df, name)()))
10184+
result = getattr(df, name)()
10185+
self.assert_(isinstance(result, Series))
10186+
self.assert_(len(result))
1017410187

1017510188
if has_skipna:
1017610189
def skipna_wrapper(x):
@@ -10184,21 +10197,27 @@ def wrapper(x):
1018410197

1018510198
result0 = f(axis=0, skipna=False)
1018610199
result1 = f(axis=1, skipna=False)
10187-
assert_series_equal(result0, frame.apply(wrapper))
10200+
assert_series_equal(result0, frame.apply(wrapper),
10201+
check_dtype=check_dtype,
10202+
check_less_precise=check_less_precise)
1018810203
assert_series_equal(result1, frame.apply(wrapper, axis=1),
10189-
check_dtype=False) # HACK: win32
10204+
check_dtype=False,
10205+
check_less_precise=check_less_precise) # HACK: win32
1019010206
else:
1019110207
skipna_wrapper = alternative
1019210208
wrapper = alternative
1019310209

1019410210
result0 = f(axis=0)
1019510211
result1 = f(axis=1)
10196-
assert_series_equal(result0, frame.apply(skipna_wrapper))
10212+
assert_series_equal(result0, frame.apply(skipna_wrapper),
10213+
check_dtype=check_dtype,
10214+
check_less_precise=check_less_precise)
1019710215
assert_series_equal(result1, frame.apply(skipna_wrapper, axis=1),
10198-
check_dtype=False)
10216+
check_dtype=False,
10217+
check_less_precise=check_less_precise)
1019910218

1020010219
# check dtypes
10201-
if check_dtypes:
10220+
if check_dtype:
1020210221
lcd_dtype = frame.values.dtype
1020310222
self.assert_(lcd_dtype == result0.dtype)
1020410223
self.assert_(lcd_dtype == result1.dtype)
@@ -10331,7 +10350,8 @@ def wrapper(x):
1033110350
return np.nan
1033210351
return np.median(x)
1033310352

10334-
self._check_stat_op('median', wrapper, frame=self.intframe, check_dtypes=False)
10353+
self._check_stat_op('median', wrapper, frame=self.intframe,
10354+
check_dtype=False, check_dates=True)
1033510355

1033610356
def test_quantile(self):
1033710357
from pandas.compat.scipy import scoreatpercentile

0 commit comments

Comments
 (0)