Skip to content

Commit a0f075b

Browse files
committed
ENH: Add pandas accessors .df and .s
1 parent 602aeee commit a0f075b

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

backtesting/_util.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Sequence
23
from numbers import Number
34

@@ -68,6 +69,11 @@ def __float__(self):
6869
return super().__float__()
6970

7071
def to_series(self):
72+
warnings.warn("`.to_series()` is deprecated. For pd.Series conversion, use accessor `.s`")
73+
return self.s
74+
75+
@property
76+
def s(self) -> pd.Series:
7177
return pd.Series(self, index=self._opts['data'].index, name=self.name)
7278

7379

@@ -82,15 +88,13 @@ class _Data:
8288
and the returned "series" are _not_ `pd.Series` but `np.ndarray`
8389
for performance reasons.
8490
"""
85-
def __init__(self, df):
91+
def __init__(self, df: pd.DataFrame):
92+
self.__df = df
8693
self.__i = len(df)
8794
self.__pip = None
8895
self.__cache = {}
89-
90-
self.__arrays = {col: _Array(arr, data=self)
91-
for col, arr in df.items()}
92-
# Leave index as Series because pd.Timestamp nicer API to work with
93-
self.__arrays['__index'] = df.index.copy()
96+
self.__arrays = None
97+
self._update_arrays()
9498

9599
def __getitem__(self, item):
96100
return getattr(self, item)
@@ -105,17 +109,29 @@ def _set_length(self, i):
105109
self.__i = i
106110
self.__cache.clear()
107111

112+
def _update_arrays(self):
113+
self.__arrays = {col: _Array(arr, data=self)
114+
for col, arr in self.__df.items()}
115+
# Leave index as Series because pd.Timestamp nicer API to work with
116+
self.__arrays['__index'] = self.__df.index.copy()
117+
108118
def __len__(self):
109119
return self.__i
110120

121+
@property
122+
def df(self) -> pd.DataFrame:
123+
return (self.__df.iloc[:self.__i]
124+
if self.__i < len(self.__df)
125+
else self.__df)
126+
111127
@property
112128
def pip(self):
113129
if self.__pip is None:
114130
self.__pip = 10**-np.median([len(s.partition('.')[-1])
115131
for s in self.__arrays['Close'].astype(str)])
116132
return self.__pip
117133

118-
def __get_array(self, key):
134+
def __get_array(self, key) -> _Array:
119135
arr = self.__cache.get(key)
120136
if arr is None:
121137
arr = self.__cache[key] = self.__arrays[key][:self.__i]
@@ -142,8 +158,8 @@ def Volume(self):
142158
return self.__get_array('Volume')
143159

144160
@property
145-
def index(self):
146-
return self.__get_array('__index')
161+
def index(self) -> pd.Index:
162+
return self.__get_array('__index') # type: ignore
147163

148164
# Make pickling in Backtest.optimize() work with our catch-all __getattr__
149165
def __getstate__(self):

backtesting/backtesting.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -969,11 +969,12 @@ def run(self, **kwargs) -> pd.Series:
969969
970970
Keyword arguments are interpreted as strategy parameters.
971971
"""
972-
data = _Data(self._data)
972+
data = _Data(self._data.copy(deep=False))
973973
broker = self._broker(data=data) # type: _Broker
974974
strategy = self._strategy(broker, data, kwargs) # type: Strategy
975975

976976
strategy.init()
977+
data._update_arrays() # Strategy.init might have changed/added to data.df
977978

978979
# Indicators used in Strategy.next()
979980
indicator_attrs = {attr: indicator

backtesting/lib.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def SMA(series, n):
209209
assert isinstance(series, _Array), \
210210
'resample_apply() takes either a `pd.Series`, `pd.DataFrame`, ' \
211211
'or a `Strategy.data.*` array'
212-
series = series.to_series()
212+
series = series.s
213213

214214
series = series.copy() # XXX: pandas 1.0.1 bug https://github.com/pandas-dev/pandas/issues/31710 # noqa: E501
215215
resampled = series.resample(rule, label='right').agg(agg).dropna()

backtesting/test/_test.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def test_plot_heatmaps(self):
555555
def test_SignalStrategy(self):
556556
class S(SignalStrategy):
557557
def init(self):
558-
sma = self.data.Close.to_series().rolling(10).mean()
558+
sma = self.data.Close.s.rolling(10).mean()
559559
self.set_signal(self.data.Close > sma,
560560
self.data.Close < sma)
561561

@@ -568,7 +568,7 @@ def init(self):
568568
super().init()
569569
self.set_atr_periods(40)
570570
self.set_trailing_sl(3)
571-
self.sma = self.I(lambda: self.data.Close.to_series().rolling(10).mean())
571+
self.sma = self.I(lambda: self.data.Close.s.rolling(10).mean())
572572

573573
def next(self):
574574
super().next()
@@ -596,6 +596,22 @@ class Class:
596596
for s in ('Open', 'High', 'Low', 'Close'):
597597
self.assertEqual(_as_str(_Array([1], name=s)), s[0])
598598

599+
def test_pandas_accessors(self):
600+
class S(Strategy):
601+
def init(self):
602+
close, index = self.data.Close, self.data.index
603+
assert close.s.equals(pd.Series(close, index=index))
604+
assert self.data.df['Close'].equals(pd.Series(close, index=index))
605+
self.data.df['new_key'] = 2 * close
606+
607+
def next(self):
608+
close, index = self.data.Close, self.data.index
609+
assert close.s.equals(pd.Series(close, index=index))
610+
assert self.data.df['Close'].equals(pd.Series(close, index=index))
611+
assert self.data.df['new_key'].equals(pd.Series(self.data.new_key, index=index))
612+
613+
Backtest(GOOG.iloc[:20], S).run()
614+
599615

600616
@unittest.skipUnless(
601617
os.path.isdir(os.path.join(os.path.dirname(__file__),

0 commit comments

Comments
 (0)