Skip to content

Commit

Permalink
Merge pull request #13 from neuromorphs/feature-state
Browse files Browse the repository at this point in the history
paper version with support for stateful submodules
  • Loading branch information
sheiksadique authored Dec 6, 2023
2 parents c7aba27 + 53109c3 commit 5fa4c07
Show file tree
Hide file tree
Showing 13 changed files with 632 additions and 138 deletions.
2 changes: 1 addition & 1 deletion nirtorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .graph import extract_torch_graph # noqa F401
from .from_nir import load # noqa F401
from .graph import extract_torch_graph # noqa F401
from .to_nir import extract_nir_graph # noqa F401

__version__ = version = "0.2.1"
248 changes: 180 additions & 68 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,127 +1,239 @@
from typing import Callable, Dict, List, Optional
import dataclasses
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import nir
import torch
import torch.nn as nn

from .graph import Graph, Node
from .graph_utils import trace_execution
from .utils import sanitize_name


def execution_order_up_to_node(
node: Node,
graph: Graph,
execution_order: List[Node],
visited: Optional[Dict[Node, bool]] = None,
) -> List[Node]:
"""Recursive function to evaluate execution order until a given node.
@dataclasses.dataclass
class GraphExecutorState:
"""State for the GraphExecutor that keeps track of both the state of hidden units
and caches the output of previous modules, for use in (future) recurrent
computations."""

Args:
node (Node): Execution order for the node of interest
graph (Graph): Graph object describing the network
execution_order (List[Node]): The current known execution order.
Returns:
List[Node]: Execution order
"""
if visited is None:
visited = {n: False for n in graph.node_list}
is_recursive = False
if len(execution_order) == list(graph.node_list):
# All nodes are executed
return execution_order
for parent in graph.find_source_nodes_of(node):
if parent not in execution_order and not visited[parent]:
visited[parent] = True
execution_order = execution_order_up_to_node(
parent, graph, execution_order, visited
)
if node in parent.outgoing_nodes:
is_recursive = True
# Ensure we're not re-adding a recursive node
if is_recursive and node in execution_order:
return execution_order
else: # Finally since all parents are known and executed
return execution_order + [node]
state: Dict[str, Any] = dataclasses.field(default_factory=dict)
cache: Dict[str, Any] = dataclasses.field(default_factory=dict)


class GraphExecutor(nn.Module):
def __init__(self, graph: Graph) -> None:
"""Executes the NIR graph in PyTorch.
By default the graph executor is stateful, since there may be recurrence or
stateful modules in the graph. Specifically, that means accepting and returning a
state object (`GraphExecutorState`). If that is not desired,
set `return_state=False` in the constructor.
Arguments:
graph (Graph): The graph to execute
return_state (bool, optional): Whether to return the state object.
Defaults to True.
Raises:
ValueError: If there are no edges in the graph
"""

def __init__(self, graph: Graph, return_state: bool = True) -> None:
super().__init__()
self.graph = graph
self.stateful_modules = set()
self.return_state = return_state
self.instantiate_modules()
self.execution_order = self.get_execution_order()
if len(self.execution_order) == 0:
raise ValueError("Graph is empty")

def _is_module_stateful(self, module: torch.nn.Module) -> bool:
signature = inspect.signature(module.forward)
arguments = len(signature.parameters)
# HACK for snntorch modules
if "snntorch" in str(module.__class__):
if module.__class__.__name__ in [
"Synaptic",
"RSynaptic",
"Leaky",
"RLeaky",
]:
return not module.init_hidden
return "state" in signature.parameters and arguments > 1

def get_execution_order(self) -> List[Node]:
"""Evaluate the execution order and instantiate that as a list."""
execution_order = []
# Then loop over all nodes and check that they are added to the execution order.
for node in self.graph.node_list:
if node not in execution_order and isinstance(node.elem, nn.Module):
execution_order = execution_order_up_to_node(
node, self.graph, execution_order
)
return execution_order
# TODO: Adapt this for graphs with multiple inputs
inputs = self.graph.inputs
if len(inputs) != 1:
raise ValueError(
f"Currently, only one input is supported, but {len(inputs)} was given"
)
return trace_execution(inputs[0], lambda n: n.outgoing_nodes.keys())

def instantiate_modules(self):
for mod, name in self.graph.module_names.items():
if isinstance(mod, nn.Module):
if mod is not None:
self.add_module(sanitize_name(name), mod)
if self._is_module_stateful(mod):
self.stateful_modules.add(sanitize_name(name))

def get_input_nodes(self) -> List[Node]:
# NOTE: This is a hack. Should use the input nodes from NIR graph
return self.graph.get_root()

def forward(self, data: torch.Tensor):
outs = {}
def _apply_module(
self,
node: Node,
input_nodes: List[Node],
new_state: GraphExecutorState,
old_state: GraphExecutorState,
data: Optional[torch.Tensor] = None,
):
"""Applies a module and keeps track of its state.
TODO: Use pytree to recursively construct the state
"""
inputs = []
# Append state if needed
if node.name in self.stateful_modules and node.name in old_state.state:
inputs.extend(old_state.state[node.name])

# Sum recurrence if needed
summed_inputs = [] if data is None else [data]
for input_node in input_nodes:
if (
input_node.name not in new_state.cache
and input_node.name in old_state.cache
):
summed_inputs.append(old_state.cache[input_node.name])
elif input_node.name in new_state.cache:
summed_inputs.append(new_state.cache[input_node.name])

if len(summed_inputs) == 0:
raise ValueError("No inputs found for node {}".format(node.name))
elif len(summed_inputs) == 1:
inputs.insert(0, summed_inputs[0])
elif len(summed_inputs) > 1:
inputs.insert(0, torch.stack(summed_inputs).sum(0))

out = node.elem(*inputs)
# If the module is stateful, we know the output is (at least) a tuple
# HACK to make it work for snnTorch
is_rsynaptic = "snntorch._neurons.rsynaptic.RSynaptic" in str(
node.elem.__class__
)
if is_rsynaptic and not node.elem.init_hidden:
assert "lif" in node.name, "this shouldnt happen.."
new_state.state[node.name] = out # snnTorch requires output inside state
out = out[0]
elif node.name in self.stateful_modules:
new_state.state[node.name] = out[1:] # Store the new state
out = out[0]
return out, new_state

def forward(
self, data: torch.Tensor, old_state: Optional[GraphExecutorState] = None
):
if old_state is None:
old_state = GraphExecutorState()
new_state = GraphExecutorState()
first_node = True
# NOTE: This logic is not yet consistent for models with multiple input nodes
for node in self.execution_order:
input_nodes = self.graph.find_source_nodes_of(node)
if node.elem is None:
continue
if len(input_nodes) == 0 or len(outs) == 0:
# This is the root node
outs[node.name] = node.elem(data)
else:
# Intermediate nodes
input_data = (outs[node.name] for node in input_nodes)
outs[node.name] = node.elem(*input_data)
return outs[node.name]


def _mod_nir_to_graph(nir_graph: nir.NIRNode) -> Graph:
module_names = {module: name for name, module in nir_graph.nodes.items()}
graph = Graph(module_names=module_names)
for src, dst in nir_graph.edges:
graph.add_edge(nir_graph.nodes[src], nir_graph.nodes[dst])
out, new_state = self._apply_module(
node,
input_nodes,
new_state=new_state,
old_state=old_state,
data=data if first_node else None,
)
new_state.cache[node.name] = out
first_node = False

# If the output node is a dummy nir.Output node, use the second-to-last node
if node.name not in new_state.cache:
node = self.execution_order[-2]
if self.return_state:
return new_state.cache[node.name], new_state
else:
return new_state.cache[node.name]


def _mod_nir_to_graph(
torch_graph: nir.NIRGraph, nir_nodes: Dict[str, nir.NIRNode]
) -> Graph:
module_names = {module: name for name, module in torch_graph.nodes.items()}
inputs = [name for name, node in nir_nodes.items() if isinstance(node, nir.Input)]
graph = Graph(module_names=module_names, inputs=inputs)
for src, dst in torch_graph.edges:
# Allow edges to refer to subgraph inputs and outputs
if src not in torch_graph.nodes and f"{src}.output" in torch_graph.nodes:
src = f"{src}.output"
if dst not in torch_graph.nodes and f"{dst}.input" in torch_graph.nodes:
dst = f"{dst}.input"
graph.add_edge(torch_graph.nodes[src], torch_graph.nodes[dst])
return graph


def _switch_default_models(nir_graph: nir.NIRNode) -> Optional[torch.nn.Module]:
if isinstance(nir_graph, nir.Input) or isinstance(nir_graph, nir.Output):
return torch.nn.Identity()


def _switch_models_with_map(
nir_graph: nir.NIRNode, model_map: Callable[[nn.Module], nn.Module]
) -> nir.NIRNode:
nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()}
nodes = {}
for name, node in nir_graph.nodes.items():
mapped_module = model_map(node)
if mapped_module is None:
mapped_module = _switch_default_models(node)
nodes[name] = mapped_module
# nodes = {name: model_map(node) for name, node in nir_graph.nodes.items()}
return nir.NIRGraph(nodes, nir_graph.edges)


def load(
nir_graph: nir.NIRNode, model_map: Callable[[nir.NIRNode], nn.Module]
nir_graph: Union[nir.NIRNode, str],
model_map: Callable[[nir.NIRNode], nn.Module],
return_state: bool = True,
) -> nn.Module:
"""Load a NIR object and convert it to a torch module using the given model map.
"""Load a NIR graph and convert it to a torch module using the given model map.
Because the graph can contain recurrence and stateful modules, the execution accepts
a secondary state argument and returns a tuple of [output, state], instead of just
the output as follows
>>> executor = nirtorch.load(nir_graph, model_map)
>>> old_state = None
>>> output, state = executor(input, old_state) # Notice second argument and output
>>> output, state = executor(input, state) # This can go on for many (time)steps
If you do not wish to operate with state, set `return_state=False`.
Args:
nir_graph (nir.NIRNode): NIR object
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string
representing the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.
return_state (bool): If True, the execution of the loaded graph will return a
tuple of [output, state], where state is a GraphExecutorState object.
If False, only the NIR graph output will be returned. Note that state is
required for recurrence to work in the graphs.
Returns:
nn.Module: The generated torch module
"""
if isinstance(nir_graph, str):
nir_graph = nir.read(nir_graph)
# Map modules to the target modules using th emodel map
nir_module_graph = _switch_models_with_map(nir_graph, model_map)
# Build a nirtorch.Graph based on the nir_graph
graph = _mod_nir_to_graph(nir_module_graph)
graph = _mod_nir_to_graph(nir_module_graph, nir_nodes=nir_graph.nodes)
# Build and return a graph executor module
return GraphExecutor(graph)
return GraphExecutor(graph, return_state=return_state)
Loading

0 comments on commit 5fa4c07

Please sign in to comment.