Skip to content

Commit 4c42f12

Browse files
committed
implemented distance based trailing stop loss
1 parent 0a76e96 commit 4c42f12

File tree

3 files changed

+56
-19
lines changed

3 files changed

+56
-19
lines changed

backtesting/lib.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
from inspect import currentframe
1818
from typing import Sequence, Optional, Union, Callable
1919

20+
import numpy
2021
import numpy as np
2122
import pandas as pd
23+
import talib
2224

2325
from .backtesting import Strategy
2426
from ._plotting import plot_heatmaps as _plot_heatmaps
2527
from ._util import _Array, _as_str
2628

2729
__pdoc__ = {}
2830

29-
3031
OHLCV_AGG = OrderedDict((
3132
('Open', 'first'),
3233
('High', 'max'),
@@ -267,7 +268,7 @@ def func(x, *_, **__):
267268
frame = frame.f_back
268269
level += 1
269270
if isinstance(frame.f_locals.get('self'), Strategy): # type: ignore
270-
strategy_I = frame.f_locals['self'].I # type: ignore
271+
strategy_I = frame.f_locals['self'].I # type: ignore
271272
break
272273
else:
273274
def strategy_I(func, *args, **kwargs):
@@ -314,6 +315,7 @@ def random_ohlc_data(example_data: pd.DataFrame, *,
314315
>>> next(ohlc_generator) # returns new random data
315316
...
316317
"""
318+
317319
def shuffle(x):
318320
return x.sample(frac=frac, replace=frac > 1, random_state=random_state)
319321

@@ -417,53 +419,68 @@ class TrailingStrategy(Strategy):
417419
"""
418420
__n_atr = 6.
419421
__atr = None
422+
__use_atr = True
423+
__sl_distance = None
420424

421-
def init(self):
425+
def init(self, use_atr=True):
422426
super().init()
423-
self.set_atr_periods()
427+
self.__use_atr = use_atr
428+
if use_atr:
429+
self.set_atr_periods()
424430

425431
def set_atr_periods(self, periods: int = 100):
426432
"""
427433
Set the lookback period for computing ATR. The default value
428434
of 100 ensures a _stable_ ATR.
429435
"""
430-
h, l, c_prev = self.data.High, self.data.Low, pd.Series(self.data.Close).shift(1)
431-
tr = np.max([h - l, (c_prev - h).abs(), (c_prev - l).abs()], axis=0)
432-
atr = pd.Series(tr).rolling(periods).mean().bfill().values
433-
self.__atr = atr
436+
self.__atr: numpy.ndarray = talib.ATR(self.data.High, self.data.Low, self.data.Close, timeperiod=periods)
437+
np.nan_to_num(self.__atr, copy=False, nan=self.__atr[np.argmax(self.__atr > 0)])
434438

435-
def set_trailing_sl(self, n_atr: float = 6):
439+
def set_trailing_atr_sl(self, n_atr: float = 6):
436440
"""
437441
Sets the future trailing stop-loss as some multiple (`n_atr`)
438442
average true bar ranges away from the current price.
439443
"""
440444
self.__n_atr = n_atr
441445

446+
def set_trailing_sl(self, sl_distance: float):
447+
"""
448+
Sets the future trailing stop-loss as fixed price.
449+
"""
450+
self.__sl_distance = sl_distance
451+
442452
def next(self):
443453
super().next()
444454
# Can't use index=-1 because self.__atr is not an Indicator type
445-
index = len(self.data)-1
455+
index = len(self.data) - 1
446456
for trade in self.trades:
447457
if trade.is_long:
448-
trade.sl = max(trade.sl or -np.inf,
449-
self.data.Close[index] - self.__atr[index] * self.__n_atr)
458+
if self.__use_atr:
459+
trade.sl = max(trade.sl or -np.inf,
460+
self.data.Close[index] - self.__atr[index] * self.__n_atr)
461+
else:
462+
trade.sl = max(trade.sl or -np.inf,
463+
self.data.Close[index] - self.__sl_distance)
450464
else:
451-
trade.sl = min(trade.sl or np.inf,
452-
self.data.Close[index] + self.__atr[index] * self.__n_atr)
465+
if self.__use_atr:
466+
trade.sl = min(trade.sl or np.inf,
467+
self.data.Close[index] + self.__atr[index] * self.__n_atr)
468+
else:
469+
trade.sl = min(trade.sl or np.inf,
470+
self.data.Close[index] + self.__sl_distance)
453471

454472

455473
# Prevent pdoc3 documenting __init__ signature of Strategy subclasses
456474
for cls in list(globals().values()):
457475
if isinstance(cls, type) and issubclass(cls, Strategy):
458476
__pdoc__[f'{cls.__name__}.__init__'] = False
459477

460-
461478
# NOTE: Don't put anything below this __all__ list
462479

463480
__all__ = [getattr(v, '__name__', k)
464-
for k, v in globals().items() # export
465-
if ((callable(v) and v.__module__ == __name__ or # callables from this module
466-
k.isupper()) and # or CONSTANTS
481+
for k, v in globals().items() # export
482+
if ((callable(v) and v.__module__ == __name__ or # callables from this module
483+
k.isupper()) and # or CONSTANTS
467484
not getattr(v, '__name__', k).startswith('_'))] # neither marked internal
468485

469486
# NOTE: Don't put anything below here. See above.

backtesting/test/_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,22 @@ def next(self):
862862
stats = Backtest(GOOG, S).run()
863863
self.assertEqual(stats['# Trades'], 57)
864864

865+
def test_TrailingStrategyDistanceSL(self):
866+
class S(TrailingStrategy):
867+
def init(self):
868+
super().init(use_atr=False)
869+
self.set_trailing_sl(30)
870+
self.sma = self.I(lambda: self.data.Close.s.rolling(10).mean())
871+
872+
def next(self):
873+
super().next()
874+
if not self.position and self.data.Close > self.sma:
875+
self.buy()
876+
bt = Backtest(GOOG, S)
877+
stats = bt.run()
878+
bt.plot()
879+
self.assertEqual(stats['# Trades'], 66)
880+
865881

866882
class TestUtil(TestCase):
867883
def test_as_str(self):

requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
# To run example notebooks, install required and test dependencies
2-
.[test]
2+
TA-Lib==0.4.20
3+
seaborn
4+
scikit-optimize
5+
scikit-learn
6+
.[test]

0 commit comments

Comments
 (0)