Skip to content

Commit 0a89a98

Browse files
ENH: lib.MultiBacktest multi-dataset backtesting wrapper
Fixes #508 Thanks! Co-Authored-By: Mike Judge <[email protected]>
1 parent 03b05c9 commit 0a89a98

File tree

4 files changed

+165
-69
lines changed

4 files changed

+165
-69
lines changed

backtesting/_util.py

+79-20
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3+
import os
34
import sys
45
import warnings
56
from contextlib import contextmanager
7+
from functools import partial
8+
from itertools import chain
69
from multiprocessing import resource_tracker as _mprt
710
from multiprocessing import shared_memory as _mpshm
811
from numbers import Number
@@ -12,6 +15,13 @@
1215
import numpy as np
1316
import pandas as pd
1417

18+
try:
19+
from tqdm.auto import tqdm as _tqdm
20+
_tqdm = partial(_tqdm, leave=False)
21+
except ImportError:
22+
def _tqdm(seq, **_):
23+
return seq
24+
1525

1626
def try_(lazy_func, default=None, exception=Exception):
1727
try:
@@ -55,6 +65,13 @@ def _as_list(value) -> List:
5565
return [value]
5666

5767

68+
def _batch(seq):
69+
# XXX: Replace with itertools.batched
70+
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
71+
for i in range(0, len(seq), n):
72+
yield seq[i:i + n]
73+
74+
5875
def _data_period(index) -> Union[pd.Timedelta, Number]:
5976
"""Return data index period as pd.Timedelta"""
6077
values = pd.Series(index[-100:])
@@ -233,7 +250,6 @@ def __setstate__(self, state):
233250

234251
if sys.version_info >= (3, 13):
235252
SharedMemory = _mpshm.SharedMemory
236-
from multiprocessing.managers import SharedMemoryManager # noqa: F401
237253
else:
238254
class SharedMemory(_mpshm.SharedMemory):
239255
# From https://github.com/python/cpython/issues/82300#issuecomment-2169035092
@@ -244,7 +260,7 @@ def __init__(self, *args, track: bool = True, **kwargs):
244260
if track:
245261
return super().__init__(*args, **kwargs)
246262
with self.__lock:
247-
with patch(_mprt, 'register', lambda *a, **kw: None): # TODO lambda
263+
with patch(_mprt, 'register', lambda *a, **kw: None):
248264
super().__init__(*args, **kwargs)
249265

250266
def unlink(self):
@@ -253,23 +269,66 @@ def unlink(self):
253269
if self._track:
254270
_mprt.unregister(self._name, "shared_memory")
255271

256-
class SharedMemoryManager:
257-
def __init__(self) -> None:
258-
self._shms: list[SharedMemory] = []
259-
260-
def SharedMemory(self, size):
261-
shm = SharedMemory(create=True, size=size, track=True)
262-
self._shms.append(shm)
263-
return shm
264-
265-
def __enter__(self):
266-
return self
267272

268-
def __exit__(self, *args, **kwargs):
269-
for shm in self._shms:
270-
try:
271-
shm.close()
273+
class SharedMemoryManager:
274+
"""
275+
A simple shared memory contextmanager based on
276+
https://docs.python.org/3/library/multiprocessing.shared_memory.html#multiprocessing.shared_memory.SharedMemory
277+
"""
278+
def __init__(self, create=False) -> None:
279+
self._shms: list[SharedMemory] = []
280+
self.__create = create
281+
282+
def SharedMemory(self, *, name=None, create=False, size=0, track=True):
283+
shm = SharedMemory(name=name, create=create, size=size, track=track)
284+
shm._create = create
285+
# Essential to keep refs on Windows
286+
# https://stackoverflow.com/questions/74193377/filenotfounderror-when-passing-a-shared-memory-to-a-new-process#comment130999060_74194875 # noqa: E501
287+
self._shms.append(shm)
288+
return shm
289+
290+
def __enter__(self):
291+
return self
292+
293+
def __exit__(self, *args, **kwargs):
294+
for shm in self._shms:
295+
try:
296+
shm.close()
297+
if shm._create:
272298
shm.unlink()
273-
except Exception:
274-
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
275-
category=ResourceWarning, stacklevel=2)
299+
except Exception:
300+
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
301+
category=ResourceWarning, stacklevel=2)
302+
raise
303+
304+
def arr2shm(self, vals):
305+
"""Array to shared memory. Returns (shm_name, shape, dtype) used for restore."""
306+
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
307+
shm = self.SharedMemory(size=vals.nbytes, create=True)
308+
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
309+
buf[:] = vals[:] # Copy into shared memory
310+
return shm.name, vals.shape, vals.dtype
311+
312+
def df2shm(self, df):
313+
return tuple((
314+
(column, *self.arr2shm(values))
315+
for column, values in chain([(self._DF_INDEX_COL, df.index)], df.items())
316+
))
317+
318+
@staticmethod
319+
def shm2arr(shm, shape, dtype):
320+
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
321+
arr.setflags(write=False)
322+
return arr
323+
324+
_DF_INDEX_COL = '__bt_index'
325+
326+
@staticmethod
327+
def shm2df(data_shm):
328+
shm = [SharedMemory(name=name, create=False, track=False) for _, name, _, _ in data_shm]
329+
df = pd.DataFrame({
330+
col: SharedMemoryManager.shm2arr(shm, shape, dtype)
331+
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
332+
df.set_index(SharedMemoryManager._DF_INDEX_COL, drop=True, inplace=True)
333+
df.index.name = None
334+
return df, shm

backtesting/backtesting.py

+4-46
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from __future__ import annotations
1010

1111
import multiprocessing as mp
12-
import os
1312
import sys
1413
import warnings
1514
from abc import ABCMeta, abstractmethod
@@ -24,18 +23,11 @@
2423
import pandas as pd
2524
from numpy.random import default_rng
2625

27-
try:
28-
from tqdm.auto import tqdm as _tqdm
29-
_tqdm = partial(_tqdm, leave=False)
30-
except ImportError:
31-
def _tqdm(seq, **_):
32-
return seq
33-
3426
from ._plotting import plot # noqa: I001
3527
from ._stats import compute_stats
3628
from ._util import (
37-
SharedMemory, SharedMemoryManager, _as_str, _Indicator, _Data, _indicator_warmup_nbars,
38-
_strategy_indicators, patch, try_,
29+
SharedMemoryManager, _as_str, _Indicator, _Data, _batch, _indicator_warmup_nbars,
30+
_strategy_indicators, patch, try_, _tqdm,
3931
)
4032

4133
__pdoc__ = {
@@ -1507,36 +1499,14 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
15071499
[p.values() for p in param_combos],
15081500
names=next(iter(param_combos)).keys()))
15091501

1510-
def _batch(seq):
1511-
# XXX: Replace with itertools.batched
1512-
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
1513-
for i in range(0, len(seq), n):
1514-
yield seq[i:i + n]
1515-
15161502
with mp.Pool() as pool, \
15171503
SharedMemoryManager() as smm:
15181504

1519-
shm_refs = [] # https://stackoverflow.com/questions/74193377/filenotfounderror-when-passing-a-shared-memory-to-a-new-process#comment130999060_74194875 # noqa: E501
1520-
1521-
def arr2shm(vals):
1522-
nonlocal smm
1523-
shm = smm.SharedMemory(size=vals.nbytes)
1524-
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
1525-
buf[:] = vals[:] # Copy into shared memory
1526-
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
1527-
shm_refs.append(shm)
1528-
return shm.name, vals.shape, vals.dtype
1529-
1530-
data_shm = tuple((
1531-
(column, *arr2shm(values))
1532-
for column, values in chain([(Backtest._mp_task_INDEX_COL, self._data.index)],
1533-
self._data.items())
1534-
))
15351505
with patch(self, '_data', None):
15361506
bt = copy(self) # bt._data will be reassigned in _mp_task worker
15371507
results = _tqdm(
15381508
pool.imap(Backtest._mp_task,
1539-
((bt, data_shm, params_batch)
1509+
((bt, smm.df2shm(self._data), params_batch)
15401510
for params_batch in _batch(param_combos))),
15411511
total=len(param_combos),
15421512
desc='Backtest.optimize'
@@ -1640,27 +1610,15 @@ def cons(x):
16401610
@staticmethod
16411611
def _mp_task(arg):
16421612
bt, data_shm, params_batch = arg
1643-
shm = [SharedMemory(name=shm_name, create=False, track=False)
1644-
for _, shm_name, *_ in data_shm]
1613+
bt._data, shm = SharedMemoryManager.shm2df(data_shm)
16451614
try:
1646-
def shm2arr(shm, shape, dtype):
1647-
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
1648-
arr.setflags(write=False)
1649-
return arr
1650-
1651-
bt._data = df = pd.DataFrame({
1652-
col: shm2arr(shm, shape, dtype)
1653-
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
1654-
df.set_index(Backtest._mp_task_INDEX_COL, drop=True, inplace=True)
16551615
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
16561616
for stats in (bt.run(**params)
16571617
for params in params_batch)]
16581618
finally:
16591619
for shmem in shm:
16601620
shmem.close()
16611621

1662-
_mp_task_INDEX_COL = '__bt_index'
1663-
16641622
def plot(self, *, results: pd.Series = None, filename=None, plot_width=None,
16651623
plot_equity=True, plot_return=False, plot_pl=True,
16661624
plot_volume=True, plot_drawdown=False, plot_trades=True,

backtesting/lib.py

+71-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313

1414
from __future__ import annotations
1515

16+
import multiprocessing as mp
1617
from collections import OrderedDict
1718
from inspect import currentframe
18-
from itertools import compress
19+
from itertools import chain, compress, count
1920
from numbers import Number
2021
from typing import Callable, Generator, Optional, Sequence, Union
2122

@@ -24,7 +25,7 @@
2425

2526
from ._plotting import plot_heatmaps as _plot_heatmaps
2627
from ._stats import compute_stats as _compute_stats
27-
from ._util import _Array, _as_str
28+
from ._util import SharedMemoryManager, _Array, _as_str, _batch, _tqdm
2829
from .backtesting import Backtest, Strategy
2930

3031
__pdoc__ = {}
@@ -535,6 +536,74 @@ def __init__(self,
535536
__pdoc__[f'{cls.__name__}.__init__'] = False
536537

537538

539+
class MultiBacktest:
540+
"""
541+
Multi-dataset `backtesting.backtesting.Backtest` wrapper.
542+
543+
Run supplied `backtesting.backtesting.Strategy` on several instruments,
544+
in parallel. Used for comparing strategy runs across many instruments
545+
or classes of instruments. Example:
546+
547+
from backtesting.test import EURUSD, BTCUSD, SmaCross
548+
btm = MultiBacktest([EURUSD, BTCUSD], SmaCross)
549+
stats_per_ticker: pd.DataFrame = btm.run(fast=10, slow=20)
550+
heatmap_per_ticker: pd.DataFrame = btm.optimize(...)
551+
"""
552+
def __init__(self, df_list, strategy_cls, **kwargs):
553+
self._dfs = df_list
554+
self._strategy = strategy_cls
555+
self._bt_kwargs = kwargs
556+
557+
def run(self, **kwargs):
558+
"""
559+
Wraps `backtesting.backtesting.Backtest.run`. Returns `pd.DataFrame` with
560+
currency indexes in columns.
561+
"""
562+
with mp.Pool() as pool, \
563+
SharedMemoryManager() as smm:
564+
shm = [smm.df2shm(df) for df in self._dfs]
565+
results = _tqdm(
566+
pool.imap(self._mp_task_run,
567+
((df_batch, self._strategy, self._bt_kwargs, kwargs)
568+
for df_batch in _batch(shm))),
569+
total=len(shm),
570+
desc=self.__class__.__name__,
571+
)
572+
df = pd.DataFrame(list(chain(*results))).transpose()
573+
return df
574+
575+
@staticmethod
576+
def _mp_task_run(args):
577+
data_shm, strategy, bt_kwargs, run_kwargs = args
578+
dfs, shms = zip(*(SharedMemoryManager.shm2df(i) for i in data_shm))
579+
try:
580+
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
581+
for stats in (Backtest(df, strategy, **bt_kwargs).run(**run_kwargs)
582+
for df in dfs)]
583+
finally:
584+
for shmem in chain(*shms):
585+
shmem.close()
586+
587+
def optimize(self, **kwargs) -> pd.DataFrame:
588+
"""
589+
Wraps `backtesting.backtesting.Backtest.optimize`, but returns `pd.DataFrame` with
590+
currency indexes in columns.
591+
592+
heamap: pd.DataFrame = btm.optimize(...)
593+
from backtesting.plot import plot_heatmaps
594+
plot_heatmaps(heatmap.mean(axis=1))
595+
"""
596+
heatmaps = []
597+
# Simple loop since bt.optimize already does its own multiprocessing
598+
for df in _tqdm(self._dfs, desc=self.__class__.__name__):
599+
bt = Backtest(df, self._strategy, **self._bt_kwargs)
600+
_best_stats, heatmap = bt.optimize( # type: ignore
601+
return_heatmap=True, return_optimization=False, **kwargs)
602+
heatmaps.append(heatmap)
603+
heatmap = pd.DataFrame(dict(zip(count(), heatmaps)))
604+
return heatmap
605+
606+
538607
# NOTE: Don't put anything below this __all__ list
539608

540609
__all__ = [getattr(v, '__name__', k)

backtesting/test/_test.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from backtesting._stats import compute_drawdown_duration_peaks
1919
from backtesting._util import _Array, _as_str, _Indicator, patch, try_
2020
from backtesting.lib import (
21-
FractionalBacktest, OHLCV_AGG,
21+
FractionalBacktest, MultiBacktest, OHLCV_AGG,
2222
SignalStrategy,
2323
TrailingStrategy,
2424
barssince,
@@ -943,6 +943,16 @@ def test_FractionalBacktest(self):
943943
stats = ubtc_bt.run(fast=2, slow=3)
944944
self.assertEqual(stats['# Trades'], 41)
945945

946+
def test_MultiBacktest(self):
947+
btm = MultiBacktest([GOOG, EURUSD, BTCUSD], SmaCross, cash=100_000)
948+
res = btm.run(fast=2)
949+
self.assertIsInstance(res, pd.DataFrame)
950+
self.assertEqual(res.columns.tolist(), [0, 1, 2])
951+
heatmap = btm.optimize(fast=[2, 4], slow=[10, 20])
952+
self.assertIsInstance(heatmap, pd.DataFrame)
953+
self.assertEqual(heatmap.columns.tolist(), [0, 1, 2])
954+
plot_heatmaps(heatmap.mean(axis=1), open_browser=False)
955+
946956

947957
class TestUtil(TestCase):
948958
def test_as_str(self):

0 commit comments

Comments
 (0)