diff --git a/cadet/cadet.py b/cadet/cadet.py index dcfa6e9..b749fc6 100644 --- a/cadet/cadet.py +++ b/cadet/cadet.py @@ -6,6 +6,8 @@ from typing import Optional import warnings +from addict import Dict + from cadet.h5 import H5 from cadet.runner import CadetRunnerBase, CadetCLIRunner, ReturnInformation from cadet.cadet_dll import CadetDLLRunner @@ -525,3 +527,12 @@ def __del__(self): self.clear() del self._cadet_dll_runner del self._cadet_cli_runner + + def __getstate__(self): + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + # Restore the state and cast to addict.Dict() to add __frozen attributes + state = Dict(state) + self.__dict__.update(state) diff --git a/cadet/cadet_dll.py b/cadet/cadet_dll.py index 487853b..9da32ac 100644 --- a/cadet/cadet_dll.py +++ b/cadet/cadet_dll.py @@ -1626,6 +1626,9 @@ def __init__(self, dll_path: os.PathLike | str) -> None: Path to the CADET DLL. """ self._cadet_path = Path(dll_path) + self._initialize_dll() + + def _initialize_dll(self): self._lib = ctypes.cdll.LoadLibrary(self._cadet_path.as_posix()) # Query meta information @@ -1693,6 +1696,18 @@ def __init__(self, dll_path: os.PathLike | str) -> None: self._driver = self._api.createDriver() self.res: Optional[SimulationResult] = None + def __getstate__(self): + # Exclude all non-pickleable attributes and only keep _cadet_path + state = self.__dict__.copy() + pickleable_keys = ["_cadet_path"] + state = {key: state[key] for key in pickleable_keys} + return state + + def __setstate__(self, state): + # Restore the state and reinitialize the DLL + self.__dict__.update(state) + self._initialize_dll() + def clear(self) -> None: """ Clear the current simulation state. diff --git a/pyproject.toml b/pyproject.toml index 1a1e629..9419c06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,10 @@ dependencies = [ ] [project.optional-dependencies] -testing = ["pytest"] +testing = [ + "pytest", + "joblib" +] [project.urls] "homepage" = "https://github.com/cadet/CADET-Python" diff --git a/tests/test_parallelization.py b/tests/test_parallelization.py new file mode 100644 index 0000000..a0506c6 --- /dev/null +++ b/tests/test_parallelization.py @@ -0,0 +1,40 @@ +from cadet import Cadet +from joblib import Parallel, delayed +from .test_dll import setup_model + +n_jobs = 2 + + +def run_simulation(model): + model.save() + data = model.run_load() + return data + + +def test_parallelization_io(): + model1 = Cadet() + model1.root.input = {'model': 1} + model1.filename = "sim_1.h5" + model2 = Cadet() + model2.root.input = {'model': 2} + model2.filename = "sim_2.h5" + + models = [model1, model2] + + results_sequential = [run_simulation(model) for model in models] + + results_parallel = Parallel(n_jobs=n_jobs, verbose=0)( + delayed(run_simulation)(model, ) for model in models + ) + assert results_sequential == results_parallel + + +def test_parallelization_simulation(): + models = [setup_model(Cadet.autodetect_cadet(), file_name=f"LWE_{i}.h5") for i in range(2)] + + results_sequential = [run_simulation(model) for model in models] + + results_parallel = Parallel(n_jobs=n_jobs, verbose=0)( + delayed(run_simulation)(model, ) for model in models + ) + assert results_sequential == results_parallel