Skip to content

Commit

Permalink
Add new QuantumGraph.get_refs method and use it in scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
timj committed Nov 4, 2024
1 parent 1378e19 commit 261ac6c
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 82 deletions.
152 changes: 152 additions & 0 deletions python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import getpass
import io
import json
import logging
import lzma
import os
import struct
Expand Down Expand Up @@ -75,6 +76,7 @@
from .quantumNode import BuildId, QuantumNode

_T = TypeVar("_T", bound="QuantumGraph")
_LOG = logging.getLogger(__name__)

# modify this constant any time the on disk representation of the save file
# changes, and update the load helpers to behave properly for each version.
Expand Down Expand Up @@ -1656,3 +1658,153 @@ def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
self.write_configs(butler, compare_existing=existing)
self.write_packages(butler, compare_existing=existing)
self.write_init_outputs(butler, skip_existing=existing)

def get_refs(
self,
*,
include_init_inputs: bool = False,
include_inputs: bool = False,
include_intermediates: bool | None = None,
include_init_outputs: bool = False,
include_outputs: bool = False,
conform_outputs: bool = True,
) -> tuple[set[DatasetRef], dict[str, DatastoreRecordData]]:
"""Get the requested dataset refs from the graph.
Parameters
----------
include_init_inputs : `bool`, optional
Include init inputs.
include_inputs : `bool`, optional
Include inputs.
include_intermediates : `bool` or `None`, optional
If `None`, no special handling for intermediates is performed.
If `True` intermediates are calculated even if other flags
do not request datasets. If `False` intermediates will be removed
from any results.
include_init_outputs : `bool`, optional
Include init outpus.
include_outputs : `bool`, optional
Include outputs.
conform_outputs : `bool`, optional
Whether any outputs found should have their dataset types conformed
with the registry dataset types.
Returns
-------
refs : `set` [ `lsst.daf.butler.DatasetRef` ]
The requested dataset refs found in the graph.
datastore_records : `dict` [ `str`, \
`lsst.daf.butler.datastore.record_data.DatastoreRecordData` ]
Any datastore records found.
Notes
-----
Conforming and requesting inputs and outputs can result in the same
dataset appearing in the results twice with differing storage classes.
"""
datastore_records: dict[str, DatastoreRecordData] = {}
init_input_refs: set[DatasetRef] = set()
init_output_refs: set[DatasetRef] = set()

if include_intermediates is True:
# Need to enable inputs and outputs even if not explicitly
# requested.
request_include_init_inputs = True
request_include_inputs = True
request_include_init_outputs = True
request_include_outputs = True
else:
request_include_init_inputs = include_init_inputs
request_include_inputs = include_inputs
request_include_init_outputs = include_init_outputs
request_include_outputs = include_outputs

if request_include_init_inputs or request_include_init_outputs:
for task_def in self.iterTaskGraph():
if request_include_init_inputs:
if in_refs := self.initInputRefs(task_def):
init_input_refs.update(in_refs)
if request_include_init_outputs:
if out_refs := self.initOutputRefs(task_def):
init_output_refs.update(out_refs)

input_refs: set[DatasetRef] = set()
output_refs: set[DatasetRef] = set()

for qnode in self:
if request_include_inputs:
for other_refs in qnode.quantum.inputs.values():
input_refs.update(other_refs)
# Inputs can come with datastore records.
for store_name, records in qnode.quantum.datastore_records.items():
datastore_records.setdefault(store_name, DatastoreRecordData()).update(records)

Check warning on line 1741 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1741

Added line #L1741 was not covered by tests
if request_include_outputs:
for other_refs in qnode.quantum.outputs.values():
output_refs.update(other_refs)

# Intermediates are the intersection of inputs and outputs. Must do
# this analysis before conforming since dataset type changes will
# change set membership.
inter_msg = ""
intermediates = set()
if include_intermediates is not None:
intermediates = (input_refs | init_input_refs) & (output_refs | init_output_refs)

if include_intermediates is False:
# Remove intermediates from results.
init_input_refs -= intermediates
input_refs -= intermediates
init_output_refs -= intermediates
output_refs -= intermediates
inter_msg = f"; Intermediates removed: {len(intermediates)}"
intermediates = set()
elif include_intermediates is True:
# Do not mention intermediates if all the input/output flags
# would have resulted in them anyhow.
if (
(request_include_init_inputs is not include_init_inputs)
or (request_include_inputs is not include_inputs)
or (request_include_init_outputs is not include_init_outputs)
or (request_include_outputs is not include_outputs)
):
inter_msg = f"; including intermediates: {len(intermediates)}"

# Assign intermediates to the relevant category.
if not include_init_inputs:
init_input_refs = init_input_refs & intermediates
if not include_inputs:
input_refs = input_refs & intermediates
if not include_init_outputs:
init_output_refs = init_output_refs & intermediates
if not include_outputs:
output_refs = output_refs & intermediates

# Conforming can result in an input ref and an output ref appearing
# in the returned results that are identical apart from storage class.
if conform_outputs:
# Get data repository definitions from the QuantumGraph; these can
# have different storage classes than those in the quanta.
dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}

def _update_ref(ref: DatasetRef) -> DatasetRef:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
ref = ref.overrideStorageClass(internal_dataset_type.storageClass_name)

Check warning on line 1793 in python/lsst/pipe/base/graph/graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/graph/graph.py#L1793

Added line #L1793 was not covered by tests
return ref

# Convert output_refs to the data repository storage classes, too.
output_refs = {_update_ref(ref) for ref in output_refs}
init_output_refs = {_update_ref(ref) for ref in init_output_refs}

_LOG.info(
"Found the following datasets. InitInputs: %d; Inputs: %d; InitOutputs: %s; Outputs: %d%s",
len(init_input_refs),
len(input_refs),
len(init_output_refs),
len(output_refs),
inter_msg,
)

refs = input_refs | init_input_refs | init_output_refs | output_refs
return refs, datastore_records
51 changes: 10 additions & 41 deletions python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

import logging

from lsst.daf.butler import DatasetRef, QuantumBackedButler
from lsst.daf.butler.datastore.record_data import DatastoreRecordData
from lsst.daf.butler import QuantumBackedButler
from lsst.pipe.base import QuantumGraph
from lsst.resources import ResourcePath

Expand All @@ -48,7 +47,7 @@ def retrieve_artifacts_for_quanta(
include_inputs: bool,
include_outputs: bool,
) -> list[ResourcePath]:
"""Retrieve artifacts from a graph and store locally.
"""Retrieve artifacts referenced in a graph and store locally.
Parameters
----------
Expand Down Expand Up @@ -81,48 +80,18 @@ def retrieve_artifacts_for_quanta(
nodes = qgraph_node_id or None
qgraph = QuantumGraph.loadUri(graph, nodes=nodes)

Check warning on line 81 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L80-L81

Added lines #L80 - L81 were not covered by tests

refs, datastore_records = qgraph.get_refs(

Check warning on line 83 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L83

Added line #L83 was not covered by tests
include_inputs=include_inputs,
include_init_inputs=include_inputs,
include_outputs=include_outputs,
include_init_outputs=include_outputs,
conform_outputs=True, # Need to look for predicted outputs with correct storage class.
)

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

Check warning on line 93 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L93

Added line #L93 was not covered by tests

datastore_records: dict[str, DatastoreRecordData] = {}
refs: set[DatasetRef] = set()
if include_inputs:
# Collect input refs used by this graph.
for task_def in qgraph.iterTaskGraph():
if in_refs := qgraph.initInputRefs(task_def):
refs.update(in_refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.inputs.values():
refs.update(otherRefs)
for store_name, records in qnode.quantum.datastore_records.items():
datastore_records.setdefault(store_name, DatastoreRecordData()).update(records)
n_inputs = len(refs)
if n_inputs:
_LOG.info("Found %d input dataset%s.", n_inputs, "" if n_inputs == 1 else "s")

if include_outputs:
# Collect output refs that could be created by this graph.
original_output_refs: set[DatasetRef] = set(qgraph.globalInitOutputRefs())
for task_def in qgraph.iterTaskGraph():
if out_refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(out_refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

# Convert output_refs to the data repository storage classes, too.
for ref in original_output_refs:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
refs.add(ref.overrideStorageClass(internal_dataset_type.storageClass_name))
else:
refs.add(ref)

n_outputs = len(refs) - n_inputs
if n_outputs:
_LOG.info("Found %d output dataset%s.", n_outputs, "" if n_outputs == 1 else "s")

# Make QBB, its config is the same as output Butler.
qbb = QuantumBackedButler.from_predicted(

Check warning on line 96 in python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/retrieve_artifacts_for_quanta.py#L96

Added line #L96 was not covered by tests
config=repo,
Expand Down
25 changes: 4 additions & 21 deletions python/lsst/pipe/base/script/transfer_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

__all__ = ["transfer_from_graph"]

from lsst.daf.butler import Butler, CollectionType, DatasetRef, QuantumBackedButler, Registry
from lsst.daf.butler import Butler, CollectionType, QuantumBackedButler, Registry
from lsst.daf.butler.registry import MissingCollectionError
from lsst.pipe.base import QuantumGraph

Expand Down Expand Up @@ -69,27 +69,10 @@ def transfer_from_graph(
# Read whole graph into memory
qgraph = QuantumGraph.loadUri(graph)

# Collect output refs that could be created by this graph.
original_output_refs: set[DatasetRef] = set(qgraph.globalInitOutputRefs())
for task_def in qgraph.iterTaskGraph():
if refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}
output_refs, _ = qgraph.get_refs(include_outputs=True, include_init_outputs=True, conform_outputs=True)

Check warning on line 72 in python/lsst/pipe/base/script/transfer_from_graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/transfer_from_graph.py#L72

Added line #L72 was not covered by tests

# Convert output_refs to the data repository storage classes, too.
output_refs = set()
for ref in original_output_refs:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
output_refs.add(ref.overrideStorageClass(internal_dataset_type.storageClass_name))
else:
output_refs.add(ref)
# Get data repository dataset type definitions from the QuantumGraph.
dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

Check warning on line 75 in python/lsst/pipe/base/script/transfer_from_graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/transfer_from_graph.py#L75

Added line #L75 was not covered by tests

# Make QBB, its config is the same as output Butler.
qbb = QuantumBackedButler.from_predicted(
Expand Down
23 changes: 3 additions & 20 deletions python/lsst/pipe/base/script/zip_from_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import logging
import re

from lsst.daf.butler import DatasetRef, QuantumBackedButler
from lsst.daf.butler import QuantumBackedButler
from lsst.daf.butler.utils import globToRegex
from lsst.pipe.base import QuantumGraph
from lsst.resources import ResourcePath
Expand Down Expand Up @@ -67,28 +67,11 @@ def zip_from_graph(
# Read whole graph into memory
qgraph = QuantumGraph.loadUri(graph)

# Collect output refs that could be created by this graph.
original_output_refs: set[DatasetRef] = set(qgraph.globalInitOutputRefs())
for task_def in qgraph.iterTaskGraph():
if refs := qgraph.initOutputRefs(task_def):
original_output_refs.update(refs)
for qnode in qgraph:
for otherRefs in qnode.quantum.outputs.values():
original_output_refs.update(otherRefs)
output_refs, _ = qgraph.get_refs(include_outputs=True, include_init_outputs=True, conform_outputs=True)

Check warning on line 70 in python/lsst/pipe/base/script/zip_from_graph.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/pipe/base/script/zip_from_graph.py#L70

Added line #L70 was not covered by tests

# Get data repository definitions from the QuantumGraph; these can have
# different storage classes than those in the quanta.
# Get data repository dataset type definitions from the QuantumGraph.
dataset_types = {dstype.name: dstype for dstype in qgraph.registryDatasetTypes()}

# Convert output_refs to the data repository storage classes, too.
output_refs = set()
for ref in original_output_refs:
internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
output_refs.add(ref.overrideStorageClass(internal_dataset_type.storageClass_name))
else:
output_refs.add(ref)

# Make QBB, its config is the same as output Butler.
qbb = QuantumBackedButler.from_predicted(
config=repo,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,29 @@ def testGetSummary(self) -> None:
self.assertEqual(self.qGraph.graphID, summary.graphID)
self.assertEqual(len(summary.qgraphTaskSummaries), len(self.qGraph.taskGraph))

def test_get_refs(self) -> None:
"""Test that dataset refs can be retrieved from graph."""
refs, _ = self.qGraph.get_refs(include_inputs=True)
self.assertEqual(len(refs), 8, str(refs))
refs, _ = self.qGraph.get_refs(include_init_inputs=True)
self.assertEqual(len(refs), 2, str(refs))
refs, _ = self.qGraph.get_refs(include_init_outputs=True)
self.assertEqual(len(refs), 3, str(refs))
refs, _ = self.qGraph.get_refs(include_outputs=True)
self.assertEqual(len(refs), 8, str(refs))
refs, _ = self.qGraph.get_refs(include_inputs=True, include_outputs=True)
self.assertEqual(len(refs), 12, str(refs))
refs, _ = self.qGraph.get_refs(
include_inputs=True, include_outputs=True, include_init_inputs=True, include_init_outputs=True
)
self.assertEqual(len(refs), 15, str(refs))
refs, _ = self.qGraph.get_refs(include_intermediates=True)
self.assertEqual(len(refs), 6, str(refs))
refs, _ = self.qGraph.get_refs(include_intermediates=False)
self.assertEqual(len(refs), 0, str(refs))
refs, _ = self.qGraph.get_refs(include_intermediates=False, include_inputs=True, include_outputs=True)
self.assertEqual(len(refs), 8, str(refs))


class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
"""Run file leak tests."""
Expand Down

0 comments on commit 261ac6c

Please sign in to comment.