4
4
import warnings
5
5
from itertools import cycle , combinations
6
6
from functools import partial
7
+ from typing import List
7
8
8
9
import numpy as np
9
10
import pandas as pd
31
32
from bokeh .palettes import Category10
32
33
from bokeh .transform import factor_cmap
33
34
34
- from backtesting ._util import _data_period , _as_list
35
-
35
+ from backtesting ._util import _data_period , _as_list , _Indicator
36
36
37
37
with open (os .path .join (os .path .dirname (__file__ ), 'autoscale_cb.js' ),
38
38
encoding = 'utf-8' ) as _f :
@@ -85,10 +85,13 @@ def lightness(color, lightness=.94):
85
85
return color .to_rgb ()
86
86
87
87
88
- def plot (* , results , df , indicators , filename = '' , plot_width = None ,
88
+ def plot (* , results : pd .Series ,
89
+ df : pd .DataFrame ,
90
+ indicators : List [_Indicator ],
91
+ filename = '' , plot_width = None ,
89
92
plot_equity = True , plot_pl = True ,
90
93
plot_volume = True , plot_drawdown = False ,
91
- smooth_equity = False , relative_equity = True , omit_missing = True ,
94
+ smooth_equity = False , relative_equity = True ,
92
95
superimpose = True , show_legend = True , open_browser = True ):
93
96
"""
94
97
Like much of GUI code everywhere, this is a mess.
@@ -101,41 +104,29 @@ def plot(*, results, df, indicators, filename='', plot_width=None,
101
104
_bokeh_reset (filename )
102
105
103
106
COLORS = [BEAR_COLOR , BULL_COLOR ]
107
+ BAR_WIDTH = .8
104
108
105
- equity_data = results ['_equity_curve' ].copy (False )
109
+ assert df .index .equals (results ['_equity_curve' ].index )
110
+ equity_data = results ['_equity_curve' ].copy (deep = False )
106
111
trades = results ['_trades' ]
107
112
108
- orig_df = df = df .copy (False )
109
- df .index .name = None # Provides source name @index
110
- index = df .index
111
- assert df .index .equals (equity_data .index )
112
- time_resolution = getattr (index , 'resolution' , None )
113
- is_datetime_index = index .is_all_dates
114
-
115
- # If all Volume is NaN, don't plot volume
116
113
plot_volume = plot_volume and not df .Volume .isnull ().all ()
114
+ time_resolution = getattr (df .index , 'resolution' , None )
115
+ is_datetime_index = df .index .is_all_dates
117
116
118
- # OHLC vbar width in msec.
119
- # +1 will work in case of non-datetime index where vbar width should just be =1
120
- bar_width = 1 + dict (day = 86400 ,
121
- hour = 3600 ,
122
- minute = 60 ,
123
- second = 1 ).get (time_resolution , 0 ) * 1000 * .85
124
-
125
- if is_datetime_index :
126
- # Add index as a separate data source column because true .index is offset to align vbars
127
- df ['datetime' ] = index
128
- df .index = df .index + pd .Timedelta (bar_width / 2 , unit = 'ms' )
117
+ from .lib import OHLCV_AGG
118
+ # ohlc df may contain many columns. We're only interested in, and pass on to Bokeh, these
119
+ df = df [list (OHLCV_AGG .keys ())].copy (deep = False )
120
+ df .index .name = None # Provides source name @index
121
+ df ['datetime' ] = df .index # Save original, maybe datetime index
129
122
130
- if omit_missing :
131
- bar_width = .8
132
- df = df .reset_index (drop = True )
133
- equity_data = equity_data .reset_index (drop = True )
134
- index = df .index
123
+ df = df .reset_index (drop = True )
124
+ equity_data = equity_data .reset_index (drop = True )
125
+ index = df .index
135
126
136
127
new_bokeh_figure = partial (
137
128
_figure ,
138
- x_axis_type = 'datetime' if is_datetime_index and not omit_missing else ' linear' ,
129
+ x_axis_type = 'linear' ,
139
130
plot_width = plot_width ,
140
131
plot_height = 400 ,
141
132
tools = "xpan,xwheel_zoom,box_zoom,undo,redo,reset,crosshair,save" ,
@@ -152,12 +143,9 @@ def plot(*, results, df, indicators, filename='', plot_width=None,
152
143
153
144
source = ColumnDataSource (df )
154
145
source .add ((df .Close >= df .Open ).values .astype (np .uint8 ).astype (str ), 'inc' )
155
- trades_index = trades ['ExitBar' ]
156
- if not omit_missing :
157
- trades_index = index [trades_index .astype (int )]
158
146
159
147
trade_source = ColumnDataSource (dict (
160
- index = trades_index ,
148
+ index = trades [ 'ExitBar' ] ,
161
149
datetime = trades ['ExitTime' ],
162
150
exit_price = trades ['ExitPrice' ],
163
151
size = trades ['Size' ],
@@ -170,7 +158,7 @@ def plot(*, results, df, indicators, filename='', plot_width=None,
170
158
lightness (BULL_COLOR , .35 )]
171
159
trades_cmap = factor_cmap ('returns_positive' , colors_darker , ['0' , '1' ])
172
160
173
- if is_datetime_index and omit_missing :
161
+ if is_datetime_index :
174
162
fig_ohlc .xaxis .formatter = FuncTickFormatter (
175
163
args = dict (axis = fig_ohlc .xaxis [0 ],
176
164
formatter = DatetimeTickFormatter (days = ['%d %b' , '%a %d' ],
@@ -184,7 +172,7 @@ def plot(*, results, df, indicators, filename='', plot_width=None,
184
172
''' )
185
173
186
174
NBSP = ' ' * 4
187
- ohlc_extreme_values = df [['High' , 'Low' ]].copy (False )
175
+ ohlc_extreme_values = df [['High' , 'Low' ]].copy (deep = False )
188
176
ohlc_tooltips = [
189
177
('x, y' , NBSP .join (('$index' ,
190
178
'$y{0,0.0[0000]}' ))),
@@ -222,39 +210,33 @@ def set_tooltips(fig, tooltips=(), vline=True, renderers=(), show_arrow=True):
222
210
def _plot_equity_section ():
223
211
"""Equity section"""
224
212
# Max DD Dur. line
225
- equity = equity_data ['Equity' ].reset_index (drop = True )
226
- argmax = equity_data ['DrawdownDuration' ].reset_index (drop = True ).idxmax ()
227
- try :
228
- dd_start = equity [:argmax ].idxmax ()
229
- except Exception : # ValueError: attempt to get argmax of an empty sequence
213
+ equity = equity_data ['Equity' ].copy ()
214
+ dd_end = equity_data ['DrawdownDuration' ].idxmax ()
215
+ if np .isnan (dd_end ):
230
216
dd_start = dd_end = equity .index [0 ]
231
- timedelta = 0
232
217
else :
233
- dd_end = argmax
234
- if is_datetime_index and omit_missing :
235
- # "Calendar" duration
236
- timedelta = df .datetime .iloc [dd_end ] - df .datetime .iloc [dd_start ]
237
- else :
238
- timedelta = dd_end - dd_start
239
- # Get point intersection
218
+ dd_start = equity [:dd_end ].idxmax ()
219
+ # If DD not extending into the future, get exact point of intersection with equity
240
220
if dd_end != equity .index [- 1 ]:
241
- x1 , x2 = dd_end - 1 , dd_end
242
- y , y1 , y2 = equity [ dd_start ], equity [x1 ], equity [x2 ]
243
- dd_end -= ( 1 - ( y - y1 ) / ( y2 - y1 )) * ( dd_end - x1 ) # y = a x + b
221
+ dd_end = np . interp ( equity [ dd_start ],
222
+ ( equity [dd_end - 1 ], equity [dd_end ]),
223
+ ( dd_end - 1 , dd_end ))
244
224
245
225
if smooth_equity :
246
- select = (pd .Index (trades ['ExitBar' ]) |
247
- # Include beginning and end
248
- equity .index [:1 ] | equity .index [- 1 :] |
249
- # Include peak equity and peak DD
250
- pd .Index ([equity .idxmax (), argmax ]) |
251
- # Include max dd end points. Otherwise the MaxDD line looks amiss.
252
- pd .Index ([dd_start , int (dd_end ), min (equity .size - 1 , int (dd_end + 1 ))]))
226
+ interest_points = pd .Index ([
227
+ # Beginning and end
228
+ equity .index [0 ], equity .index [- 1 ],
229
+ # Peak equity and peak DD
230
+ equity .idxmax (), equity_data ['DrawdownPct' ].idxmax (),
231
+ # Include max dd end points. Otherwise the MaxDD line looks amiss.
232
+ dd_start , int (dd_end ), min (int (dd_end + 1 ), equity .size - 1 ),
233
+ ])
234
+ select = pd .Index (trades ['ExitBar' ]) | interest_points
253
235
select = select .unique ().dropna ()
254
236
equity = equity .iloc [select ].reindex (equity .index )
255
237
equity .interpolate (inplace = True )
256
238
257
- equity .index = equity_data .index
239
+ assert equity .index . equals ( equity_data .index )
258
240
259
241
if relative_equity :
260
242
equity /= equity .iloc [0 ]
@@ -302,9 +284,10 @@ def _plot_equity_section():
302
284
fig .scatter (argmax , equity [argmax ],
303
285
legend_label = 'Max Drawdown (-{:.1f}%)' .format (100 * drawdown [argmax ]),
304
286
color = 'red' , size = 8 )
305
- fig .line ([index [dd_start ], index [int (dd_end )]], equity .iloc [dd_start ],
287
+ dd_timedelta_label = df ['datetime' ].iloc [int (round (dd_end ))] - df ['datetime' ].iloc [dd_start ]
288
+ fig .line ([dd_start , dd_end ], equity .iloc [dd_start ],
306
289
line_color = 'red' , line_width = 2 ,
307
- legend_label = 'Max Dd Dur. ({})' .format (timedelta )
290
+ legend_label = 'Max Dd Dur. ({})' .format (dd_timedelta_label )
308
291
.replace (' 00:00:00' , '' )
309
292
.replace ('(0 days ' , '(' ))
310
293
@@ -354,7 +337,7 @@ def _plot_volume_section():
354
337
fig .xaxis .formatter = fig_ohlc .xaxis [0 ].formatter
355
338
fig .xaxis .visible = True
356
339
fig_ohlc .xaxis .visible = False # Show only Volume's xaxis
357
- r = fig .vbar ('index' , bar_width , 'Volume' , source = source , color = inc_cmap )
340
+ r = fig .vbar ('index' , BAR_WIDTH , 'Volume' , source = source , color = inc_cmap )
358
341
set_tooltips (fig , [('Volume' , '@Volume{0.00 a}' )], renderers = [r ])
359
342
fig .yaxis .formatter = NumeralTickFormatter (format = "0 a" )
360
343
return fig
@@ -374,48 +357,37 @@ def _plot_superimposed_ohlc():
374
357
stacklevel = 4 )
375
358
return
376
359
377
- orig_df [ '_width' ] = 1
378
- from . lib import OHLCV_AGG
379
- df2 = orig_df . resample ( resample_rule , label = 'left' ) .agg (dict (OHLCV_AGG , _width = 'count' ))
360
+ df2 = ( df . assign ( _width = 1 ). set_index ( 'datetime' )
361
+ . resample ( resample_rule , label = 'left' )
362
+ .agg (dict (OHLCV_AGG , _width = 'count' ) ))
380
363
381
364
# Check if resampling was downsampling; error on upsampling
382
- orig_freq = _data_period (orig_df )
383
- resample_freq = _data_period (df2 )
365
+ orig_freq = _data_period (df [ 'datetime' ] )
366
+ resample_freq = _data_period (df2 . index )
384
367
if resample_freq < orig_freq :
385
368
raise ValueError ('Invalid value for `superimpose`: Upsampling not supported.' )
386
369
if resample_freq == orig_freq :
387
370
warnings .warn ('Superimposed OHLC plot matches the original plot. Skipping.' ,
388
371
stacklevel = 4 )
389
372
return
390
373
391
- if omit_missing :
392
- width2 = '_width'
393
- df2 .index = df2 ['_width' ].cumsum ().shift (1 ).fillna (0 )
394
- df2 .index += df2 ['_width' ] / 2 - .5
395
- df2 ['_width' ] -= .1 # Candles don't touch
396
- else :
397
- del df ['_width' ]
398
- width2 = dict (day = 86400 * 5 ,
399
- hour = 86400 ,
400
- minute = 3600 ,
401
- second = 60 )[time_resolution ] * 1000
402
- df2 .index += pd .Timedelta (
403
- width2 / 2 +
404
- (width2 / 5 if resample_rule == 'W' else 0 ), # Sunday week start
405
- unit = 'ms' )
406
- df2 ['inc' ] = (df2 .Close >= df2 .Open ).astype (np .uint8 ).astype (str )
374
+ df2 .index = df2 ['_width' ].cumsum ().shift (1 ).fillna (0 )
375
+ df2 .index += df2 ['_width' ] / 2 - .5
376
+ df2 ['_width' ] -= .1 # Candles don't touch
377
+
378
+ df2 ['inc' ] = (df2 .Close >= df2 .Open ).astype (int ).astype (str )
407
379
df2 .index .name = None
408
380
source2 = ColumnDataSource (df2 )
409
381
fig_ohlc .segment ('index' , 'High' , 'index' , 'Low' , source = source2 , color = '#bbbbbb' )
410
382
colors_lighter = [lightness (BEAR_COLOR , .92 ),
411
383
lightness (BULL_COLOR , .92 )]
412
- fig_ohlc .vbar ('index' , width2 , 'Open' , 'Close' , source = source2 , line_color = None ,
384
+ fig_ohlc .vbar ('index' , '_width' , 'Open' , 'Close' , source = source2 , line_color = None ,
413
385
fill_color = factor_cmap ('inc' , colors_lighter , ['0' , '1' ]))
414
386
415
387
def _plot_ohlc ():
416
388
"""Main OHLC bars"""
417
389
fig_ohlc .segment ('index' , 'High' , 'index' , 'Low' , source = source , color = "black" )
418
- r = fig_ohlc .vbar ('index' , bar_width , 'Open' , 'Close' , source = source ,
390
+ r = fig_ohlc .vbar ('index' , BAR_WIDTH , 'Open' , 'Close' , source = source ,
419
391
line_color = "black" , fill_color = inc_cmap )
420
392
return r
421
393
@@ -484,7 +456,7 @@ def __eq__(self, other):
484
456
'index' , source_name , source = source ,
485
457
legend_label = legend_label , color = color ,
486
458
line_color = 'black' , fill_alpha = .8 ,
487
- marker = 'circle' , radius = bar_width / 2 * 1.5 )
459
+ marker = 'circle' , radius = BAR_WIDTH / 2 * 1.5 )
488
460
else :
489
461
fig .line (
490
462
'index' , source_name , source = source ,
@@ -495,7 +467,7 @@ def __eq__(self, other):
495
467
r = fig .scatter (
496
468
'index' , source_name , source = source ,
497
469
legend_label = LegendStr (legend_label ), color = color ,
498
- marker = 'circle' , radius = bar_width / 2 * .9 )
470
+ marker = 'circle' , radius = BAR_WIDTH / 2 * .9 )
499
471
else :
500
472
r = fig .line (
501
473
'index' , source_name , source = source ,
0 commit comments