From 6f121462b30968dc38e02d890b80822191b60406 Mon Sep 17 00:00:00 2001
From: Tom Close <tom.g.close@gmail.com>
Date: Fri, 1 Mar 2024 21:49:32 +1100
Subject: [PATCH] ENH: Add the ability to bring your own worker (#733)

* add the ability to BYO your own worker, i.e. without it needing to monkey path the pydra.engine.workers.WORKERS dict

* changed byo plugins to be classes not instances

* touch up

* added test to catch missing patch lines

* touch up
---
 pydra/engine/submitter.py            | 29 +++++++++----
 pydra/engine/tests/test_submitter.py | 63 +++++++++++++++++++++++++++-
 pydra/engine/workers.py              | 53 ++++++++++++++---------
 3 files changed, 116 insertions(+), 29 deletions(-)

diff --git a/pydra/engine/submitter.py b/pydra/engine/submitter.py
index 3906955b2c..23c8f50b0e 100644
--- a/pydra/engine/submitter.py
+++ b/pydra/engine/submitter.py
@@ -1,9 +1,10 @@
 """Handle execution backends."""
 
 import asyncio
+import typing as ty
 import pickle
 from uuid import uuid4
-from .workers import WORKERS
+from .workers import Worker, WORKERS
 from .core import is_workflow
 from .helpers import get_open_loop, load_and_run_async
 
@@ -16,24 +17,34 @@
 class Submitter:
     """Send a task to the execution backend."""
 
-    def __init__(self, plugin="cf", **kwargs):
+    def __init__(self, plugin: ty.Union[str, ty.Type[Worker]] = "cf", **kwargs):
         """
         Initialize task submission.
 
         Parameters
         ----------
-        plugin : :obj:`str`
-            The identifier of the execution backend.
+        plugin : :obj:`str` or :obj:`ty.Type[pydra.engine.core.Worker]`
+            Either the identifier of the execution backend or the worker class itself.
             Default is ``cf`` (Concurrent Futures).
+        **kwargs
+            Additional keyword arguments to pass to the worker.
 
         """
         self.loop = get_open_loop()
         self._own_loop = not self.loop.is_running()
-        self.plugin = plugin
-        try:
-            self.worker = WORKERS[self.plugin](**kwargs)
-        except KeyError:
-            raise NotImplementedError(f"No worker for {self.plugin}")
+        if isinstance(plugin, str):
+            self.plugin = plugin
+            try:
+                worker_cls = WORKERS[self.plugin]
+            except KeyError:
+                raise NotImplementedError(f"No worker for '{self.plugin}' plugin")
+        else:
+            try:
+                self.plugin = plugin.plugin_name
+            except AttributeError:
+                raise ValueError("Worker class must have a 'plugin_name' str attribute")
+            worker_cls = plugin
+        self.worker = worker_cls(**kwargs)
         self.worker.loop = self.loop
 
     def __call__(self, runnable, cache_locations=None, rerun=False, environment=None):
diff --git a/pydra/engine/tests/test_submitter.py b/pydra/engine/tests/test_submitter.py
index d65247e96a..a3219521a0 100644
--- a/pydra/engine/tests/test_submitter.py
+++ b/pydra/engine/tests/test_submitter.py
@@ -2,6 +2,8 @@
 import re
 import subprocess as sp
 import time
+import os
+from unittest.mock import patch
 
 import pytest
 
@@ -12,8 +14,9 @@
     gen_basic_wf_with_threadcount,
     gen_basic_wf_with_threadcount_concurrent,
 )
-from ..core import Workflow
+from ..core import Workflow, TaskBase
 from ..submitter import Submitter
+from ..workers import SerialWorker
 from ... import mark
 from pathlib import Path
 from datetime import datetime
@@ -612,3 +615,61 @@ def alter_input(x):
 @mark.task
 def to_tuple(x, y):
     return (x, y)
+
+
+class BYOAddVarWorker(SerialWorker):
+    """A dummy worker that adds 1 to the output of the task"""
+
+    plugin_name = "byo_add_env_var"
+
+    def __init__(self, add_var, **kwargs):
+        super().__init__(**kwargs)
+        self.add_var = add_var
+
+    async def exec_serial(self, runnable, rerun=False, environment=None):
+        if isinstance(runnable, TaskBase):
+            with patch.dict(os.environ, {"BYO_ADD_VAR": str(self.add_var)}):
+                result = runnable._run(rerun, environment=environment)
+            return result
+        else:  # it could be tuple that includes pickle files with tasks and inputs
+            return super().exec_serial(runnable, rerun, environment)
+
+
+@mark.task
+def add_env_var_task(x: int) -> int:
+    return x + int(os.environ.get("BYO_ADD_VAR", 0))
+
+
+def test_byo_worker():
+
+    task1 = add_env_var_task(x=1)
+
+    with Submitter(plugin=BYOAddVarWorker, add_var=10) as sub:
+        assert sub.plugin == "byo_add_env_var"
+        result = task1(submitter=sub)
+
+    assert result.output.out == 11
+
+    task2 = add_env_var_task(x=2)
+
+    with Submitter(plugin="serial") as sub:
+        result = task2(submitter=sub)
+
+    assert result.output.out == 2
+
+
+def test_bad_builtin_worker():
+
+    with pytest.raises(NotImplementedError, match="No worker for 'bad-worker' plugin"):
+        Submitter(plugin="bad-worker")
+
+
+def test_bad_byo_worker():
+
+    class BadWorker:
+        pass
+
+    with pytest.raises(
+        ValueError, match="Worker class must have a 'plugin_name' str attribute"
+    ):
+        Submitter(plugin=BadWorker)
diff --git a/pydra/engine/workers.py b/pydra/engine/workers.py
index 155a2800d9..eaa40beb0a 100644
--- a/pydra/engine/workers.py
+++ b/pydra/engine/workers.py
@@ -128,6 +128,8 @@ async def fetch_finished(self, futures):
 class SerialWorker(Worker):
     """A worker to execute linearly."""
 
+    plugin_name = "serial"
+
     def __init__(self, **kwargs):
         """Initialize worker."""
         logger.debug("Initialize SerialWorker")
@@ -157,6 +159,8 @@ async def fetch_finished(self, futures):
 class ConcurrentFuturesWorker(Worker):
     """A worker to execute in parallel using Python's concurrent futures."""
 
+    plugin_name = "cf"
+
     def __init__(self, n_procs=None):
         """Initialize Worker."""
         super().__init__()
@@ -192,6 +196,7 @@ def close(self):
 class SlurmWorker(DistributedWorker):
     """A worker to execute tasks on SLURM systems."""
 
+    plugin_name = "slurm"
     _cmd = "sbatch"
     _sacct_re = re.compile(
         "(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
@@ -367,6 +372,8 @@ async def _verify_exit_code(self, jobid):
 class SGEWorker(DistributedWorker):
     """A worker to execute tasks on SLURM systems."""
 
+    plugin_name = "sge"
+
     _cmd = "qsub"
     _sacct_re = re.compile(
         "(?P<jobid>\\d*) +(?P<status>\\w*)\\+? +" "(?P<exit_code>\\d+):\\d+"
@@ -860,6 +867,8 @@ class DaskWorker(Worker):
     This is an experimental implementation with limited testing.
     """
 
+    plugin_name = "dask"
+
     def __init__(self, **kwargs):
         """Initialize Worker."""
         super().__init__()
@@ -898,7 +907,7 @@ def close(self):
 class PsijWorker(Worker):
     """A worker to execute tasks using PSI/J."""
 
-    def __init__(self, subtype, **kwargs):
+    def __init__(self, **kwargs):
         """
         Initialize PsijWorker.
 
@@ -915,15 +924,6 @@ def __init__(self, subtype, **kwargs):
         logger.debug("Initialize PsijWorker")
         self.psij = psij
 
-        # Check if the provided subtype is valid
-        valid_subtypes = ["local", "slurm"]
-        if subtype not in valid_subtypes:
-            raise ValueError(
-                f"Invalid 'subtype' provided. Available options: {', '.join(valid_subtypes)}"
-            )
-
-        self.subtype = subtype
-
     def run_el(self, interface, rerun=False, **kwargs):
         """Run a task."""
         return self.exec_psij(interface, rerun=rerun)
@@ -1039,14 +1039,29 @@ def close(self):
         pass
 
 
+class PsijLocalWorker(PsijWorker):
+    """A worker to execute tasks using PSI/J on the local machine."""
+
+    subtype = "local"
+    plugin_name = f"psij-{subtype}"
+
+
+class PsijSlurmWorker(PsijWorker):
+    """A worker to execute tasks using PSI/J using SLURM."""
+
+    subtype = "slurm"
+    plugin_name = f"psij-{subtype}"
+
+
 WORKERS = {
-    "serial": SerialWorker,
-    "cf": ConcurrentFuturesWorker,
-    "slurm": SlurmWorker,
-    "dask": DaskWorker,
-    "sge": SGEWorker,
-    **{
-        "psij-" + subtype: lambda subtype=subtype: PsijWorker(subtype=subtype)
-        for subtype in ["local", "slurm"]
-    },
+    w.plugin_name: w
+    for w in (
+        SerialWorker,
+        ConcurrentFuturesWorker,
+        SlurmWorker,
+        DaskWorker,
+        SGEWorker,
+        PsijLocalWorker,
+        PsijSlurmWorker,
+    )
 }