diff --git a/src/molexp/project.py b/src/molexp/project.py index 4e77400..a71b877 100644 --- a/src/molexp/project.py +++ b/src/molexp/project.py @@ -163,14 +163,22 @@ def mapper(tasks: list[me.Task], params: list[me.Param]) -> Parallelizable[Any]: for task, param in zip(tasks, params): dr = self._get_driver(task, config) task.param |= param - yield dr.execute(inputs=task.param, final_vars=final_vars) + yield {"dr": dr, "inputs": task.param, "final_vars": final_vars} + # yield dr.execute(inputs=task.param, final_vars=final_vars) - def reducer(mapper: Collect[Any]) -> Any: + def dag_result(mapper: dict) -> dict: + _dr = mapper["dr"] + _inputs = mapper["inputs"] + _final_vars = mapper["final_vars"] + return _dr.execute(inputs=_inputs, final_vars=_final_vars) - return reducer_fn(mapper) + def reducer(dag_result: Collect[dict]) -> Any: + + return reducer_fn(dag_result) temp_module = ad_hoc_utils.create_temporary_module( mapper, + dag_result, reducer, module_name="start_tasks_mapper_reducer", ) @@ -179,7 +187,7 @@ def reducer(mapper: Collect[Any]) -> Any: driver.Builder() .with_modules(temp_module) .enable_dynamic_execution(allow_experimental_mode=True) - .with_remote_executor(executors.MultiProcessingExecutor(8)) + .with_remote_executor(executors.MultiThreadingExecutor(8)) .build() ) dr.execute(final_vars=["reducer"], inputs={"tasks": tasks, "params": params})