Skip to content

Commit 1faa7e8

Browse files
committed
REF: Improve multiprocessing on Linux (fork)
... by passing less objects around (less pickling) Fixes: #51
1 parent c37e7a9 commit 1faa7e8

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

backtesting/backtesting.py

+37-21
Original file line numberDiff line numberDiff line change
@@ -824,23 +824,35 @@ def _batch(seq):
824824
for i in range(0, len(seq), n):
825825
yield seq[i:i + n]
826826

827-
# If multiprocessing start method is 'fork' (i.e. on POSIX), use
828-
# a pool of processes to compute results in parallel.
829-
# Otherwise (i.e. on Windos), sequential computation will be "faster".
830-
if mp.get_start_method(allow_none=False) == 'fork':
831-
with ProcessPoolExecutor() as executor:
832-
futures = [executor.submit(self._mp_task, params)
833-
for params in _batch(param_combos)]
834-
for future in _tqdm(as_completed(futures), total=len(futures)):
835-
for params, stats in future.result():
836-
heatmap[tuple(params.values())] = maximize(stats)
837-
else:
838-
if os.name == 'posix':
839-
warnings.warn("For multiprocessing support in `Backtest.optimize()` "
840-
"set multiprocessing start method to 'fork'.")
841-
for params in _tqdm(param_combos):
842-
for _, stats in self._mp_task([params]):
843-
heatmap[tuple(params.values())] = maximize(stats)
827+
# Save necessary objects into "global" state; pass into concurrent executor
828+
# (and thus pickle) nothing but two numbers; receive nothing but numbers.
829+
# With start method "fork", children processes will inherit parent address space
830+
# in a copy-on-write manner, achieving better performance/RAM benefit.
831+
backtest_uuid = np.random.random()
832+
param_batches = list(_batch(param_combos))
833+
Backtest._mp_backtests[backtest_uuid] = (self, param_batches, maximize)
834+
try:
835+
# If multiprocessing start method is 'fork' (i.e. on POSIX), use
836+
# a pool of processes to compute results in parallel.
837+
# Otherwise (i.e. on Windos), sequential computation will be "faster".
838+
if mp.get_start_method(allow_none=False) == 'fork':
839+
with ProcessPoolExecutor() as executor:
840+
futures = [executor.submit(Backtest._mp_task, backtest_uuid, i)
841+
for i in range(len(param_batches))]
842+
for future in _tqdm(as_completed(futures), total=len(futures)):
843+
batch_index, values = future.result()
844+
for value, params in zip(values, param_batches[batch_index]):
845+
heatmap[tuple(params.values())] = value
846+
else:
847+
if os.name == 'posix':
848+
warnings.warn("For multiprocessing support in `Backtest.optimize()` "
849+
"set multiprocessing start method to 'fork'.")
850+
for batch_index in _tqdm(range(len(param_batches))):
851+
_, values = Backtest._mp_task(backtest_uuid, batch_index)
852+
for value, params in zip(values, param_batches[batch_index]):
853+
heatmap[tuple(params.values())] = value
854+
finally:
855+
del Backtest._mp_backtests[backtest_uuid]
844856

845857
best_params = heatmap.idxmax()
846858

@@ -856,10 +868,14 @@ def _batch(seq):
856868
return self._results, heatmap
857869
return self._results
858870

859-
def _mp_task(self, param_combos):
860-
return [(params, stats) for params, stats in ((params, self.run(**params))
861-
for params in param_combos)
862-
if stats['# Trades']]
871+
@staticmethod
872+
def _mp_task(backtest_uuid, batch_index):
873+
bt, param_batches, maximize_func = Backtest._mp_backtests[backtest_uuid]
874+
return batch_index, [maximize_func(stats) if stats['# Trades'] else np.nan
875+
for stats in (bt.run(**params)
876+
for params in param_batches[batch_index])]
877+
878+
_mp_backtests = {}
863879

864880
@staticmethod
865881
def _compute_drawdown_duration_peaks(dd: pd.Series):

backtesting/test/_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from runpy import run_path
99
from tempfile import NamedTemporaryFile, gettempdir
1010
from unittest import TestCase
11+
from unittest.mock import patch
1112

1213
import numpy as np
1314
import pandas as pd
@@ -372,6 +373,34 @@ def test_optimize(self):
372373
with _tempfile() as f:
373374
bt.plot(filename=f, open_browser=False)
374375

376+
def test_nowrite_df(self):
377+
# Test we don't write into passed data df by default.
378+
# Important for copy-on-write in Backtest.optimize()
379+
df = EURUSD.astype(float)
380+
values = df.values.ctypes.data
381+
assert values == df.values.ctypes.data
382+
383+
class S(SmaCross):
384+
def init(self):
385+
super().init()
386+
assert values == self.data.df.values.ctypes.data
387+
388+
bt = Backtest(df, S)
389+
_ = bt.run()
390+
assert values == bt._data.values.ctypes.data
391+
392+
def test_multiprocessing_windows_spawn(self):
393+
df = GOOG.iloc[:100]
394+
kw = dict(fast=[10])
395+
396+
stats1 = Backtest(df, SmaCross).optimize(**kw)
397+
with patch('multiprocessing.get_start_method', lambda **_: 'spawn'):
398+
with self.assertWarns(UserWarning) as cm:
399+
stats2 = Backtest(df, SmaCross).optimize(**kw)
400+
401+
self.assertIn('multiprocessing support', cm.warning.args[0])
402+
assert stats1.filter('[^_]').equals(stats2.filter('[^_]')), (stats1, stats2)
403+
375404
def test_optimize_invalid_param(self):
376405
bt = Backtest(GOOG.iloc[:100], SmaCross)
377406
self.assertRaises(AttributeError, bt.optimize, foo=range(3))

0 commit comments

Comments
 (0)