Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert joblib.Parallel and use mp.Pool + mp.SharedMemory instead #1222

Merged
merged 7 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 53 additions & 37 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,68 @@ on:
schedule: [ cron: '2 2 * * 6' ] # Every Saturday, 02:02

jobs:
lint:
runs-on: ubuntu-latest
timeout-minutes: 1
steps:
- uses: actions/checkout@v4
- run: pip install flake8 mypy
- run: flake8 backtesting setup.py
- run: mypy --no-warn-unused-ignores backtesting

coverage:
needs: lint
runs-on: ubuntu-latest
timeout-minutes: 4
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- run: pip install -U --pre bokeh pandas numpy coverage && pip install -U .[test]
- env: { BOKEH_BROWSER: none }
run: time coverage run -m backtesting.test
- run: coverage combine && coverage report

build:
name: Build
needs: lint
runs-on: ubuntu-latest
timeout-minutes: 3
strategy:
matrix:
python-version: ['3.10', 3.13]
include:
- python-version: 3.12
test-type: lint
- python-version: 3.11
test-type: docs

python-version: [3.11, 3.12, 3.13]
steps:
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-py${{ matrix.python-version }}
- uses: actions/checkout@v4
with:
fetch-depth: 3
- name: Fetch tags
run: git fetch --depth=1 origin +refs/tags/*:refs/tags/*

- run: pip install -U pip setuptools wheel
- if: matrix.test-type == 'lint'
run: pip install -U --pre bokeh pandas numpy && pip install -U .[dev]
- if: matrix.test-type == 'docs'
run: pip install -e .[doc] # -e provides _version.py for pdoc
- run: pip install -U .[test]

- if: matrix.test-type == 'lint'
run: flake8 backtesting setup.py
- if: matrix.test-type == 'lint'
run: mypy backtesting
- if: matrix.test-type == 'lint'
env: { BOKEH_BROWSER: none }
run: time coverage run -m backtesting.test
- if: matrix.test-type == 'lint'
run: coverage combine && coverage report

- if: '! matrix.test-type'
env: { BOKEH_BROWSER: none }
- env: { BOKEH_BROWSER: none }
run: time python -m backtesting.test

- if: matrix.test-type == 'docs'
run: time doc/build.sh
docs:
needs: lint
runs-on: ubuntu-latest
timeout-minutes: 3
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 3
- run: git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- run: pip install -e .[doc,test] # -e provides ./backtesting/_version.py for pdoc
- run: time doc/build.sh

win64:
needs:
- build
- docs
runs-on: windows-latest
timeout-minutes: 4
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.13
- run: pip install .[test]
- env: { BOKEH_BROWSER: none }
run: python -m backtesting.test
63 changes: 63 additions & 0 deletions backtesting/_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

import sys
import warnings
from contextlib import contextmanager
from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
from numbers import Number
from threading import Lock
from typing import Dict, List, Optional, Sequence, Union, cast

import numpy as np
Expand All @@ -15,6 +20,20 @@ def try_(lazy_func, default=None, exception=Exception):
return default


@contextmanager
def patch(obj, attr, newvalue):
had_attr = hasattr(obj, attr)
orig_value = getattr(obj, attr, None)
setattr(obj, attr, newvalue)
try:
yield
finally:
if had_attr:
setattr(obj, attr, orig_value)
else:
delattr(obj, attr)


def _as_str(value) -> str:
if isinstance(value, (Number, str)):
return str(value)
Expand Down Expand Up @@ -210,3 +229,47 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state


if sys.version_info >= (3, 13):
SharedMemory = _mpshm.SharedMemory
from multiprocessing.managers import SharedMemoryManager # noqa: F401
else:
class SharedMemory(_mpshm.SharedMemory):
# From https://github.com/python/cpython/issues/82300#issuecomment-2169035092
__lock = Lock()

def __init__(self, *args, track: bool = True, **kwargs):
self._track = track
if track:
return super().__init__(*args, **kwargs)
with self.__lock:
with patch(_mprt, 'register', lambda *a, **kw: None): # TODO lambda
super().__init__(*args, **kwargs)

def unlink(self):
if _mpshm._USE_POSIX and self._name:
_mpshm._posixshmem.shm_unlink(self._name)
if self._track:
_mprt.unregister(self._name, "shared_memory")

class SharedMemoryManager:
def __init__(self) -> None:
self._shms: list[SharedMemory] = []

def SharedMemory(self, size):
shm = SharedMemory(create=True, size=size, track=True)
self._shms.append(shm)
return shm

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
for shm in self._shms:
try:
shm.close()
shm.unlink()
except Exception:
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
category=ResourceWarning, stacklevel=2)
82 changes: 66 additions & 16 deletions backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from __future__ import annotations

import multiprocessing as mp
import os
import sys
import warnings
from abc import ABCMeta, abstractmethod
Expand All @@ -16,11 +18,10 @@
from itertools import chain, product, repeat
from math import copysign
from numbers import Number
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from numpy.random import default_rng

try:
Expand All @@ -32,7 +33,10 @@ def _tqdm(seq, **_):

from ._plotting import plot # noqa: I001
from ._stats import compute_stats
from ._util import _as_str, _Indicator, _Data, _indicator_warmup_nbars, _strategy_indicators, try_
from ._util import (
SharedMemory, SharedMemoryManager, _as_str, _Indicator, _Data, _indicator_warmup_nbars,
_strategy_indicators, patch, try_,
)

__pdoc__ = {
'Strategy.__init__': False,
Expand Down Expand Up @@ -1495,15 +1499,44 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
[p.values() for p in param_combos],
names=next(iter(param_combos)).keys()))

with Parallel(prefer='threads', require='sharedmem', max_nbytes='50M',
n_jobs=-2, return_as='generator') as parallel:
def _batch(seq):
# XXX: Replace with itertools.batched
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
for i in range(0, len(seq), n):
yield seq[i:i + n]

with mp.Pool() as pool, \
SharedMemoryManager() as smm:

shm_refs = [] # https://stackoverflow.com/questions/74193377/filenotfounderror-when-passing-a-shared-memory-to-a-new-process#comment130999060_74194875 # noqa: E501

def arr2shm(vals):
nonlocal smm
shm = smm.SharedMemory(size=vals.nbytes)
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
buf[:] = vals[:] # Copy into shared memory
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
shm_refs.append(shm)
return shm.name, vals.shape, vals.dtype

data_shm = tuple((
(column, *arr2shm(values))
for column, values in chain([(Backtest._mp_task_INDEX_COL, self._data.index)],
self._data.items())
))
with patch(self, '_data', None):
bt = copy(self) # bt._data will be reassigned in _mp_task worker
results = _tqdm(
parallel(delayed(self._mp_task)(self, params, maximize=maximize)
for params in param_combos),
pool.imap(Backtest._mp_task,
((bt, data_shm, params_batch)
for params_batch in _batch(param_combos))),
total=len(param_combos),
desc='Backtest.optimize')
for value, params in zip(results, param_combos):
heatmap[tuple(params.values())] = value
desc='Backtest.optimize'
)
for param_batch, result in zip(_batch(param_combos), results):
for params, stats in zip(param_batch, result):
if stats is not None:
heatmap[tuple(params.values())] = maximize(stats)

if pd.isnull(heatmap).all():
# No trade was made in any of the runs. Just make a random
Expand Down Expand Up @@ -1552,7 +1585,7 @@ def memoized_run(tup):
stats = self.run(**dict(tup))
return -maximize(stats)

progress = iter(_tqdm(repeat(None), total=max_tries, desc='Backtest.optimize'))
progress = iter(_tqdm(repeat(None), total=max_tries, leave=False, desc='Backtest.optimize'))
_names = tuple(kwargs.keys())

def objective_function(x):
Expand Down Expand Up @@ -1597,11 +1630,28 @@ def cons(x):
return output

@staticmethod
def _mp_task(bt, params, *, maximize):
stats = bt.run(**params)
return maximize(stats) if stats['# Trades'] else np.nan

_mp_backtests: Dict[float, Tuple['Backtest', List, Callable]] = {}
def _mp_task(arg):
bt, data_shm, params_batch = arg
shm = [SharedMemory(name=shm_name, create=False, track=False)
for _, shm_name, *_ in data_shm]
try:
def shm2arr(shm, shape, dtype):
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
arr.setflags(write=False)
return arr

bt._data = df = pd.DataFrame({
col: shm2arr(shm, shape, dtype)
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
df.set_index(Backtest._mp_task_INDEX_COL, drop=True, inplace=True)
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
for stats in (bt.run(**params)
for params in params_batch)]
finally:
for shmem in shm:
shmem.close()

_mp_task_INDEX_COL = '__bt_index'

def plot(self, *, results: pd.Series = None, filename=None, plot_width=None,
plot_equity=True, plot_return=False, plot_pl=True,
Expand Down
7 changes: 3 additions & 4 deletions backtesting/test/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import unittest
import warnings

suite = unittest.defaultTestLoader.discover('backtesting.test',
pattern='_test*.py')
unittest.defaultTestLoader.suiteClass = lambda _: suite

if __name__ == '__main__':
unittest.main(verbosity=2)
warnings.filterwarnings('error')
unittest.main(module='backtesting.test._test', verbosity=2)
Loading