From dbb9b079d2ea6f911b742a7109e395f13589e718 Mon Sep 17 00:00:00 2001 From: sapetnioc Date: Fri, 26 Apr 2024 10:06:31 +0200 Subject: [PATCH] Workers count can be given at engine creation --- capsul/application.py | 5 ++++- capsul/database/populse_db.py | 32 ++++++++++++++++---------------- capsul/test/test_workers.py | 27 +++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 17 deletions(-) create mode 100644 capsul/test/test_workers.py diff --git a/capsul/application.py b/capsul/application.py index 3f9f0c66..798f07e0 100644 --- a/capsul/application.py +++ b/capsul/application.py @@ -151,7 +151,7 @@ def engines(self): if field.name != "databases": yield self.engine(field.name) - def engine(self, name="builtin", update_database=False): + def engine(self, name="builtin", workers_count=None, update_database=False): """Get a :class:`~capsul.engine.Engine` instance""" from .engine import Engine @@ -159,6 +159,9 @@ def engine(self, name="builtin", update_database=False): engine_config = getattr(self.config, name, None) if engine_config is None: raise ValueError(f'engine "{name}" is not configured.') + if workers_count is not None: + engine_config.start_workers = engine_config.start_workers.copy() + engine_config.start_workers["count"] = workers_count return Engine( name, engine_config, diff --git a/capsul/database/populse_db.py b/capsul/database/populse_db.py index daf314ac..f4d631aa 100644 --- a/capsul/database/populse_db.py +++ b/capsul/database/populse_db.py @@ -19,7 +19,7 @@ "label": [str, {"index": True}], "config": dict, "workers": list[str], - "executions": list[dict], + "executions": list[str], "persistent": bool, "connections": int, } @@ -183,14 +183,14 @@ def worker_database_config(self, engine_id): return self.config def worker_started(self, engine_id): - with self.storage.data(write=True) as db: + with self.storage.data(write=True, exclusive=True) as db: worker_id = str(uuid4()) workers = db.capsul_engine[engine_id].workers.get() if workers is not None: workers.append(worker_id) db.capsul_engine[engine_id].workers = workers - return worker_id - raise ValueError(f"Invalid engine_id: {engine_id}") + return worker_id + raise ValueError(f"Invalid engine_id: {engine_id}") def worker_ended(self, engine_id, worker_id): with self.storage.data(write=True) as db: @@ -395,9 +395,9 @@ def job_finished_json( waiting_job["return_code"] = ( "Not started because de dependent job failed" ) - db.capsul_job[ - engine_id, execution_id, waiting_id - ].job = waiting_job + db.capsul_job[engine_id, execution_id, waiting_id].job = ( + waiting_job + ) waiting.remove(waiting_id) failed.append(waiting_id) stack.update(waiting_job.get("waited_by", [])) @@ -427,9 +427,9 @@ def job_finished_json( if not ongoing and not ready: if failed: - db.capsul_execution[ - engine_id, execution_id - ].error = "Some jobs failed" + db.capsul_execution[engine_id, execution_id].error = ( + "Some jobs failed" + ) db.capsul_execution[engine_id, execution_id].update( { "status": "finalization", @@ -483,12 +483,12 @@ def set_job_output_parameters( indices = job.get("parameters_index", {}) for name, value in output_parameters.items(): values[indices[name]] = value - db.capsul_job[ - engine_id, execution_id, job_id - ].job.output_parameters = output_parameters - db.capsul_execution[ - engine_id, execution_id - ].workflow_parameters_values = values + db.capsul_job[engine_id, execution_id, job_id].job.output_parameters = ( + output_parameters + ) + db.capsul_execution[engine_id, execution_id].workflow_parameters_values = ( + values + ) def job_json(self, engine_id, execution_id, job_id): if os.path.exists(self.path): diff --git a/capsul/test/test_workers.py b/capsul/test/test_workers.py new file mode 100644 index 00000000..e5f1ea88 --- /dev/null +++ b/capsul/test/test_workers.py @@ -0,0 +1,27 @@ +from capsul.api import Capsul +import time + +def noop() -> None: + pass + +def test_start_workers(): + capsul = Capsul(database_path="") + noop_executable = capsul.executable(noop) + for wc in [3, 2, 1]: + with capsul.engine(workers_count=wc) as engine: + requested = engine.config.start_workers.get("count", 0) + assert requested == wc + noop_id = engine.start(noop_executable) + for i in range(100): + if engine.database.workers_count(engine.engine_id) == wc: + break + time.sleep(0.2) + else: + raise RuntimeError(f'expected {wc} workers to be created, got {engine.database.workers_count(engine.engine_id)}') + engine.dispose(noop_id) + for i in range(100): + if engine.database.workers_count(engine.engine_id) == 0: + break + time.sleep(0.2) + else: + raise RuntimeError(f'expected workers to be stopped; running workers = {engine.database.workers_count(engine.engine_id)}')