diff --git a/capsul/execution_context.py b/capsul/execution_context.py index daf8c9863..9068da6f4 100644 --- a/capsul/execution_context.py +++ b/capsul/execution_context.py @@ -8,7 +8,9 @@ from .dataset import Dataset from .pipeline.pipeline import Process, Pipeline -from .pipeline.process_iteration import IndependentExecutables, ProcessIteration +from .pipeline.process_iteration import (IndependentExecutables, + ProcessIteration) +from .pipeline import pipeline_tools from capsul.config.configuration import get_config_class from .config.configuration import ModuleConfiguration @@ -104,6 +106,8 @@ def __init__(self, executable, debug=False): jobs_per_process = {} process_chronology = {} processes_proxies = {} + nodes = pipeline_tools.topological_sort_nodes(executable.all_nodes()) + pipeline_tools.propagate_meta(executable, nodes) job_parameters = self._create_jobs( top_parameters=top_parameters, executable=executable, @@ -132,7 +136,8 @@ def __init__(self, executable, debug=False): bj['waited_by'].add(after_job) # Resolve disabled jobs - disabled_jobs = [(uuid, job) for uuid, job in self.jobs.items() if job['disabled']] + disabled_jobs = [(uuid, job) for uuid, job in self.jobs.items() + if job['disabled']] for disabled_job in disabled_jobs: wait_for = set() stack = disabled_job[1]['wait_for'] @@ -153,7 +158,7 @@ def __init__(self, executable, debug=False): for job in disabled_job[1]['waited_by']: self.jobs[job]['wait_for'].remove(disabled_job[0]) del self.jobs[disabled_job[0]] - + # Transform wait_for sets to lists for json storage # and add waited_by for job_id, job in self.jobs.items(): @@ -211,10 +216,10 @@ def _create_jobs(self, find_temporary_to_generate(executable) disabled_nodes = process.disabled_pipeline_steps_nodes() for node_name, node in process.nodes.items(): - if (node is process - or not node.activated - or not isinstance(node, Process) - or node in disabled_nodes): + if (node is process + or not node.activated + or not isinstance(node, Process) + or node in disabled_nodes): continue nodes_dict[node_name] = {} job_parameters = self._create_jobs( @@ -231,9 +236,17 @@ def _create_jobs(self, disabled=disabled or node in disabled_nodes) nodes.append(node) for field in process.user_fields(): - for dest_node, plug_name in executable.get_linked_items(process, - field.name, - in_sub_pipelines=False): + links = list(executable.get_linked_items( + process, + field.name, + in_sub_pipelines=False, + direction='links_from')) \ + + list(executable.get_linked_items( + process, + field.name, + in_sub_pipelines=False, + direction='links_to')) + for dest_node, plug_name in links: if dest_node in disabled_nodes: continue if field.metadata('write', False) \ @@ -243,12 +256,16 @@ def _create_jobs(self, else: parameters.content[field.name] \ = nodes_dict.get(dest_node.name, {}).get(plug_name) - break + # break if field.is_output(): - for dest_node_name, dest_plug_name, dest_node, dest_plug, is_weak in process.plugs[field.name].links_to: - if (isinstance(dest_node, Process) and dest_node.activated and - dest_node not in disabled_nodes and - not dest_node.field(dest_plug_name).is_output()): + for dest_node, dest_plug_name \ + in executable.get_linked_items( + process, field.name, direction='links_to'): + if (isinstance(dest_node, Process) + and dest_node.activated + and dest_node not in disabled_nodes + and not dest_node.field( + dest_plug_name).is_output()): if isinstance(dest_node, Pipeline): continue process_chronology.setdefault( @@ -258,9 +275,11 @@ def _create_jobs(self, for node in nodes: for plug_name in node.plugs: first = nodes_dict[node.name].get(plug_name) - for dest_node, dest_plug_name in process.get_linked_items(node, plug_name, - in_sub_pipelines=False): - + for dest_node, dest_plug_name in process.get_linked_items( + node, plug_name, + in_sub_pipelines=False, + direction=('links_from', 'links_to')): + second = nodes_dict.get(dest_node.name, {}).get(dest_plug_name) if dest_node.pipeline is not node.pipeline: continue @@ -385,12 +404,14 @@ def _create_jobs(self, suffix = '' uuid = str(uuid4()) value[i] = f'{prefix}.{field.name}_{i}_{uuid}{suffix}' + # print('generate tmp:', value) if value is undefined: value = process.getattr(field.name, None) # print(' ', field.name, '<-', repr(value), getattr(field, 'generate_temporary', False)) proxy = parameters.proxy(executable.json_value(value)) parameters[field.name] = proxy - if field.is_output() and isinstance(executable, (Pipeline, ProcessIteration)): + if field.is_output() and isinstance( + executable, (Pipeline, ProcessIteration)): for dest_node, plug_name in executable.get_linked_items(process, field.name, direction='links_to'): if isinstance(dest_node, Pipeline): @@ -401,6 +422,7 @@ def _create_jobs(self, process.uuid + ','.join(process_iterations.get(process.uuid, []))) return parameters + def find_temporary_to_generate(executable): # print('!temporaries! ->', executable.label) if isinstance(executable, Pipeline): diff --git a/capsul/pipeline/pipeline.py b/capsul/pipeline/pipeline.py index 5c4c750b3..b9bf6bdbb 100644 --- a/capsul/pipeline/pipeline.py +++ b/capsul/pipeline/pipeline.py @@ -1518,7 +1518,7 @@ def workflow_ordered_nodes(self, remove_disabled_steps=True): graph = self.workflow_graph(remove_disabled_steps) - # Start the topologival sort + # Start the topological sort ordered_list = graph.topological_sort() def walk_workflow(wokflow, workflow_list): @@ -1541,7 +1541,7 @@ def walk_workflow(wokflow, workflow_list): workflow_list = [] walk_workflow(ordered_list, workflow_list) - return workflow_list + return workflow_list def find_empty_parameters(self): """ Find internal File/Directory parameters not exported to the main @@ -2248,18 +2248,21 @@ def dispatch_plugs(self, node, name): name, in_sub_pipelines=False, activated_only=False, - process_only=False)) + process_only=False, + direction=('links_from', 'links_to'))) while stack: item = stack.pop() if item not in done: node, plug = item yield (node, plug) done.add(item) - stack.extend(self.get_linked_items(node, - plug, + stack.extend(self.get_linked_items( + node, + plug, in_sub_pipelines=False, activated_only=False, - process_only=False)) + process_only=False, + direction=('links_from', 'links_to'))) self.enable_parameter_links = enable_parameter_links def dispatch_all_values(self): @@ -2278,19 +2281,31 @@ def get_linked_items(self, node, plug_name=None, in_sub_pipelines=True, Going through switches and inside subpipelines, ignoring nodes that are not activated. The result is a generator of pairs (node, plug_name). + + direction may be a sting, 'links_from', 'links_to', or a tuple + ('linnks_from', 'links_to'). ''' if plug_name is None: stack = [(node, plug) for plug in node.plugs] else: stack = [(node, plug_name)] + done = set() + while stack: - node, plug_name = stack.pop(0) + current = stack.pop(0) + if current in done: + continue + done.add(current) + node, plug_name = current if activated_only and not node.activated: continue plug = node.plugs.get(plug_name) if plug: if direction is not None: - directions = (direction,) + if isinstance(direction, (tuple, list)): + directions = direction + else: + directions = (direction,) else: if isinstance(node, Pipeline): if in_outer_pipelines: @@ -2304,41 +2319,68 @@ def get_linked_items(self, node, plug_name=None, in_sub_pipelines=True, else: directions = ('links_from',) for current_direction in directions: - for dest_plug_name, dest_node in (i[1:3] for i in getattr(plug, current_direction)): - if dest_node is node or (activated_only - and not dest_node.activated): + for dest_plug_name, dest_node in \ + (i[1:3] for i in getattr(plug, current_direction)): + if dest_node is node \ + or (activated_only + and not dest_node.activated): continue if isinstance(dest_node, Pipeline): - if ((in_sub_pipelines and dest_node is not self) or - (in_outer_pipelines and isinstance(dest_node, Pipeline))): - for n, p in self.get_linked_items(dest_node, - dest_plug_name, - activated_only=activated_only, - process_only=process_only, - in_sub_pipelines=in_sub_pipelines, - direction=current_direction, - in_outer_pipelines=in_outer_pipelines): + if ((in_sub_pipelines and dest_node is not self) + or in_outer_pipelines): + for n, p in self.get_linked_items( + dest_node, + dest_plug_name, + activated_only=activated_only, + process_only=process_only, + in_sub_pipelines=in_sub_pipelines, + direction=current_direction, + in_outer_pipelines=in_outer_pipelines): if n is not node: - yield (n, p) - yield (dest_node, dest_plug_name) + if (n, p) not in done: + yield (n, p) + if (dest_node, dest_plug_name) not in done: + yield (dest_node, dest_plug_name) elif isinstance(dest_node, Switch): if dest_plug_name == 'switch': if not process_only: - yield (dest_node, dest_plug_name) + if (dest_node, dest_plug_name) \ + not in done: + yield (dest_node, dest_plug_name) else: - for input_plug_name, output_plug_name in dest_node.connections(): - if plug.output ^ isinstance(node, Pipeline): + if direction is None \ + or (isinstance(direction, + (tuple, list)) + and len(direction) == 2): + # if bidirectional search only + stack.append((dest_node, dest_plug_name)) + for input_plug_name, output_plug_name \ + in dest_node.connections(): + if plug.output ^ isinstance(node, + Pipeline): if dest_plug_name == input_plug_name: - if not process_only: - yield (dest_node, output_plug_name) - stack.append((dest_node, output_plug_name)) + if not process_only \ + and (dest_node, + output_plug_name) \ + not in done: + yield ( + dest_node, + output_plug_name) + stack.append((dest_node, + output_plug_name)) else: if dest_plug_name == output_plug_name: - if not process_only: - yield (dest_node, input_plug_name) - stack.append((dest_node, input_plug_name)) + if not process_only \ + and (dest_node, + input_plug_name) \ + not in done: + yield ( + dest_node, input_plug_name) + stack.append((dest_node, + input_plug_name)) else: - yield (dest_node, dest_plug_name) + if (dest_node, dest_plug_name) not in done: + yield (dest_node, dest_plug_name) def json(self, include_parameters=True): result = super().json(include_parameters=include_parameters) diff --git a/capsul/pipeline/pipeline_nodes.py b/capsul/pipeline/pipeline_nodes.py index 008e8684c..dc78cdecb 100644 --- a/capsul/pipeline/pipeline_nodes.py +++ b/capsul/pipeline/pipeline_nodes.py @@ -235,21 +235,13 @@ def _switch_changed(self, new_selection, old_selection): setattr(self, output_plug_name, getattr(self, corresponding_input_plug_name, undefined)) - if self.pipeline is not None: - f = self.field(output_plug_name) - for n, p in self.pipeline.get_linked_items( - self, corresponding_input_plug_name, - direction='links_from'): - # copy input field metadata - for k, v in n.field(p).metadata().items(): - setattr(f, k, v) - break - # Propagate the associated field documentation out_field = self.field(output_plug_name) in_field = self.field(corresponding_input_plug_name) out_field.doc = in_field.metadata('doc', None) + self.propagate_fields_metadata() + self.pipeline.restore_update_nodes_and_plugs_activation() self.__block_output_propagation = False @@ -314,14 +306,6 @@ def _any_attribute_changed(self, new, old, name): if self.switch == switch_selection: self.__block_output_propagation = True setattr(self, output_plug_name, new) - if self.pipeline is not None: - f = self.field(output_plug_name) - for n, p in self.pipeline.get_linked_items( - self, name, direction='links_from'): - # copy input field metadata - for k, v in n.field(p).metadata().items(): - setattr(f, k, v) - break self.__block_output_propagation = False def __setstate__(self, state): @@ -393,6 +377,26 @@ def configured_controller(self): c.optional_params = [self.field(p).optional for p in self.inputs] return c + def propagate_fields_metadata(self): + ''' Propagate metadata from connected inputs (that is, outputs of + upstream processes) to outputs. + This is needed to get correct status (read/write) on output pipeline + plugs once the switch state is chosen. + ''' + for output_plug_name in self._outputs: + # Get the associated input name + input_plug_name = f'{self.switch}_switch_{output_plug_name}' + + if self.pipeline is not None: + f = self.field(output_plug_name) + for n, p in self.pipeline.get_linked_items( + self, input_plug_name, + direction='links_from'): + # copy input field metadata + for k, v in n.field(p).metadata().items(): + setattr(f, k, v) + break + @classmethod def build_node(cls, pipeline, name, conf_controller): node = Switch(pipeline, name, conf_controller.inputs, diff --git a/capsul/pipeline/pipeline_tools.py b/capsul/pipeline/pipeline_tools.py index d9807dcf9..72f142c70 100644 --- a/capsul/pipeline/pipeline_tools.py +++ b/capsul/pipeline/pipeline_tools.py @@ -1526,3 +1526,95 @@ def replace_node(node, module_name, dirname, done, parent, node_name): save_pipeline(pipeline, filename) del sys.path[0] + + +def topological_sort_nodes(nodes): + ''' Sort nodes topologically according to their links. + All linked nodes must be in the nodes list: if switched or pipelines are + removed, the sort will be broken. + + In the output list, pipeline nodes will appear twice, in tuples: + + (pipeline, 0) is the position of the input plugs of the pipeline + + (pipeline, 1) is the position of the output plugs of the pipeline + + nodes inside the pipeline will logically be between both. + ''' + nsort = [] + done = set() + todo = list(nodes) + while todo: + node = todo.pop(0) + if node in done: + continue + + i = -1 + cont = False + for plug in node.plugs.values(): + if not plug.output: + for ld in plug.links_from: + n = ld[2] + n0 = n + if isinstance(n, Pipeline): + if not ld[3].output: + n = (n, 0) # begin of pipeline + else: + n = (n, 1) # end of pipeline + if n in done: + ni = nsort.index(n) # WARNING: expensive search + if ni > i: + i = ni + else: + todo.insert(0, n0) + cont = True + break + if cont: + break + if cont: + continue + + # print('insert', node.full_name, ':', i+1) + # if i >= 0: + # print(' after', nsort[i].full_name if not isinstance(nsort[i], tuple) else (nsort[i][0].full_name, nsort[i][1]) ) + if isinstance(node, Pipeline): + nsort.insert(i+1, (node, 0)) + nsort.insert(i+2, (node, 1)) + done.add((node, 0)) + done.add((node, 1)) + else: + nsort.insert(i+1, node) + done.add(node) + return nsort + + +def propagate_meta(executable, nodes=None): + ''' Propagate metadata from processes output plugs to downstream + switches and upper level pipelines plugs, recursively in topological order. + + If ``nodes`` is provided, it should be the nodes list already in + topological order. It may be passed if reused in order to avoid + calling :func:`topological_sort_nodes` several times. + ''' + if nodes is None: + nodes = topological_sort_nodes( + executable.all_nodes()) + for node in nodes: + if isinstance(node, Switch): + node.propagate_fields_metadata() + if isinstance(node, tuple): + if node[1] == 0: + # pipeline inputs + continue + node = node[0] + for pname, plug in node.plugs.items(): + if plug.output: + f = node.field(pname) + for ld in plug.links_to: + n = ld[2] + p = ld[1] + fo = n.field(p) + if fo.is_output(): + # copy field metadata + for k, v in f.metadata().items(): + setattr(fo, k, v)