-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added option to execute stateful submodules #13
Merged
Merged
Changes from 9 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
a0a2f47
Added option to execute stateful submodules
Jegp 1465f29
Returned state if stateful module
Jegp bbe54a0
Ruff
Jegp a64c23d
Added recurrent execution
Jegp 0a1d97d
Added tests for recurrent execution
Jegp 5cc87de
test for NIR -> NIRTorch -> NIR
stevenabreu7 70f4447
refactoring + expose ignore_submodules_of
stevenabreu7 175d34f
fix and test for issue #16
stevenabreu7 37e4237
fix recurrent test
stevenabreu7 c76fb6f
remove batch froms shape spec
sheiksadique 26cadf6
Merge branch 'main' into 17-input-node-retains-batch-dimension
sheiksadique 48f9842
bug from hell
stevenabreu7 26242d3
from_nir hacks for snnTorch
stevenabreu7 668e023
+ optional model.forward args for stateful modules
stevenabreu7 c555b2a
change subgraphs handlign (flatten + remove I/O)
stevenabreu7 60c01f8
model fwd args + ignore_dims arg
stevenabreu7 d4b1afb
[hack] remove wrong RNN self-connection (NIRTorch)
stevenabreu7 c736c0e
Added proper graph tracing
Jegp fe7188a
+ arg to ignore dims in to_nir
stevenabreu7 a21819f
add tests
stevenabreu7 bef454b
output_shape also uses ignore_dims
sheiksadique b95ad5c
Added test for flatten
Jegp 6c1d81e
Merged changes from #18
Jegp 3bc8bd2
minor correction to default value
sheiksadique 84e3cc8
Added ability to ignore state in executor
Jegp 8278437
Added flag in nirtorch parsing
Jegp ec8cded
Added flag in nirtorch parsing
Jegp 5845167
Merged sinabs test changes
Jegp 0325c80
minor changes to the doc strings
sheiksadique 53109c3
formatting fixes
sheiksadique File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import nir | ||
import numpy as np | ||
import torch | ||
import nirtorch | ||
|
||
|
||
use_snntorch = False | ||
# use_snntorch = True | ||
|
||
|
||
if use_snntorch: | ||
import snntorch as snn | ||
|
||
|
||
def _nir_to_snntorch_module(node: nir.NIRNode) -> torch.nn.Module: | ||
if isinstance(node, (nir.Linear, nir.Affine)): | ||
return torch.nn.Linear(*node.weight.shape) | ||
|
||
elif isinstance(node, (nir.LIF, nir.CubaLIF)): | ||
return snn.Leaky(0.9, init_hidden=True) | ||
|
||
else: | ||
return None | ||
|
||
|
||
def _nir_to_pytorch_module(node: nir.NIRNode) -> torch.nn.Module: | ||
if isinstance(node, (nir.Linear, nir.Affine)): | ||
return torch.nn.Linear(*node.weight.shape) | ||
|
||
elif isinstance(node, (nir.LIF, nir.CubaLIF)): | ||
return torch.nn.Linear(1, 1) | ||
|
||
else: | ||
return None | ||
|
||
|
||
if use_snntorch: | ||
_nir_to_torch_module = _nir_to_snntorch_module | ||
else: | ||
_nir_to_torch_module = _nir_to_pytorch_module | ||
|
||
|
||
def _create_torch_model() -> torch.nn.Module: | ||
if use_snntorch: | ||
return torch.nn.Sequential(torch.nn.Linear(1, 1), snn.Leaky(0.9, init_hidden=True)) | ||
else: | ||
return torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Identity()) | ||
|
||
|
||
def _torch_to_nir(module: torch.nn.Module) -> nir.NIRNode: | ||
if isinstance(module, torch.nn.Linear): | ||
return nir.Linear(np.array(module.weight.data)) | ||
|
||
else: | ||
return None | ||
|
||
|
||
def _lif_nir_graph(from_file=True): | ||
if from_file: | ||
return nir.read('tests/lif_norse.nir') | ||
else: | ||
return nir.NIRGraph( | ||
nodes={ | ||
'0': nir.Affine(weight=np.array([[1.]]), bias=np.array([0.])), | ||
'1': nir.LIF( | ||
tau=np.array([0.1]), | ||
r=np.array([1.]), | ||
v_leak=np.array([0.]), | ||
v_threshold=np.array([0.1]) | ||
), | ||
'input': nir.Input(input_type={'input': np.array([1])}), | ||
'output': nir.Output(output_type={'output': np.array([1])}) | ||
}, | ||
edges=[ | ||
('input', '0'), ('0', '1'), ('1', 'output') | ||
] | ||
) | ||
|
||
|
||
def test_nir_to_torch_to_nir(from_file=True): | ||
graph = _lif_nir_graph(from_file=from_file) | ||
assert graph is not None | ||
module = nirtorch.load(graph, _nir_to_torch_module) | ||
assert module is not None | ||
graph2 = nirtorch.extract_nir_graph(module, _torch_to_nir, torch.zeros(1, 1)) | ||
assert sorted(graph.edges) == sorted(graph2.edges) | ||
assert graph2 is not None | ||
|
||
|
||
# if __name__ == '__main__': | ||
# test_nir_to_torch_to_nir(from_file=False) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic here implies that if any module has multiple inputs, it will be assumed to be stateful. This is a deal breaker!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, we need to find a better way to implement this.. It currently breaks in snnTorch because you may have multiple inputs but not be stateful (if the node keeps track of its own hidden state by itself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy to find other ways of doing this. But how?
Here's the challenge as far as I can tell
state
parameter (similar to PyTorch RNNs)spk
andmem
inputsWould an option be to look for
state
in the arguments to account for the norse case andspk
andmem
to account for the snnTorch case?