Skip to content

Commit

Permalink
Fix pickle support to allow for parallelization and add parallelizati…
Browse files Browse the repository at this point in the history
…on tests

Fixed pickle support by removing all ctypes pointers from the state in CadetDLLRunner.__getstate__ and recreating the dll interface in CadetDLLRunner.__setstate__ .
Fixed "no attribute __frozen" error by casting Cadet state into addict.Dict on Cadet.__setstate__ .
  • Loading branch information
ronald-jaepel committed Nov 19, 2024
1 parent b99ef3c commit 9cd4160
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 1 deletion.
11 changes: 11 additions & 0 deletions cadet/cadet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Optional
import warnings

from addict import Dict

from cadet.h5 import H5
from cadet.runner import CadetRunnerBase, CadetCLIRunner, ReturnInformation
from cadet.cadet_dll import CadetDLLRunner
Expand Down Expand Up @@ -525,3 +527,12 @@ def __del__(self):
self.clear()
del self._cadet_dll_runner
del self._cadet_cli_runner

def __getstate__(self):
state = self.__dict__.copy()
return state

def __setstate__(self, state):
# Restore the state and cast to addict.Dict() to add __frozen attributes
state = Dict(state)
self.__dict__.update(state)
15 changes: 15 additions & 0 deletions cadet/cadet_dll.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,9 @@ def __init__(self, dll_path: os.PathLike | str) -> None:
Path to the CADET DLL.
"""
self._cadet_path = Path(dll_path)
self._initialize_dll()

def _initialize_dll(self):
self._lib = ctypes.cdll.LoadLibrary(self._cadet_path.as_posix())

# Query meta information
Expand Down Expand Up @@ -1693,6 +1696,18 @@ def __init__(self, dll_path: os.PathLike | str) -> None:
self._driver = self._api.createDriver()
self.res: Optional[SimulationResult] = None

def __getstate__(self):
# Exclude all non-pickleable attributes and only keep _cadet_path
state = self.__dict__.copy()
pickleable_keys = ["_cadet_path"]
state = {key: state[key] for key in pickleable_keys}
return state

def __setstate__(self, state):
# Restore the state and reinitialize the DLL
self.__dict__.update(state)
self._initialize_dll()

def clear(self) -> None:
"""
Clear the current simulation state.
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ dependencies = [
]

[project.optional-dependencies]
testing = ["pytest"]
testing = [
"pytest",
"joblib"
]

[project.urls]
"homepage" = "https://github.com/cadet/CADET-Python"
Expand Down
40 changes: 40 additions & 0 deletions tests/test_parallelization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from cadet import Cadet
from joblib import Parallel, delayed
from .test_dll import setup_model

n_jobs = 2


def run_simulation(model):
model.save()
data = model.run_load()
return data


def test_parallelization_io():
model1 = Cadet()
model1.root.input = {'model': 1}
model1.filename = "sim_1.h5"
model2 = Cadet()
model2.root.input = {'model': 2}
model2.filename = "sim_2.h5"

models = [model1, model2]

results_sequential = [run_simulation(model) for model in models]

results_parallel = Parallel(n_jobs=n_jobs, verbose=0)(
delayed(run_simulation)(model, ) for model in models
)
assert results_sequential == results_parallel


def test_parallelization_simulation():
models = [setup_model(Cadet.autodetect_cadet(), file_name=f"LWE_{i}.h5") for i in range(2)]

results_sequential = [run_simulation(model) for model in models]

results_parallel = Parallel(n_jobs=n_jobs, verbose=0)(
delayed(run_simulation)(model, ) for model in models
)
assert results_sequential == results_parallel

0 comments on commit 9cd4160

Please sign in to comment.