Skip to content

Commit

Permalink
MNT Switch Task imports to a conditional import behind a function to …
Browse files Browse the repository at this point in the history
…prevent environment conflicts.
  • Loading branch information
gadorlhiac committed Mar 29, 2024
1 parent f356203 commit f9732a7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 27 deletions.
63 changes: 63 additions & 0 deletions lute/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""LUTE Tasks
Functions:
import_task(task_name: str) -> Type[Task]: Provides conditional import of
Task's. This prevents import conflicts as Task's may be intended to run
in different environments.
Exceptions:
TaskNotFoundError: Raised if
"""

from typing import Type
from .task import Task


class TaskNotFoundError(Exception):
"""Exception raised if an unrecognized Task is requested.
The Task could be invalid (e.g. misspelled, nonexistent) or it may not have
been registered with the `import_task` function below.
"""

...


def import_task(task_name: str) -> Type[Task]:
"""Conditionally imports Task's to prevent environment conflicts.
Args:
task_name (str): The name of the Task to import.
Returns:
TaskType (Type[Task]): The requested Task class.
Raises:
TaskNotFoundError: Raised if the requested Task is unrecognized.
If the Task exits it may not have been registered.
"""
if task_name == "Test":
from .test import Test

return Test

if task_name == "TestSocket":
from .test import TestSocket

return TestSocket

if task_name == "TestReadOutput":
from .test import TestReadOutput

return TestReadOutput

if task_name == "TestWriteOutput":
from .test import TestWriteOutput

return TestWriteOutput

if task_name == "FindPeaksPyAlgos":
from .sfx_find_peaks import FindPeaksPyAlgos

return FindPeaksPyAlgos

raise TaskNotFoundError
37 changes: 10 additions & 27 deletions subprocess_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import sys
import os
import argparse
import logging
import signal
import types
import importlib.util
from typing import Type, Optional, Dict, Any

from lute.tasks.task import Task, ThirdPartyTask
Expand Down Expand Up @@ -60,37 +58,22 @@ def timeout_handler(signum: int, frame: types.FrameType) -> None:

# Hack to avoid importing modules with conflicting dependencie
TaskType: Type[Task]
module_with_task: Optional[str] = None
lute_path: str = os.getenv("LUTE_PATH", os.path.dirname(__file__))
if isinstance(task_parameters, BaseBinaryParameters):
TaskType = ThirdPartyTask
else:
for module_name in os.listdir(f"{lute_path}/lute/tasks"):
if module_name.endswith(".py") and module_name not in [
"dataclasses.py",
"task.py",
"__init__.py",
]:
with open(f"{lute_path}/lute/tasks/{module_name}", "r") as f:
txt: str = f.read()
if f"class {task_name}(Task):" in txt:
module_with_task = module_name[:-3]
del txt
break
else:
from lute.tasks import import_task, TaskNotFoundError

try:
TaskType = import_task(task_name)
except TaskNotFoundError as err:
logger.debug(
f"Task {task_name} not found while scanning directory: `{lute_path}/lute/tasks`."
(
f"Task {task_name} not found! Things to double check:\n"
"\t - The spelling of the Task name.\n"
"\t - Has the Task been registered in lute.tasks.import_task."
)
)
sys.exit(-1)

# If we got this far we should have a module or are ThirdPartyTask
if module_with_task is not None:
spec: importlib.machinery.ModuleSpec = importlib.util.spec_from_file_location(
module_with_task, f"{lute_path}/lute/tasks/{module_with_task}.py"
)
task_module: types.ModuleType = importlib.util.module_from_spec(spec)
spec.loader.exec_module(task_module)
TaskType: Type[Task] = getattr(task_module, f"{task_name}")

task: Task = TaskType(params=task_parameters)
task.run()

0 comments on commit f9732a7

Please sign in to comment.