|
17 | 17 | from inspect import currentframe
|
18 | 18 | from typing import Sequence, Optional, Union, Callable
|
19 | 19 |
|
| 20 | +import numpy |
20 | 21 | import numpy as np
|
21 | 22 | import pandas as pd
|
| 23 | +import talib |
22 | 24 |
|
23 | 25 | from .backtesting import Strategy
|
24 | 26 | from ._plotting import plot_heatmaps as _plot_heatmaps
|
25 | 27 | from ._util import _Array, _as_str
|
26 | 28 |
|
27 | 29 | __pdoc__ = {}
|
28 | 30 |
|
29 |
| - |
30 | 31 | OHLCV_AGG = OrderedDict((
|
31 | 32 | ('Open', 'first'),
|
32 | 33 | ('High', 'max'),
|
@@ -267,7 +268,7 @@ def func(x, *_, **__):
|
267 | 268 | frame = frame.f_back
|
268 | 269 | level += 1
|
269 | 270 | if isinstance(frame.f_locals.get('self'), Strategy): # type: ignore
|
270 |
| - strategy_I = frame.f_locals['self'].I # type: ignore |
| 271 | + strategy_I = frame.f_locals['self'].I # type: ignore |
271 | 272 | break
|
272 | 273 | else:
|
273 | 274 | def strategy_I(func, *args, **kwargs):
|
@@ -314,6 +315,7 @@ def random_ohlc_data(example_data: pd.DataFrame, *,
|
314 | 315 | >>> next(ohlc_generator) # returns new random data
|
315 | 316 | ...
|
316 | 317 | """
|
| 318 | + |
317 | 319 | def shuffle(x):
|
318 | 320 | return x.sample(frac=frac, replace=frac > 1, random_state=random_state)
|
319 | 321 |
|
@@ -417,53 +419,68 @@ class TrailingStrategy(Strategy):
|
417 | 419 | """
|
418 | 420 | __n_atr = 6.
|
419 | 421 | __atr = None
|
| 422 | + __use_atr = True |
| 423 | + __sl_distance = None |
420 | 424 |
|
421 |
| - def init(self): |
| 425 | + def init(self, use_atr=True): |
422 | 426 | super().init()
|
423 |
| - self.set_atr_periods() |
| 427 | + self.__use_atr = use_atr |
| 428 | + if use_atr: |
| 429 | + self.set_atr_periods() |
424 | 430 |
|
425 | 431 | def set_atr_periods(self, periods: int = 100):
|
426 | 432 | """
|
427 | 433 | Set the lookback period for computing ATR. The default value
|
428 | 434 | of 100 ensures a _stable_ ATR.
|
429 | 435 | """
|
430 |
| - h, l, c_prev = self.data.High, self.data.Low, pd.Series(self.data.Close).shift(1) |
431 |
| - tr = np.max([h - l, (c_prev - h).abs(), (c_prev - l).abs()], axis=0) |
432 |
| - atr = pd.Series(tr).rolling(periods).mean().bfill().values |
433 |
| - self.__atr = atr |
| 436 | + self.__atr: numpy.ndarray = talib.ATR(self.data.High, self.data.Low, self.data.Close, timeperiod=periods) |
| 437 | + np.nan_to_num(self.__atr, copy=False, nan=self.__atr[np.argmax(self.__atr > 0)]) |
434 | 438 |
|
435 |
| - def set_trailing_sl(self, n_atr: float = 6): |
| 439 | + def set_trailing_atr_sl(self, n_atr: float = 6): |
436 | 440 | """
|
437 | 441 | Sets the future trailing stop-loss as some multiple (`n_atr`)
|
438 | 442 | average true bar ranges away from the current price.
|
439 | 443 | """
|
440 | 444 | self.__n_atr = n_atr
|
441 | 445 |
|
| 446 | + def set_trailing_sl(self, sl_distance: float): |
| 447 | + """ |
| 448 | + Sets the future trailing stop-loss as fixed price. |
| 449 | + """ |
| 450 | + self.__sl_distance = sl_distance |
| 451 | + |
442 | 452 | def next(self):
|
443 | 453 | super().next()
|
444 | 454 | # Can't use index=-1 because self.__atr is not an Indicator type
|
445 |
| - index = len(self.data)-1 |
| 455 | + index = len(self.data) - 1 |
446 | 456 | for trade in self.trades:
|
447 | 457 | if trade.is_long:
|
448 |
| - trade.sl = max(trade.sl or -np.inf, |
449 |
| - self.data.Close[index] - self.__atr[index] * self.__n_atr) |
| 458 | + if self.__use_atr: |
| 459 | + trade.sl = max(trade.sl or -np.inf, |
| 460 | + self.data.Close[index] - self.__atr[index] * self.__n_atr) |
| 461 | + else: |
| 462 | + trade.sl = max(trade.sl or -np.inf, |
| 463 | + self.data.Close[index] - self.__sl_distance) |
450 | 464 | else:
|
451 |
| - trade.sl = min(trade.sl or np.inf, |
452 |
| - self.data.Close[index] + self.__atr[index] * self.__n_atr) |
| 465 | + if self.__use_atr: |
| 466 | + trade.sl = min(trade.sl or np.inf, |
| 467 | + self.data.Close[index] + self.__atr[index] * self.__n_atr) |
| 468 | + else: |
| 469 | + trade.sl = min(trade.sl or np.inf, |
| 470 | + self.data.Close[index] + self.__sl_distance) |
453 | 471 |
|
454 | 472 |
|
455 | 473 | # Prevent pdoc3 documenting __init__ signature of Strategy subclasses
|
456 | 474 | for cls in list(globals().values()):
|
457 | 475 | if isinstance(cls, type) and issubclass(cls, Strategy):
|
458 | 476 | __pdoc__[f'{cls.__name__}.__init__'] = False
|
459 | 477 |
|
460 |
| - |
461 | 478 | # NOTE: Don't put anything below this __all__ list
|
462 | 479 |
|
463 | 480 | __all__ = [getattr(v, '__name__', k)
|
464 |
| - for k, v in globals().items() # export |
465 |
| - if ((callable(v) and v.__module__ == __name__ or # callables from this module |
466 |
| - k.isupper()) and # or CONSTANTS |
| 481 | + for k, v in globals().items() # export |
| 482 | + if ((callable(v) and v.__module__ == __name__ or # callables from this module |
| 483 | + k.isupper()) and # or CONSTANTS |
467 | 484 | not getattr(v, '__name__', k).startswith('_'))] # neither marked internal
|
468 | 485 |
|
469 | 486 | # NOTE: Don't put anything below here. See above.
|
0 commit comments