Skip to content

Commit

Permalink
trying to fix several workflow issues
Browse files Browse the repository at this point in the history
liked parameters with different values [fixed]
jobs dependencies [partly fixed, still not OK]
(#286)
  • Loading branch information
denisri committed Aug 3, 2023
1 parent 4eef044 commit e7813ba
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 69 deletions.
60 changes: 41 additions & 19 deletions capsul/execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

from .dataset import Dataset
from .pipeline.pipeline import Process, Pipeline
from .pipeline.process_iteration import IndependentExecutables, ProcessIteration
from .pipeline.process_iteration import (IndependentExecutables,
ProcessIteration)
from .pipeline import pipeline_tools
from capsul.config.configuration import get_config_class
from .config.configuration import ModuleConfiguration

Expand Down Expand Up @@ -104,6 +106,8 @@ def __init__(self, executable, debug=False):
jobs_per_process = {}
process_chronology = {}
processes_proxies = {}
nodes = pipeline_tools.topological_sort_nodes(executable.all_nodes())
pipeline_tools.propagate_meta(executable, nodes)
job_parameters = self._create_jobs(
top_parameters=top_parameters,
executable=executable,
Expand Down Expand Up @@ -132,7 +136,8 @@ def __init__(self, executable, debug=False):
bj['waited_by'].add(after_job)

# Resolve disabled jobs
disabled_jobs = [(uuid, job) for uuid, job in self.jobs.items() if job['disabled']]
disabled_jobs = [(uuid, job) for uuid, job in self.jobs.items()
if job['disabled']]
for disabled_job in disabled_jobs:
wait_for = set()
stack = disabled_job[1]['wait_for']
Expand All @@ -153,7 +158,7 @@ def __init__(self, executable, debug=False):
for job in disabled_job[1]['waited_by']:
self.jobs[job]['wait_for'].remove(disabled_job[0])
del self.jobs[disabled_job[0]]

# Transform wait_for sets to lists for json storage
# and add waited_by
for job_id, job in self.jobs.items():
Expand Down Expand Up @@ -211,10 +216,10 @@ def _create_jobs(self,
find_temporary_to_generate(executable)
disabled_nodes = process.disabled_pipeline_steps_nodes()
for node_name, node in process.nodes.items():
if (node is process
or not node.activated
or not isinstance(node, Process)
or node in disabled_nodes):
if (node is process
or not node.activated
or not isinstance(node, Process)
or node in disabled_nodes):
continue
nodes_dict[node_name] = {}
job_parameters = self._create_jobs(
Expand All @@ -231,9 +236,17 @@ def _create_jobs(self,
disabled=disabled or node in disabled_nodes)
nodes.append(node)
for field in process.user_fields():
for dest_node, plug_name in executable.get_linked_items(process,
field.name,
in_sub_pipelines=False):
links = list(executable.get_linked_items(
process,
field.name,
in_sub_pipelines=False,
direction='links_from')) \
+ list(executable.get_linked_items(
process,
field.name,
in_sub_pipelines=False,
direction='links_to'))
for dest_node, plug_name in links:
if dest_node in disabled_nodes:
continue
if field.metadata('write', False) \
Expand All @@ -243,12 +256,16 @@ def _create_jobs(self,
else:
parameters.content[field.name] \
= nodes_dict.get(dest_node.name, {}).get(plug_name)
break
# break
if field.is_output():
for dest_node_name, dest_plug_name, dest_node, dest_plug, is_weak in process.plugs[field.name].links_to:
if (isinstance(dest_node, Process) and dest_node.activated and
dest_node not in disabled_nodes and
not dest_node.field(dest_plug_name).is_output()):
for dest_node, dest_plug_name \
in executable.get_linked_items(
process, field.name, direction='links_to'):
if (isinstance(dest_node, Process)
and dest_node.activated
and dest_node not in disabled_nodes
and not dest_node.field(
dest_plug_name).is_output()):
if isinstance(dest_node, Pipeline):
continue
process_chronology.setdefault(
Expand All @@ -258,9 +275,11 @@ def _create_jobs(self,
for node in nodes:
for plug_name in node.plugs:
first = nodes_dict[node.name].get(plug_name)
for dest_node, dest_plug_name in process.get_linked_items(node, plug_name,
in_sub_pipelines=False):

for dest_node, dest_plug_name in process.get_linked_items(
node, plug_name,
in_sub_pipelines=False,
direction=('links_from', 'links_to')):

second = nodes_dict.get(dest_node.name, {}).get(dest_plug_name)
if dest_node.pipeline is not node.pipeline:
continue
Expand Down Expand Up @@ -385,12 +404,14 @@ def _create_jobs(self,
suffix = ''
uuid = str(uuid4())
value[i] = f'{prefix}.{field.name}_{i}_{uuid}{suffix}'
# print('generate tmp:', value)
if value is undefined:
value = process.getattr(field.name, None)
# print(' ', field.name, '<-', repr(value), getattr(field, 'generate_temporary', False))
proxy = parameters.proxy(executable.json_value(value))
parameters[field.name] = proxy
if field.is_output() and isinstance(executable, (Pipeline, ProcessIteration)):
if field.is_output() and isinstance(
executable, (Pipeline, ProcessIteration)):
for dest_node, plug_name in executable.get_linked_items(process, field.name,
direction='links_to'):
if isinstance(dest_node, Pipeline):
Expand All @@ -401,6 +422,7 @@ def _create_jobs(self,
process.uuid + ','.join(process_iterations.get(process.uuid, [])))
return parameters


def find_temporary_to_generate(executable):
# print('!temporaries! ->', executable.label)
if isinstance(executable, Pipeline):
Expand Down
106 changes: 74 additions & 32 deletions capsul/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,7 @@ def workflow_ordered_nodes(self, remove_disabled_steps=True):
graph = self.workflow_graph(remove_disabled_steps)


# Start the topologival sort
# Start the topological sort
ordered_list = graph.topological_sort()

def walk_workflow(wokflow, workflow_list):
Expand All @@ -1541,7 +1541,7 @@ def walk_workflow(wokflow, workflow_list):
workflow_list = []
walk_workflow(ordered_list, workflow_list)

return workflow_list
return workflow_list

def find_empty_parameters(self):
""" Find internal File/Directory parameters not exported to the main
Expand Down Expand Up @@ -2248,18 +2248,21 @@ def dispatch_plugs(self, node, name):
name,
in_sub_pipelines=False,
activated_only=False,
process_only=False))
process_only=False,
direction=('links_from', 'links_to')))
while stack:
item = stack.pop()
if item not in done:
node, plug = item
yield (node, plug)
done.add(item)
stack.extend(self.get_linked_items(node,
plug,
stack.extend(self.get_linked_items(
node,
plug,
in_sub_pipelines=False,
activated_only=False,
process_only=False))
process_only=False,
direction=('links_from', 'links_to')))
self.enable_parameter_links = enable_parameter_links

def dispatch_all_values(self):
Expand All @@ -2278,19 +2281,31 @@ def get_linked_items(self, node, plug_name=None, in_sub_pipelines=True,
Going through switches and inside subpipelines, ignoring nodes that are
not activated.
The result is a generator of pairs (node, plug_name).
direction may be a sting, 'links_from', 'links_to', or a tuple
('linnks_from', 'links_to').
'''
if plug_name is None:
stack = [(node, plug) for plug in node.plugs]
else:
stack = [(node, plug_name)]
done = set()

while stack:
node, plug_name = stack.pop(0)
current = stack.pop(0)
if current in done:
continue
done.add(current)
node, plug_name = current
if activated_only and not node.activated:
continue
plug = node.plugs.get(plug_name)
if plug:
if direction is not None:
directions = (direction,)
if isinstance(direction, (tuple, list)):
directions = direction
else:
directions = (direction,)
else:
if isinstance(node, Pipeline):
if in_outer_pipelines:
Expand All @@ -2304,41 +2319,68 @@ def get_linked_items(self, node, plug_name=None, in_sub_pipelines=True,
else:
directions = ('links_from',)
for current_direction in directions:
for dest_plug_name, dest_node in (i[1:3] for i in getattr(plug, current_direction)):
if dest_node is node or (activated_only
and not dest_node.activated):
for dest_plug_name, dest_node in \
(i[1:3] for i in getattr(plug, current_direction)):
if dest_node is node \
or (activated_only
and not dest_node.activated):
continue
if isinstance(dest_node, Pipeline):
if ((in_sub_pipelines and dest_node is not self) or
(in_outer_pipelines and isinstance(dest_node, Pipeline))):
for n, p in self.get_linked_items(dest_node,
dest_plug_name,
activated_only=activated_only,
process_only=process_only,
in_sub_pipelines=in_sub_pipelines,
direction=current_direction,
in_outer_pipelines=in_outer_pipelines):
if ((in_sub_pipelines and dest_node is not self)
or in_outer_pipelines):
for n, p in self.get_linked_items(
dest_node,
dest_plug_name,
activated_only=activated_only,
process_only=process_only,
in_sub_pipelines=in_sub_pipelines,
direction=current_direction,
in_outer_pipelines=in_outer_pipelines):
if n is not node:
yield (n, p)
yield (dest_node, dest_plug_name)
if (n, p) not in done:
yield (n, p)
if (dest_node, dest_plug_name) not in done:
yield (dest_node, dest_plug_name)
elif isinstance(dest_node, Switch):
if dest_plug_name == 'switch':
if not process_only:
yield (dest_node, dest_plug_name)
if (dest_node, dest_plug_name) \
not in done:
yield (dest_node, dest_plug_name)
else:
for input_plug_name, output_plug_name in dest_node.connections():
if plug.output ^ isinstance(node, Pipeline):
if direction is None \
or (isinstance(direction,
(tuple, list))
and len(direction) == 2):
# if bidirectional search only
stack.append((dest_node, dest_plug_name))
for input_plug_name, output_plug_name \
in dest_node.connections():
if plug.output ^ isinstance(node,
Pipeline):
if dest_plug_name == input_plug_name:
if not process_only:
yield (dest_node, output_plug_name)
stack.append((dest_node, output_plug_name))
if not process_only \
and (dest_node,
output_plug_name) \
not in done:
yield (
dest_node,
output_plug_name)
stack.append((dest_node,
output_plug_name))
else:
if dest_plug_name == output_plug_name:
if not process_only:
yield (dest_node, input_plug_name)
stack.append((dest_node, input_plug_name))
if not process_only \
and (dest_node,
input_plug_name) \
not in done:
yield (
dest_node, input_plug_name)
stack.append((dest_node,
input_plug_name))
else:
yield (dest_node, dest_plug_name)
if (dest_node, dest_plug_name) not in done:
yield (dest_node, dest_plug_name)

def json(self, include_parameters=True):
result = super().json(include_parameters=include_parameters)
Expand Down
40 changes: 22 additions & 18 deletions capsul/pipeline/pipeline_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,13 @@ def _switch_changed(self, new_selection, old_selection):
setattr(self, output_plug_name,
getattr(self, corresponding_input_plug_name, undefined))

if self.pipeline is not None:
f = self.field(output_plug_name)
for n, p in self.pipeline.get_linked_items(
self, corresponding_input_plug_name,
direction='links_from'):
# copy input field metadata
for k, v in n.field(p).metadata().items():
setattr(f, k, v)
break

# Propagate the associated field documentation
out_field = self.field(output_plug_name)
in_field = self.field(corresponding_input_plug_name)
out_field.doc = in_field.metadata('doc', None)

self.propagate_fields_metadata()

self.pipeline.restore_update_nodes_and_plugs_activation()
self.__block_output_propagation = False

Expand Down Expand Up @@ -314,14 +306,6 @@ def _any_attribute_changed(self, new, old, name):
if self.switch == switch_selection:
self.__block_output_propagation = True
setattr(self, output_plug_name, new)
if self.pipeline is not None:
f = self.field(output_plug_name)
for n, p in self.pipeline.get_linked_items(
self, name, direction='links_from'):
# copy input field metadata
for k, v in n.field(p).metadata().items():
setattr(f, k, v)
break
self.__block_output_propagation = False

def __setstate__(self, state):
Expand Down Expand Up @@ -393,6 +377,26 @@ def configured_controller(self):
c.optional_params = [self.field(p).optional for p in self.inputs]
return c

def propagate_fields_metadata(self):
''' Propagate metadata from connected inputs (that is, outputs of
upstream processes) to outputs.
This is needed to get correct status (read/write) on output pipeline
plugs once the switch state is chosen.
'''
for output_plug_name in self._outputs:
# Get the associated input name
input_plug_name = f'{self.switch}_switch_{output_plug_name}'

if self.pipeline is not None:
f = self.field(output_plug_name)
for n, p in self.pipeline.get_linked_items(
self, input_plug_name,
direction='links_from'):
# copy input field metadata
for k, v in n.field(p).metadata().items():
setattr(f, k, v)
break

@classmethod
def build_node(cls, pipeline, name, conf_controller):
node = Switch(pipeline, name, conf_controller.inputs,
Expand Down
Loading

0 comments on commit e7813ba

Please sign in to comment.