Skip to content

Commit 506ebc6

Browse files
committed
ENH: Add pandas accessors .df and .s
1 parent fcf45bb commit 506ebc6

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

backtesting/_util.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Sequence
23
from numbers import Number
34

@@ -44,12 +45,10 @@ class _Array(np.ndarray):
4445
ndarray extended to supply .name and other arbitrary properties
4546
in ._opts dict.
4647
"""
47-
def __new__(cls, array, *, name=None, write=False, **kwargs):
48+
def __new__(cls, array, *, name=None, **kwargs):
4849
obj = np.asarray(array).view(cls)
4950
obj.name = name or array.name
5051
obj._opts = kwargs
51-
if not write:
52-
obj.setflags(write=False)
5352
return obj
5453

5554
def __array_finalize__(self, obj):
@@ -70,7 +69,20 @@ def __float__(self):
7069
return super().__float__()
7170

7271
def to_series(self):
73-
return pd.Series(self, index=self._opts['data'].index, name=self.name)
72+
warnings.warn("`.to_series()` is deprecated. For pd.Series conversion, use accessor `.s`")
73+
return self.s
74+
75+
@property
76+
def s(self) -> pd.Series:
77+
values = np.atleast_2d(self)
78+
return pd.Series(values[0], index=self._opts['data'].index, name=self.name)
79+
80+
@property
81+
def df(self) -> pd.DataFrame:
82+
values = np.atleast_2d(np.asarray(self))
83+
df = pd.DataFrame(values.T, index=self._opts['data'].index,
84+
columns=[self.name] * len(values))
85+
return df
7486

7587

7688
class _Indicator(_Array):
@@ -84,15 +96,13 @@ class _Data:
8496
and the returned "series" are _not_ `pd.Series` but `np.ndarray`
8597
for performance reasons.
8698
"""
87-
def __init__(self, df):
99+
def __init__(self, df: pd.DataFrame):
100+
self.__df = df
88101
self.__i = len(df)
89102
self.__pip = None
90103
self.__cache = {}
91-
92-
self.__arrays = {col: _Array(arr, data=self)
93-
for col, arr in df.items()}
94-
# Leave index as Series because pd.Timestamp nicer API to work with
95-
self.__arrays['__index'] = df.index.copy()
104+
self.__arrays = None
105+
self._update()
96106

97107
def __getitem__(self, item):
98108
return self.__get_array(item)
@@ -107,17 +117,35 @@ def _set_length(self, i):
107117
self.__i = i
108118
self.__cache.clear()
109119

120+
def _update(self):
121+
self.__arrays = {col: _Array(arr, data=self)
122+
for col, arr in self.__df.items()}
123+
# Leave index as Series because pd.Timestamp nicer API to work with
124+
self.__arrays['__index'] = self.__df.index.copy()
125+
126+
def __repr__(self):
127+
i = min(self.__i, len(self.__df) - 1)
128+
return '<Data i={} ({}) {}>'.format(i, self.__arrays['__index'][i],
129+
', '.join('{}={}'.format(k, v)
130+
for k, v in self.__df.iloc[i].items()))
131+
110132
def __len__(self):
111133
return self.__i
112134

135+
@property
136+
def df(self) -> pd.DataFrame:
137+
return (self.__df.iloc[:self.__i]
138+
if self.__i < len(self.__df)
139+
else self.__df)
140+
113141
@property
114142
def pip(self):
115143
if self.__pip is None:
116144
self.__pip = 10**-np.median([len(s.partition('.')[-1])
117145
for s in self.__arrays['Close'].astype(str)])
118146
return self.__pip
119147

120-
def __get_array(self, key):
148+
def __get_array(self, key) -> _Array:
121149
arr = self.__cache.get(key)
122150
if arr is None:
123151
arr = self.__cache[key] = self.__arrays[key][:self.__i]
@@ -144,8 +172,8 @@ def Volume(self):
144172
return self.__get_array('Volume')
145173

146174
@property
147-
def index(self):
148-
return self.__get_array('__index')
175+
def index(self) -> pd.DatetimeIndex:
176+
return self.__get_array('__index') # type: ignore
149177

150178
# Make pickling in Backtest.optimize() work with our catch-all __getattr__
151179
def __getstate__(self):

backtesting/backtesting.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def data(self) -> _Data:
243243
the last array value (e.g. `data.Close[-1]`)
244244
is always the _most recent_ value.
245245
* If you need data arrays (e.g. `data.Close`) to be indexed
246-
Pandas series, you can call their `.to_series()` method
247-
(e.g. `data.Close.to_series()`).
246+
Pandas series, you can call their `.s` accessor
247+
(e.g. `data.Close.s`).
248248
"""
249249
return self._data
250250

@@ -994,11 +994,12 @@ def run(self, **kwargs) -> pd.Series:
994994
995995
Keyword arguments are interpreted as strategy parameters.
996996
"""
997-
data = _Data(self._data)
997+
data = _Data(self._data.copy(deep=False))
998998
broker = self._broker(data=data) # type: _Broker
999999
strategy = self._strategy(broker, data, kwargs) # type: Strategy
10001000

10011001
strategy.init()
1002+
data._update() # Strategy.init might have changed/added to data.df
10021003

10031004
# Indicators used in Strategy.next()
10041005
indicator_attrs = {attr: indicator

backtesting/lib.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class System(Strategy):
183183
def init(self):
184184
# Strategy exposes `self.data` as raw NumPy arrays.
185185
# Let's convert closing prices back to pandas Series.
186-
close = self.data.Close.to_series()
186+
close = self.data.Close.s
187187
188188
# Resample to daily resolution. Aggregate groups
189189
# using their last value (i.e. closing price at the end
@@ -213,9 +213,8 @@ def func(x, *_, **__):
213213
assert isinstance(series, _Array), \
214214
'resample_apply() takes either a `pd.Series`, `pd.DataFrame`, ' \
215215
'or a `Strategy.data.*` array'
216-
series = series.to_series()
216+
series = series.s
217217

218-
series = series.copy() # XXX: pandas 1.0.1 bug https://github.com/pandas-dev/pandas/issues/31710 # noqa: E501
219218
resampled = series.resample(rule, label='right').agg(agg).dropna()
220219
resampled.name = _as_str(series) + '[' + rule + ']'
221220

backtesting/test/_test.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ def test_plot_heatmaps(self):
560560
def test_SignalStrategy(self):
561561
class S(SignalStrategy):
562562
def init(self):
563-
sma = self.data.Close.to_series().rolling(10).mean()
563+
sma = self.data.Close.s.rolling(10).mean()
564564
self.set_signal(self.data.Close > sma,
565565
self.data.Close < sma)
566566

@@ -573,7 +573,7 @@ def init(self):
573573
super().init()
574574
self.set_atr_periods(40)
575575
self.set_trailing_sl(3)
576-
self.sma = self.I(lambda: self.data.Close.to_series().rolling(10).mean())
576+
self.sma = self.I(lambda: self.data.Close.s.rolling(10).mean())
577577

578578
def next(self):
579579
super().next()
@@ -603,6 +603,22 @@ class Class:
603603
for s in ('Open', 'High', 'Low', 'Close', 'Volume'):
604604
self.assertEqual(_as_str(_Array([1], name=s)), s[0])
605605

606+
def test_pandas_accessors(self):
607+
class S(Strategy):
608+
def init(self):
609+
close, index = self.data.Close, self.data.index
610+
assert close.s.equals(pd.Series(close, index=index))
611+
assert self.data.df['Close'].equals(pd.Series(close, index=index))
612+
self.data.df['new_key'] = 2 * close
613+
614+
def next(self):
615+
close, index = self.data.Close, self.data.index
616+
assert close.s.equals(pd.Series(close, index=index))
617+
assert self.data.df['Close'].equals(pd.Series(close, index=index))
618+
assert self.data.df['new_key'].equals(pd.Series(self.data.new_key, index=index))
619+
620+
Backtest(GOOG.iloc[:20], S).run()
621+
606622

607623
@unittest.skipUnless(
608624
os.path.isdir(os.path.join(os.path.dirname(__file__),

0 commit comments

Comments
 (0)