Skip to content

Commit

Permalink
Don't mock automatic-output connections that are used as inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Oct 16, 2023
1 parent d4b6093 commit b8f0569
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions python/lsst/pipe/base/tests/mocks/_pipeline_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from lsst.utils.introspection import get_full_type_name
from lsst.utils.iteration import ensure_iterable

from ... import automatic_connection_constants as acc
from ... import connectionTypes as cT
from ...config import PipelineTaskConfig
from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
Expand Down Expand Up @@ -381,12 +382,22 @@ def __init__(self, *, config: MockPipelineTaskConfig):
self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
for name, connection in self.original.allConnections.items():
if connection.name not in self.unmocked_dataset_types:
# We register the mock storage class with the global singleton
# here, but can only put its name in the connection. That means
# the same global singleton (or one that also has these
# registrations) has to be available whenever this dataset type
# is used.
storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
if connection.storageClass in (
acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
acc.METADATA_OUTPUT_STORAGE_CLASS,
acc.LOG_OUTPUT_STORAGE_CLASS,
):
# We don't mock the automatic output connections, so if
# they're used as an input in any other connection, we
# can't mock them there either.
storage_class_name = connection.storageClass
else:
# We register the mock storage class with the global
# singleton here, but can only put its name in the
# connection. That means the same global singleton (or one
# that also has these registrations) has to be available
# whenever this dataset type is used.
storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name
kwargs: dict[str, Any] = {}
if hasattr(connection, "dimensions"):
connection_dimensions = set(connection.dimensions)
Expand All @@ -400,7 +411,7 @@ def __init__(self, *, config: MockPipelineTaskConfig):
connection = dataclasses.replace(
connection,
name=get_mock_name(connection.name),
storageClass=storage_class.name,
storageClass=storage_class_name,
**kwargs,
)
elif name in self.original.outputs:
Expand Down

0 comments on commit b8f0569

Please sign in to comment.