diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index 3b68d02ee..73bd635a6 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -55,6 +55,7 @@ from lsst.utils.introspection import get_full_type_name from lsst.utils.iteration import ensure_iterable +from ... import automatic_connection_constants as acc from ... import connectionTypes as cT from ...config import PipelineTaskConfig from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections @@ -381,12 +382,22 @@ def __init__(self, *, config: MockPipelineTaskConfig): self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) for name, connection in self.original.allConnections.items(): if connection.name not in self.unmocked_dataset_types: - # We register the mock storage class with the global singleton - # here, but can only put its name in the connection. That means - # the same global singleton (or one that also has these - # registrations) has to be available whenever this dataset type - # is used. - storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) + if connection.storageClass in ( + acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, + acc.METADATA_OUTPUT_STORAGE_CLASS, + acc.LOG_OUTPUT_STORAGE_CLASS, + ): + # We don't mock the automatic output connections, so if + # they're used as an input in any other connection, we + # can't mock them there either. + storage_class_name = connection.storageClass + else: + # We register the mock storage class with the global + # singleton here, but can only put its name in the + # connection. That means the same global singleton (or one + # that also has these registrations) has to be available + # whenever this dataset type is used. + storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name kwargs: dict[str, Any] = {} if hasattr(connection, "dimensions"): connection_dimensions = set(connection.dimensions) @@ -400,7 +411,7 @@ def __init__(self, *, config: MockPipelineTaskConfig): connection = dataclasses.replace( connection, name=get_mock_name(connection.name), - storageClass=storage_class.name, + storageClass=storage_class_name, **kwargs, ) elif name in self.original.outputs: