Skip to content

Commit

Permalink
minor changes to the doc strings
Browse files Browse the repository at this point in the history
  • Loading branch information
sheiksadique committed Dec 5, 2023
1 parent 5845167 commit 0325c80
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ class GraphExecutorState:
class GraphExecutor(nn.Module):
"""Executes the NIR graph in PyTorch.
By default the graph executor is stateful, since there may be recurrence or stateful modules in the graph.
Specifically, that means accepting and returning a state object (`GraphExecutorState`).
If that is not desired, set `return_state=False` in the constructor.
By default the graph executor is stateful, since there may be recurrence or
stateful modules in the graph. Specifically, that means accepting and returning a
state object (`GraphExecutorState`). If that is not desired,
set `return_state=False` in the constructor.
Arguments:
graph (Graph): The graph to execute
return_state (bool, optional): Whether to return the state object. Defaults to True.
return_state (bool, optional): Whether to return the state object.
Defaults to True.
Raises:
ValueError: If there are no edges in the graph
Expand Down Expand Up @@ -169,9 +171,9 @@ def _mod_nir_to_graph(
graph = Graph(module_names=module_names, inputs=inputs)
for src, dst in torch_graph.edges:
# Allow edges to refer to subgraph inputs and outputs
if not src in torch_graph.nodes and f"{src}.output" in torch_graph.nodes:
if src not in torch_graph.nodes and f"{src}.output" in torch_graph.nodes:
src = f"{src}.output"
if not dst in torch_graph.nodes and f"{dst}.input" in torch_graph.nodes:
if dst not in torch_graph.nodes and f"{dst}.input" in torch_graph.nodes:
dst = f"{dst}.input"
graph.add_edge(torch_graph.nodes[src], torch_graph.nodes[dst])
return graph
Expand Down Expand Up @@ -203,24 +205,25 @@ def load(
"""Load a NIR graph and convert it to a torch module using the given model map.
Because the graph can contain recurrence and stateful modules, the execution accepts
a secondary state argument and returns a tuple of [output, state], instead of just the output as follows
a secondary state argument and returns a tuple of [output, state], instead of just
the output as follows
>>> executor = nirtorch.load(nir_graph, model_map)
>>> old_state = None
>>> output, state = executor(input, old_state) # Notice the second argument and output
>>> output, state = executor(input, old_state) # Notice second argument and output
>>> output, state = executor(input, state) # This can go on for many (time)steps
If you do not wish to operate with state, set `return_state=False`.
Args:
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string representing
the path to the NIR object.
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string
representing the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.
return_state (bool): If True, the execution of the loaded graph will return a tuple
of [output, state], where state is a GraphExecutorState object. If False, only
the NIR graph output will be returned. Note that state is required for recurrence
to work in the graphs.
return_state (bool): If True, the execution of the loaded graph will return a
tuple of [output, state], where state is a GraphExecutorState object.
If False, only the NIR graph output will be returned. Note that state is
required for recurrence to work in the graphs.
Returns:
nn.Module: The generated torch module
Expand Down

0 comments on commit 0325c80

Please sign in to comment.