Skip to content

Commit

Permalink
Use new QG method for PreExecInit.savePackageVersion implementation.
Browse files Browse the repository at this point in the history
This again changes the exception types raised (to be more consistent
with what saveInitOutputs already raised), but these weren't really
recoverable errors so this shouldn't break anything.
  • Loading branch information
TallJimbo committed Sep 10, 2024
1 parent a2d2a8b commit 3226412
Showing 1 changed file with 4 additions and 135 deletions.
139 changes: 4 additions & 135 deletions python/lsst/ctrl/mpexec/preExecInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,54 +34,19 @@
# -------------------------------
import abc
import logging
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

# -----------------------------
# Imports for other modules --
# -----------------------------
from lsst.daf.butler import DatasetRef
from lsst.daf.butler.registry import ConflictingDefinitionError
from lsst.pipe.base.automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME
from lsst.utils.packages import Packages

if TYPE_CHECKING:
from lsst.daf.butler import Butler, LimitedButler
from lsst.pipe.base import QuantumGraph, TaskDef, TaskFactory
from lsst.pipe.base import QuantumGraph, TaskFactory

_LOG = logging.getLogger(__name__)


class MissingReferenceError(Exception):
"""Exception raised when resolved reference is missing from graph."""

pass


def _compare_packages(old_packages: Packages, new_packages: Packages) -> None:
"""Compare two versions of Packages.
Parameters
----------
old_packages : `Packages`
Previously recorded package versions.
new_packages : `Packages`
New set of package versions.
Raises
------
TypeError
Raised if parameters are inconsistent.
"""
diff = new_packages.difference(old_packages)
if diff:
versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
raise TypeError(f"Package versions mismatch: ({versions_str})")
else:
_LOG.debug("new packages are consistent with old")


class PreExecInitBase(abc.ABC):
"""Common part of the implementation of PreExecInit classes that does not
depend on Butler type.
Expand All @@ -91,14 +56,13 @@ class PreExecInitBase(abc.ABC):
butler : `~lsst.daf.butler.LimitedButler`
Butler to use.
taskFactory : `lsst.pipe.base.TaskFactory`
Task factory.
Ignored and accepted for backwards compatibility.
extendRun : `bool`
Whether extend run parameter is in use.
"""

def __init__(self, butler: LimitedButler, taskFactory: TaskFactory, extendRun: bool):
self.butler = butler
self.taskFactory = taskFactory
self.extendRun = extendRun

def initialize(
Expand Down Expand Up @@ -214,96 +178,7 @@ def savePackageVersions(self, graph: QuantumGraph) -> None:
TypeError
Raised if existing object in butler is incompatible with new data.
"""
packages = Packages.fromSystem()
_LOG.debug("want to save packages: %s", packages)

# start transaction to rollback any changes on exceptions
with self.transaction():
# Packages dataset ref is stored in graph's global init outputs,
# but it may be also be missing.

packages_ref, old_packages = self._find_dataset(
graph.globalInitOutputRefs(), PACKAGES_INIT_OUTPUT_NAME
)
if packages_ref is None:
return

if old_packages is not None:
# Note that because we can only detect python modules that have
# been imported, the stored list of products may be more or
# less complete than what we have now. What's important is
# that the products that are in common have the same version.
_compare_packages(old_packages, packages)
# Update the old set of packages in case we have more packages
# that haven't been persisted.
extra = packages.extra(old_packages)
if extra:
_LOG.debug("extra packages: %s", extra)
old_packages.update(packages)
# have to remove existing dataset first, butler has no
# replace option.
self.butler.pruneDatasets([packages_ref], unstore=True, purge=True)
self.butler.put(old_packages, packages_ref)
else:
self.butler.put(packages, packages_ref)

def _find_dataset(
self, refs: Iterable[DatasetRef], dataset_type: str
) -> tuple[DatasetRef | None, Any | None]:
"""Find a ref with a given dataset type name in a list of references
and try to retrieve its data from butler.
Parameters
----------
refs : `~collections.abc.Iterable` [ `~lsst.daf.butler.DatasetRef` ]
References to check for matching dataset type.
dataset_type : `str`
Name of a dataset type to look for.
Returns
-------
ref : `~lsst.daf.butler.DatasetRef` or `None`
Dataset reference or `None` if there is no matching dataset type.
data : `Any`
An existing object extracted from butler, `None` if ``ref`` is
`None` or if there is no existing object for that reference.
"""
ref: DatasetRef | None = None
for ref in refs:
if ref.datasetType.name == dataset_type:
break
else:
return None, None

try:
data = self.butler.get(ref)
if data is not None and not self.extendRun:
# It must not exist unless we are extending run.
raise ConflictingDefinitionError(f"Dataset {ref} already exists in butler")
except (LookupError, FileNotFoundError):
data = None
return ref, data

def _task_iter(self, graph: QuantumGraph) -> Iterator[TaskDef]:
"""Iterate over TaskDefs in a graph, return only tasks that have one or
more associated quanta.
"""
for taskDef in graph.iterTaskGraph():
if graph.getNumberOfQuantaForTask(taskDef) > 0:
yield taskDef

@contextmanager
def transaction(self) -> Iterator[None]:
"""Context manager for transaction.
Default implementation has no transaction support.
Yields
------
`None`
No transaction support.
"""
yield
graph.write_packages(self.butler, compare_existing=self.extendRun)


class PreExecInit(PreExecInitBase):
Expand Down Expand Up @@ -334,12 +209,6 @@ def __init__(self, butler: Butler, taskFactory: TaskFactory, extendRun: bool = F
"with a default output RUN collection."
)

@contextmanager
def transaction(self) -> Iterator[None]:
# dosctring inherited
with self.full_butler.transaction():
yield

def initializeDatasetTypes(self, graph: QuantumGraph, registerDatasetTypes: bool = False) -> None:
# docstring inherited
if registerDatasetTypes:
Expand Down

0 comments on commit 3226412

Please sign in to comment.