Skip to content

Commit f54fd76

Browse files
committed
REF: lib.resample_apply() to better handle Series/DataFrame
1 parent 469a930 commit f54fd76

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

backtesting/lib.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def quantile(series, quantile=None):
133133

134134

135135
def resample_apply(rule: str,
136-
func: Callable,
136+
func: Optional[Callable[..., Sequence]],
137137
series,
138138
*args,
139139
agg='last',
@@ -205,6 +205,10 @@ def SMA(series, n):
205205
self.sma = self.I(SMA, daily, 10, plot=False)
206206
207207
"""
208+
if func is None:
209+
def func(x, *_, **__):
210+
return x
211+
208212
if not isinstance(series, (pd.Series, pd.DataFrame)):
209213
assert isinstance(series, _Array), \
210214
'resample_apply() takes either a `pd.Series`, `pd.DataFrame`, ' \
@@ -228,14 +232,20 @@ def SMA(series, n):
228232
def strategy_I(func, *args, **kwargs):
229233
return func(*args, **kwargs)
230234

231-
# Resample back to data index
232235
def wrap_func(resampled, *args, **kwargs):
233-
ind = pd.Series(np.asarray(func(resampled, *args, **kwargs)),
234-
index=resampled.index,
235-
name=resampled.name)
236-
ind = ind.reindex(index=series.index | resampled.index,
237-
method='ffill').reindex(series.index)
238-
return ind
236+
result = func(resampled, *args, **kwargs)
237+
if not isinstance(result, pd.DataFrame) and not isinstance(result, pd.Series):
238+
result = np.asarray(result)
239+
if result.ndim == 1:
240+
result = pd.Series(result, name=resampled.name)
241+
elif result.ndim == 2:
242+
result = pd.DataFrame(result.T)
243+
# Resample back to data index
244+
if not result.index.is_all_dates:
245+
result.index = resampled.index
246+
result = result.reindex(index=series.index | resampled.index,
247+
method='ffill').reindex(series.index)
248+
return result
239249

240250
wrap_func.__name__ = func.__name__
241251

backtesting/test/_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,12 @@ def resets_index(*args):
530530
self.assertTrue((res.dropna() == res2.dropna()).all())
531531
self.assertTrue((res.index == res2.index).all())
532532

533+
res3 = resample_apply('D', None, EURUSD)
534+
self.assertIn('Volume', res3)
535+
536+
res3 = resample_apply('D', lambda df: (df.Close, df.Close), EURUSD)
537+
self.assertIsInstance(res3, pd.DataFrame)
538+
533539
def test_plot_heatmaps(self):
534540
bt = Backtest(GOOG, SmaCross)
535541
stats, heatmap = bt.optimize(fast=range(2, 7, 2),

0 commit comments

Comments
 (0)