Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to execute stateful submodules #13

Merged
merged 30 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a0a2f47
Added option to execute stateful submodules
Jegp Oct 10, 2023
1465f29
Returned state if stateful module
Jegp Oct 10, 2023
bbe54a0
Ruff
Jegp Oct 10, 2023
a64c23d
Added recurrent execution
Jegp Oct 11, 2023
0a1d97d
Added tests for recurrent execution
Jegp Oct 12, 2023
5cc87de
test for NIR -> NIRTorch -> NIR
stevenabreu7 Oct 13, 2023
70f4447
refactoring + expose ignore_submodules_of
stevenabreu7 Oct 13, 2023
175d34f
fix and test for issue #16
stevenabreu7 Oct 13, 2023
37e4237
fix recurrent test
stevenabreu7 Oct 16, 2023
c76fb6f
remove batch froms shape spec
sheiksadique Oct 18, 2023
26cadf6
Merge branch 'main' into 17-input-node-retains-batch-dimension
sheiksadique Oct 18, 2023
48f9842
bug from hell
stevenabreu7 Oct 18, 2023
26242d3
from_nir hacks for snnTorch
stevenabreu7 Oct 19, 2023
668e023
+ optional model.forward args for stateful modules
stevenabreu7 Oct 19, 2023
c555b2a
change subgraphs handlign (flatten + remove I/O)
stevenabreu7 Oct 19, 2023
60c01f8
model fwd args + ignore_dims arg
stevenabreu7 Oct 19, 2023
d4b1afb
[hack] remove wrong RNN self-connection (NIRTorch)
stevenabreu7 Oct 19, 2023
c736c0e
Added proper graph tracing
Jegp Oct 19, 2023
fe7188a
+ arg to ignore dims in to_nir
stevenabreu7 Oct 20, 2023
a21819f
add tests
stevenabreu7 Oct 20, 2023
bef454b
output_shape also uses ignore_dims
sheiksadique Oct 20, 2023
b95ad5c
Added test for flatten
Jegp Oct 20, 2023
6c1d81e
Merged changes from #18
Jegp Oct 20, 2023
3bc8bd2
minor correction to default value
sheiksadique Oct 20, 2023
84e3cc8
Added ability to ignore state in executor
Jegp Oct 21, 2023
8278437
Added flag in nirtorch parsing
Jegp Oct 21, 2023
ec8cded
Added flag in nirtorch parsing
Jegp Oct 21, 2023
5845167
Merged sinabs test changes
Jegp Oct 21, 2023
0325c80
minor changes to the doc strings
sheiksadique Dec 5, 2023
53109c3
formatting fixes
sheiksadique Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 88 additions & 15 deletions nirtorch/from_nir.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Callable, Dict, List, Optional
import dataclasses
import inspect
from typing import Callable, Dict, List, Optional, Any, Union

import nir
import torch
Expand Down Expand Up @@ -45,15 +47,31 @@ def execution_order_up_to_node(
return execution_order + [node]


@dataclasses.dataclass
class GraphExecutorState:
"""State for the GraphExecutor that keeps track of both the state of hidden
units and caches the output of previous modules, for use in (future) recurrent
computations."""

state: Dict[str, Any] = dataclasses.field(default_factory=dict)
cache: Dict[str, Any] = dataclasses.field(default_factory=dict)


class GraphExecutor(nn.Module):
def __init__(self, graph: Graph) -> None:
super().__init__()
self.graph = graph
self.stateful_modules = {}
self.instantiate_modules()
self.execution_order = self.get_execution_order()
if len(self.execution_order) == 0:
raise ValueError("Graph is empty")

def _is_module_stateful(self, module: torch.nn.Module) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to find other ways of doing this. But how?

Here's the challenge as far as I can tell

  • Most frameworks can live without state (snnTorch, Sinabs, Rockpool)
  • Norse requires a state parameter (similar to PyTorch RNNs)
  • snnTorch can take spk and mem inputs

Would an option be to look for state in the arguments to account for the norse case and spk and mem to account for the snnTorch case?

signature = inspect.signature(module.forward)
arguments = len(signature.parameters)
return arguments > 1

def get_execution_order(self) -> List[Node]:
"""Evaluate the execution order and instantiate that as a list."""
execution_order = []
Expand All @@ -67,27 +85,79 @@ def get_execution_order(self) -> List[Node]:

def instantiate_modules(self):
for mod, name in self.graph.module_names.items():
self.add_module(sanitize_name(name), mod)
if mod is not None:
self.add_module(sanitize_name(name), mod)
self.stateful_modules[sanitize_name(name)] = self._is_module_stateful(
mod
)

def get_input_nodes(self) -> List[Node]:
# NOTE: This is a hack. Should use the input nodes from NIR graph
return self.graph.get_root()

def forward(self, data: torch.Tensor):
outs = {}
def _apply_module(
self,
node: Node,
input_nodes: List[Node],
old_state: GraphExecutorState,
new_state: GraphExecutorState,
data: Optional[torch.Tensor] = None,
):
"""Applies a module and keeps track of its state.
TODO: Use pytree to recursively construct the state
"""
inputs = []
# Append state if needed
if node.name in self.stateful_modules and node.name in old_state.state:
inputs.extend(old_state.state[node.name])

# Sum recurrence if needed
summed_inputs = [] if data is None else [data]
for input_node in input_nodes:
if (
input_node.name not in new_state.cache
and input_node.name in old_state.cache
):
summed_inputs.append(old_state.cache[input_node.name])
elif input_node.name in new_state.cache:
summed_inputs.append(new_state.cache[input_node.name])

if len(summed_inputs) == 0:
raise ValueError("No inputs found for node {}".format(node.name))
elif len(summed_inputs) == 1:
inputs.insert(0, summed_inputs[0])
elif len(summed_inputs) > 1:
inputs.insert(0, torch.stack(summed_inputs).sum(0))

out = node.elem(*inputs)
# If the module is stateful, we know the output is (at least) a tuple
if self.stateful_modules[node.name]:
new_state.state[node.name] = out[1:] # Store the new state
out = out[0]
return out, new_state

def forward(
self, data: torch.Tensor, old_state: Optional[GraphExecutorState] = None
):
if old_state is None:
old_state = GraphExecutorState()
new_state = GraphExecutorState()
first_node = True
# NOTE: This logic is not yet consistent for models with multiple input nodes
for node in self.execution_order:
input_nodes = self.graph.find_source_nodes_of(node)
if node.elem is None:
continue
if len(input_nodes) == 0 or len(outs) == 0:
# This is the root node
outs[node.name] = node.elem(data)
else:
# Intermediate nodes
input_data = (outs[node.name] for node in input_nodes)
outs[node.name] = node.elem(*input_data)
return outs[node.name]
out, new_state = self._apply_module(
node, input_nodes, new_state, old_state, data if first_node else None
)
new_state.cache[node.name] = out
first_node = False

# If the output node is a dummy nir.Output node, use the second-to-last node
if node.name not in new_state.cache:
node = self.execution_order[-2]
return new_state.cache[node.name], new_state


def _mod_nir_to_graph(nir_graph: nir.NIRNode) -> Graph:
Expand All @@ -106,18 +176,21 @@ def _switch_models_with_map(


def load(
nir_graph: nir.NIRNode, model_map: Callable[[nir.NIRNode], nn.Module]
nir_graph: Union[nir.NIRNode, str], model_map: Callable[[nir.NIRNode], nn.Module]
) -> nn.Module:
"""Load a NIR object and convert it to a torch module using the given model map.
"""Load a NIR graph and convert it to a torch module using the given model map.

Args:
nir_graph (nir.NIRNode): NIR object
nir_graph (Union[nir.NIRNode, str]): The NIR object to load, or a string representing
the path to the NIR object.
model_map (Callable[[nn.NIRNode], nn.Module]): A method that returns the a torch
module that corresponds to each NIR node.

Returns:
nn.Module: The generated torch module
"""
if isinstance(nir_graph, str):
nir_graph = nir.read(nir_graph)
# Map modules to the target modules using th emodel map
nir_module_graph = _switch_models_with_map(nir_graph, model_map)
# Build a nirtorch.Graph based on the nir_graph
Expand Down
22 changes: 13 additions & 9 deletions nirtorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.nn as nn

from .utils import sanitize_name


def named_modules_map(
model: nn.Module, model_name: Optional[str] = "model"
Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
outgoing_nodes: Optional[Dict["Node", torch.Tensor]] = None,
) -> None:
self.elem = elem
self.name = name
self.name = sanitize_name(name)
if not outgoing_nodes:
self.outgoing_nodes = {}
else:
Expand Down Expand Up @@ -192,22 +194,24 @@ def populate_from(self, other_graph: "Graph"):
def __str__(self) -> str:
return self.to_md()

def debug_str(self) -> str:
debug_str = ""
for node in self.node_list:
debug_str += f"{node.name} ({node.elem.__class__.__name__})\n"
for outgoing, shape in node.outgoing_nodes.items():
debug_str += f"\t-> {outgoing.name} ({outgoing.elem.__class__.__name__})\n"
return debug_str.strip()

def to_md(self) -> str:
mermaid_md = """
```mermaid
graph TD;
"""
mermaid_md = """```mermaid\ngraph TD;\n"""
for node in self.node_list:
if node.outgoing_nodes:
for outgoing, _ in node.outgoing_nodes.items():
mermaid_md += f"{node.name} --> {outgoing.name};\n"
else:
mermaid_md += f"{node.name};\n"

end = """
```
"""
return mermaid_md + end
return mermaid_md + "\n```\n"

def leaf_only(self) -> "Graph":
leaf_modules = self.get_leaf_modules()
Expand Down
8 changes: 8 additions & 0 deletions nirtorch/to_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ def extract_nir_graph(
model_map: Callable[[nn.Module], nir.NIRNode],
sample_data: Any,
model_name: Optional[str] = "model",
ignore_submodules_of=None,
) -> nir.NIRNode:
"""Given a `model`, generate an NIR representation using the specified `model_map`.

Assumptions and known issues:
- Cannot deal with layers like torch.nn.Identity(), since the input tensor and output
tensor will be the same object, and therefore lead to cyclic connections.

Args:
model (nn.Module): The model of interest
model_map (Callable[[nn.Module], nir.NIRNode]): A method that converts a given
Expand All @@ -36,6 +41,9 @@ def extract_nir_graph(
model, sample_data=sample_data, model_name=model_name
).ignore_tensors()

if ignore_submodules_of is not None:
torch_graph = torch_graph.ignore_submodules_of(ignore_submodules_of)

# Get the root node
root_nodes = torch_graph.get_root()
if len(root_nodes) != 1:
Expand Down
Binary file added tests/braille.nir
Binary file not shown.
Binary file added tests/lif_norse.nir
Binary file not shown.
91 changes: 91 additions & 0 deletions tests/test_bidirectional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import nir
import numpy as np
import torch
import nirtorch


use_snntorch = False
# use_snntorch = True


if use_snntorch:
import snntorch as snn


def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module:
if isinstance(node, (nir.Linear, nir.Affine)):
return torch.nn.Linear(*node.weight.shape)

elif isinstance(node, (nir.LIF, nir.CubaLIF)):
return snn.Leaky(0.9, init_hidden=True)

else:
return None


def _nir_to_pytorch_module(node: nir.NIRNode) -> torch.nn.Module:
if isinstance(node, (nir.Linear, nir.Affine)):
return torch.nn.Linear(*node.weight.shape)

elif isinstance(node, (nir.LIF, nir.CubaLIF)):
return torch.nn.Linear(1, 1)

else:
return None


if use_snntorch:
_nir_to_torch_module = _nir_to_snntorch_module
else:
_nir_to_torch_module = _nir_to_pytorch_module


def _create_torch_model() -> torch.nn.Module:
if use_snntorch:
return torch.nn.Sequential(torch.nn.Linear(1, 1), snn.Leaky(0.9, init_hidden=True))
else:
return torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Identity())


def _torch_to_nir(module: torch.nn.Module) -> nir.NIRNode:
if isinstance(module, torch.nn.Linear):
return nir.Linear(np.array(module.weight.data))

else:
return None


def _lif_nir_graph(from_file=True):
if from_file:
return nir.read('tests/lif_norse.nir')
else:
return nir.NIRGraph(
nodes={
'0': nir.Affine(weight=np.array([[1.]]), bias=np.array([0.])),
'1': nir.LIF(
tau=np.array([0.1]),
r=np.array([1.]),
v_leak=np.array([0.]),
v_threshold=np.array([0.1])
),
'input': nir.Input(input_type={'input': np.array([1])}),
'output': nir.Output(output_type={'output': np.array([1])})
},
edges=[
('input', '0'), ('0', '1'), ('1', 'output')
]
)


def test_nir_to_torch_to_nir(from_file=True):
graph = _lif_nir_graph(from_file=from_file)
assert graph is not None
module = nirtorch.load(graph, _nir_to_torch_module)
assert module is not None
graph2 = nirtorch.extract_nir_graph(module, _torch_to_nir, torch.zeros(1, 1))
assert sorted(graph.edges) == sorted(graph2.edges)
assert graph2 is not None


# if __name__ == '__main__':
# test_nir_to_torch_to_nir(from_file=False)
Loading