From ff1cc60fd9c92d4ad1fcadcf7933f74a97af4cfd Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Thu, 8 Jun 2023 17:54:09 +0200 Subject: [PATCH 01/24] wip. basic graph exteaction and tracing added --- sinabs/graph.py | 220 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_graph.py | 76 +++++++++++++++ 2 files changed, 296 insertions(+) create mode 100644 sinabs/graph.py create mode 100644 tests/test_graph.py diff --git a/sinabs/graph.py b/sinabs/graph.py new file mode 100644 index 00000000..5524b710 --- /dev/null +++ b/sinabs/graph.py @@ -0,0 +1,220 @@ +import torch +import torchview +import torch.nn as nn +from typing import Union, Tuple, List, Callable, Any, Dict, Optional +from torchview import ComputationGraph +import torchview.torchview as tw +from torchview import RecorderTensor +from torchview.recorder_tensor import Recorder +from torchview import TensorNode +from torchview.computation_node import NodeContainer +import graphviz +import warnings + + +class Node: + def __init__( + self, + elem: Any, + name: str, + incoming_nodes: Optional[List["Node"]] = None, + outgoing_nodes: Optional[List["Node"]] = None, + ) -> None: + self.elem = elem + self.name = name + # Initialize if None + if not incoming_nodes: + self.incoming_nodes = [] + else: + self.incoming_nodes = incoming_nodes + # Initialize if None + if not outgoing_nodes: + self.outgoing_nodes = [] + else: + self.outgoing_nodes = outgoing_nodes + + def add_incoming(self, node: "Node"): + self.incoming_nodes.append(node) + + def add_outgoing(self, node: "Node"): + self.outgoing_nodes.append(node) + + def __str__(self) -> str: + return f"Node: {self.name}, I: {len(self.incoming_nodes)}, O: {len(self.outgoing_nodes)}" + + def __eq__(self, other: Any) -> bool: + # Two nodes are meant to be the same if they refer to the same element + try: + return self.elem is other.elem + except AttributeError: + return False + + def __hash__(self): + # Two nodes are same if they reference the same element + return hash(self.elem) + + +class Graph: + def __init__(self, mod: nn.Module) -> None: + self.elem_list = [] + self.node_map: Dict[Node, str] = {} + self.modules_map = named_modules_map(mod) + self.tensor_id_list = [] + + @property + def node_map_by_id(self): + return {v: k for k, v in self.node_map.items()} + + def get_unique_tensor_id(self): + if not self.tensor_id_list: + self.tensor_id_list.append(0) + return 0 + else: + self.tensor_id_list.append(self.tensor_id_list[-1] + 1) + return str(self.tensor_id_list[-1] + 1) + + + def __contains__(self, elem: Union[torch.Tensor, nn.Module]): + for elem_in_list in self.elem_list: + if elem is elem_in_list: + return True + return False + + def add_elem(self, elem, name: str): + if elem in self: + warnings.warn(f"{name}: Node already exists for this element ") + return self.find_node(elem) + else: + node = Node(elem, name) + self.elem_list.append(elem) + self.node_map[node] = name + return node + + def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): + if elem in self: + return self.find_node(elem) + else: + # Generate a name + if elem in self.modules_map: + name = self.modules_map[elem] + else: + assert isinstance(elem, torch.Tensor) + name = f"Tensor_{self.get_unique_tensor_id()}" + # add and return the node + new_node = self.add_elem(elem, name) + return new_node + + def find_node(self, elem: Union[torch.Tensor, nn.Module]): + for node in self.node_map.keys(): + if elem is node.elem: + return node + raise ValueError("elem not found") + + + def add_edge(self, source: Union[torch.Tensor, nn.Module], destination: Union[torch.Tensor, nn.Module]): + source_node = self.add_or_get_node_for_elem(source) + destination_node = self.add_or_get_node_for_elem(destination) + print(f"Adding edge {source_node.name} -> {destination_node.name}") + source_node.add_outgoing(destination_node) + destination_node.add_incoming(source_node) + return source_node, destination_node + + def __str__(self) -> str: + return "\n".join([f"{n}" for n in self.node_map.keys()]) + + def to_md(self)-> str: + mermaid_md = """ +```mermaid +graph TD; +""" + for node, _ in self.node_map.items(): + for outgoing in node.outgoing_nodes: + mermaid_md += f"{node.name} --> {outgoing.name};\n" + end = """ +``` +""" + return mermaid_md + end + + + +def process_input(input_data: torch.Tensor): + # Note: Works only when the input is a single tensor + # Convert to recorder tensor + recorder_tensor = input_data.as_subclass(RecorderTensor) + # Create a corresponding node for it + input_node = TensorNode(tensor=recorder_tensor, depth=0, name="Input") + recorder_tensor.tensor_nodes = [input_node] + return recorder_tensor, {}, NodeContainer(recorder_tensor.tensor_nodes) + + +_torch_module_call = torch.nn.Module.__call__ + + +def module_forward_wrapper(model_graph: Graph) -> Callable[..., Any]: + def _my_forward(mod: nn.Module, *args, **kwargs) -> Any: + # Iterate over all inputs + for i, input_data in enumerate(args): + # Create nodes and edges + model_graph.add_edge(input_data, mod) + out = _torch_module_call(mod, *args, **kwargs) + if isinstance(out, tuple): + out_tuple + elif isinstance(out, torch.Tensor): + out_tuple = out, + else: + raise Exception("Unknown output format") + # Iterate over all outputs and create nodes and edges + for i, output_data in enumerate(out_tuple): + # Create nodes and edges + model_graph.add_edge(mod, output_data) + return out + + return _my_forward + + +def forward_prop( + model: nn.Module, input_data: RecorderTensor, model_graph: ComputationGraph +): + model.eval() + model = model.to("cpu") + new_module_forward = module_forward_wrapper(model_graph) + with Recorder(_torch_module_call, new_module_forward, model_graph): + model(input_data) + return + + +def named_modules_map( + model: nn.Module, model_name: str = "model" +) -> Dict[str, nn.Module]: + """Inverse of named modules dictionary + + Args: + model (nn.Module): The module to be hashed + + Returns: + Dict[str, nn.Module]: A dictionary with modules as keys, and names as values + """ + modules_map = {} + for name, mod in model.named_modules(): + modules_map[mod] = name + modules_map[model] = model_name + return modules_map + + +def extract_graph(model: nn.Module, input_data: torch.Tensor) -> Graph: + # Modify the input somehow + input_record_tensor, kwargs_record_tensor, input_nodes = process_input(input_data) + # Create a graph + visual_graph = graphviz.Digraph( + name="..", engine="dot", strict=True, filename="somefile.dot" + ) + model_graph = ComputationGraph( + visual_graph=visual_graph, root_container=input_nodes + ) + + # Populate it + forward_prop(model, input_record_tensor, model_graph=model_graph) + + model_graph.fill_visual_graph() + + return model_graph diff --git a/tests/test_graph.py b/tests/test_graph.py new file mode 100644 index 00000000..7ba89952 --- /dev/null +++ b/tests/test_graph.py @@ -0,0 +1,76 @@ +import torchinfo +import torch.nn as nn +import torch +from torchview import draw_graph +from sinabs.graph import extract_graph, process_input + + +# Branched model +class MyBranchedModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.relu1 = nn.ReLU() + self.relu2_1 = nn.ReLU() + self.relu2_2 = nn.ReLU() + self.relu3 = nn.ReLU() + + def forward(self, data): + out1 = self.relu1(data) + out2_1 = self.relu2_1(out1) + out2_2 = self.relu2_2(out1) + out3 = self.relu3(out2_1 + out2_2) + out3.foo = "foo" + return out3 + + +input_shape = (2, 28, 28) +batch_size = 1 + +data = torch.ones((batch_size, *input_shape)) + +mymodel = MyBranchedModel() + +torchinfo.summary(mymodel, input_data=data) + + +def test_process_input(): + out = process_input(input_data=data) + print(out) + + + +def test_extract_graph(): + + model_graph = extract_graph(mymodel, input_data=data) + model_graph.visual_graph.save("branched_graph.dot") + + for id, node in model_graph.id_dict.items(): + print(type(id), type(node)) + + +def test_named_modules_map(): + from sinabs.graph import named_modules_map + mod_map = named_modules_map(mymodel) + print(mod_map) + + +#def test_module_forward_wrapper(): +mymodel = MyBranchedModel() + +orig_call = nn.Module.__call__ + +from sinabs.graph import Graph, module_forward_wrapper + +model_graph = Graph(mymodel) +new_call = module_forward_wrapper(model_graph) + +# Override call to the new wrapped call +nn.Module.__call__ = new_call + +with torch.no_grad(): + out = mymodel(data) + +# Restore normal behavior +nn.Module.__call__ = orig_call + +print(model_graph.to_md()) \ No newline at end of file From 5084cec9d2d29d324884665b5e7ae1ca642ddfa2 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 00:09:53 +0200 Subject: [PATCH 02/24] added context manager --- sinabs/graph.py | 116 +++++++++++++++++++------------------------- tests/test_graph.py | 56 +++++++++------------ 2 files changed, 73 insertions(+), 99 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 5524b710..13cac6ed 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -1,17 +1,30 @@ import torch -import torchview import torch.nn as nn from typing import Union, Tuple, List, Callable, Any, Dict, Optional from torchview import ComputationGraph -import torchview.torchview as tw from torchview import RecorderTensor from torchview.recorder_tensor import Recorder -from torchview import TensorNode -from torchview.computation_node import NodeContainer -import graphviz import warnings +def named_modules_map( + model: nn.Module, model_name: str = "model" +) -> Dict[str, nn.Module]: + """Inverse of named modules dictionary + + Args: + model (nn.Module): The module to be hashed + + Returns: + Dict[str, nn.Module]: A dictionary with modules as keys, and names as values + """ + modules_map = {} + for name, mod in model.named_modules(): + modules_map[mod] = name + modules_map[model] = model_name + return modules_map + + class Node: def __init__( self, @@ -73,7 +86,6 @@ def get_unique_tensor_id(self): self.tensor_id_list.append(self.tensor_id_list[-1] + 1) return str(self.tensor_id_list[-1] + 1) - def __contains__(self, elem: Union[torch.Tensor, nn.Module]): for elem_in_list in self.elem_list: if elem is elem_in_list: @@ -89,7 +101,7 @@ def add_elem(self, elem, name: str): self.elem_list.append(elem) self.node_map[node] = name return node - + def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): if elem in self: return self.find_node(elem) @@ -97,32 +109,34 @@ def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): # Generate a name if elem in self.modules_map: name = self.modules_map[elem] - else: + else: assert isinstance(elem, torch.Tensor) name = f"Tensor_{self.get_unique_tensor_id()}" # add and return the node new_node = self.add_elem(elem, name) return new_node - + def find_node(self, elem: Union[torch.Tensor, nn.Module]): for node in self.node_map.keys(): if elem is node.elem: return node raise ValueError("elem not found") - - def add_edge(self, source: Union[torch.Tensor, nn.Module], destination: Union[torch.Tensor, nn.Module]): + def add_edge( + self, + source: Union[torch.Tensor, nn.Module], + destination: Union[torch.Tensor, nn.Module], + ): source_node = self.add_or_get_node_for_elem(source) destination_node = self.add_or_get_node_for_elem(destination) - print(f"Adding edge {source_node.name} -> {destination_node.name}") source_node.add_outgoing(destination_node) destination_node.add_incoming(source_node) return source_node, destination_node - + def __str__(self) -> str: return "\n".join([f"{n}" for n in self.node_map.keys()]) - def to_md(self)-> str: + def to_md(self) -> str: mermaid_md = """ ```mermaid graph TD; @@ -137,21 +151,12 @@ def to_md(self)-> str: -def process_input(input_data: torch.Tensor): - # Note: Works only when the input is a single tensor - # Convert to recorder tensor - recorder_tensor = input_data.as_subclass(RecorderTensor) - # Create a corresponding node for it - input_node = TensorNode(tensor=recorder_tensor, depth=0, name="Input") - recorder_tensor.tensor_nodes = [input_node] - return recorder_tensor, {}, NodeContainer(recorder_tensor.tensor_nodes) - _torch_module_call = torch.nn.Module.__call__ def module_forward_wrapper(model_graph: Graph) -> Callable[..., Any]: - def _my_forward(mod: nn.Module, *args, **kwargs) -> Any: + def my_forward(mod: nn.Module, *args, **kwargs) -> Any: # Iterate over all inputs for i, input_data in enumerate(args): # Create nodes and edges @@ -160,7 +165,7 @@ def _my_forward(mod: nn.Module, *args, **kwargs) -> Any: if isinstance(out, tuple): out_tuple elif isinstance(out, torch.Tensor): - out_tuple = out, + out_tuple = (out,) else: raise Exception("Unknown output format") # Iterate over all outputs and create nodes and edges @@ -169,52 +174,31 @@ def _my_forward(mod: nn.Module, *args, **kwargs) -> Any: model_graph.add_edge(mod, output_data) return out - return _my_forward - - -def forward_prop( - model: nn.Module, input_data: RecorderTensor, model_graph: ComputationGraph -): - model.eval() - model = model.to("cpu") - new_module_forward = module_forward_wrapper(model_graph) - with Recorder(_torch_module_call, new_module_forward, model_graph): - model(input_data) - return - - -def named_modules_map( - model: nn.Module, model_name: str = "model" -) -> Dict[str, nn.Module]: - """Inverse of named modules dictionary + return my_forward - Args: - model (nn.Module): The module to be hashed - Returns: - Dict[str, nn.Module]: A dictionary with modules as keys, and names as values +class GraphTracer: """ - modules_map = {} - for name, mod in model.named_modules(): - modules_map[mod] = name - modules_map[model] = model_name - return modules_map + Context manager to trace a model's execution graph + Example: -def extract_graph(model: nn.Module, input_data: torch.Tensor) -> Graph: - # Modify the input somehow - input_record_tensor, kwargs_record_tensor, input_nodes = process_input(input_data) - # Create a graph - visual_graph = graphviz.Digraph( - name="..", engine="dot", strict=True, filename="somefile.dot" - ) - model_graph = ComputationGraph( - visual_graph=visual_graph, root_container=input_nodes - ) + ```python + with GraphTracer(mymodel) as tracer, torch.no_grad(): + out = mymodel(data) - # Populate it - forward_prop(model, input_record_tensor, model_graph=model_graph) + print(tracer.graph.to_md()) + ``` + """ + def __init__(self, mod: nn.Module) -> None: + self.original_torch_call = nn.Module.__call__ + self.graph = Graph(mod) - model_graph.fill_visual_graph() + def __enter__(self)->"GraphTracer": + # Override the torch call method + nn.Module.__call__ = module_forward_wrapper(self.graph) + return self - return model_graph + def __exit__(self, exc_type, exc_value, exc_tb): + # Restore normal behavior + nn.Module.__call__ = self.original_torch_call \ No newline at end of file diff --git a/tests/test_graph.py b/tests/test_graph.py index 7ba89952..4bd3ca8b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,8 +1,5 @@ -import torchinfo -import torch.nn as nn import torch -from torchview import draw_graph -from sinabs.graph import extract_graph, process_input +import torch.nn as nn # Branched model @@ -30,47 +27,40 @@ def forward(self, data): mymodel = MyBranchedModel() -torchinfo.summary(mymodel, input_data=data) - - -def test_process_input(): - out = process_input(input_data=data) - print(out) +def test_named_modules_map(): + from sinabs.graph import named_modules_map + mod_map = named_modules_map(mymodel) + print(mod_map) -def test_extract_graph(): - - model_graph = extract_graph(mymodel, input_data=data) - model_graph.visual_graph.save("branched_graph.dot") - for id, node in model_graph.id_dict.items(): - print(type(id), type(node)) +def test_module_forward_wrapper(): + mymodel = MyBranchedModel() + orig_call = nn.Module.__call__ -def test_named_modules_map(): - from sinabs.graph import named_modules_map - mod_map = named_modules_map(mymodel) - print(mod_map) + from sinabs.graph import Graph, module_forward_wrapper + model_graph = Graph(mymodel) + new_call = module_forward_wrapper(model_graph) -#def test_module_forward_wrapper(): -mymodel = MyBranchedModel() + # Override call to the new wrapped call + nn.Module.__call__ = new_call -orig_call = nn.Module.__call__ + with torch.no_grad(): + out = mymodel(data) -from sinabs.graph import Graph, module_forward_wrapper + # Restore normal behavior + nn.Module.__call__ = orig_call -model_graph = Graph(mymodel) -new_call = module_forward_wrapper(model_graph) + print(model_graph.to_md()) -# Override call to the new wrapped call -nn.Module.__call__ = new_call -with torch.no_grad(): - out = mymodel(data) +def test_graph_tracer(): + from sinabs.graph import GraphTracer -# Restore normal behavior -nn.Module.__call__ = orig_call + with GraphTracer(mymodel) as tracer, torch.no_grad(): + out = mymodel(data) -print(model_graph.to_md()) \ No newline at end of file + print(tracer.graph.to_md()) \ No newline at end of file From 7d8ff0b8df3edcc61a456a730ab36b72fcf6ece6 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 16:42:42 +0200 Subject: [PATCH 03/24] removing redundant incoming nodes attribute --- sinabs/graph.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 13cac6ed..258c2331 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -30,30 +30,20 @@ def __init__( self, elem: Any, name: str, - incoming_nodes: Optional[List["Node"]] = None, outgoing_nodes: Optional[List["Node"]] = None, ) -> None: self.elem = elem self.name = name - # Initialize if None - if not incoming_nodes: - self.incoming_nodes = [] - else: - self.incoming_nodes = incoming_nodes - # Initialize if None if not outgoing_nodes: self.outgoing_nodes = [] else: self.outgoing_nodes = outgoing_nodes - def add_incoming(self, node: "Node"): - self.incoming_nodes.append(node) - def add_outgoing(self, node: "Node"): self.outgoing_nodes.append(node) def __str__(self) -> str: - return f"Node: {self.name}, I: {len(self.incoming_nodes)}, O: {len(self.outgoing_nodes)}" + return f"Node: {self.name}, Out: {len(self.outgoing_nodes)}" def __eq__(self, other: Any) -> bool: # Two nodes are meant to be the same if they refer to the same element @@ -130,7 +120,6 @@ def add_edge( source_node = self.add_or_get_node_for_elem(source) destination_node = self.add_or_get_node_for_elem(destination) source_node.add_outgoing(destination_node) - destination_node.add_incoming(source_node) return source_node, destination_node def __str__(self) -> str: @@ -150,8 +139,6 @@ def to_md(self) -> str: return mermaid_md + end - - _torch_module_call = torch.nn.Module.__call__ @@ -187,18 +174,19 @@ class GraphTracer: with GraphTracer(mymodel) as tracer, torch.no_grad(): out = mymodel(data) - print(tracer.graph.to_md()) + print(tracer.graph.to_md()) ``` """ + def __init__(self, mod: nn.Module) -> None: self.original_torch_call = nn.Module.__call__ self.graph = Graph(mod) - def __enter__(self)->"GraphTracer": + def __enter__(self) -> "GraphTracer": # Override the torch call method nn.Module.__call__ = module_forward_wrapper(self.graph) return self def __exit__(self, exc_type, exc_value, exc_tb): # Restore normal behavior - nn.Module.__call__ = self.original_torch_call \ No newline at end of file + nn.Module.__call__ = self.original_torch_call From 010c89294d2b6b4f4f3d6c5edcbfce025c06b37f Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 17:27:50 +0200 Subject: [PATCH 04/24] updated graph definition. --- sinabs/graph.py | 9 +++++---- tests/test_graph.py | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 258c2331..91a15040 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -58,10 +58,10 @@ def __hash__(self): class Graph: - def __init__(self, mod: nn.Module) -> None: + def __init__(self, module_names: Dict[nn.Module, str]) -> None: self.elem_list = [] self.node_map: Dict[Node, str] = {} - self.modules_map = named_modules_map(mod) + self.module_names = module_names self.tensor_id_list = [] @property @@ -97,8 +97,8 @@ def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): return self.find_node(elem) else: # Generate a name - if elem in self.modules_map: - name = self.modules_map[elem] + if elem in self.module_names: + name = self.module_names[elem] else: assert isinstance(elem, torch.Tensor) name = f"Tensor_{self.get_unique_tensor_id()}" @@ -139,6 +139,7 @@ def to_md(self) -> str: return mermaid_md + end + _torch_module_call = torch.nn.Module.__call__ diff --git a/tests/test_graph.py b/tests/test_graph.py index 4bd3ca8b..d1e5285f 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -40,9 +40,9 @@ def test_module_forward_wrapper(): orig_call = nn.Module.__call__ - from sinabs.graph import Graph, module_forward_wrapper + from sinabs.graph import Graph, module_forward_wrapper, named_modules_map - model_graph = Graph(mymodel) + model_graph = Graph(named_modules_map(mymodel)) new_call = module_forward_wrapper(model_graph) # Override call to the new wrapped call @@ -58,9 +58,9 @@ def test_module_forward_wrapper(): def test_graph_tracer(): - from sinabs.graph import GraphTracer + from sinabs.graph import GraphTracer, named_modules_map - with GraphTracer(mymodel) as tracer, torch.no_grad(): + with GraphTracer(named_modules_map(mymodel)) as tracer, torch.no_grad(): out = mymodel(data) - print(tracer.graph.to_md()) \ No newline at end of file + print(tracer.graph.to_md()) From fada64f7402e73c9ead58e78d52711b927757d5f Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 22:41:56 +0200 Subject: [PATCH 05/24] Added methods to simplify the graph --- sinabs/graph.py | 98 ++++++++++++++++++++++++++++++++++++++++----- tests/test_graph.py | 55 ++++++++++++++++++++++--- 2 files changed, 137 insertions(+), 16 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 91a15040..6dcf722d 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from typing import Union, Tuple, List, Callable, Any, Dict, Optional +from typing import Union, Tuple, List, Callable, Any, Dict, Optional, Type from torchview import ComputationGraph from torchview import RecorderTensor from torchview.recorder_tensor import Recorder @@ -60,36 +60,36 @@ def __hash__(self): class Graph: def __init__(self, module_names: Dict[nn.Module, str]) -> None: self.elem_list = [] - self.node_map: Dict[Node, str] = {} + self.node_list: List[Node] = [] self.module_names = module_names self.tensor_id_list = [] @property def node_map_by_id(self): - return {v: k for k, v in self.node_map.items()} + return {n.name: n for n in self.node_list} - def get_unique_tensor_id(self): + def get_unique_tensor_id(self)->str: if not self.tensor_id_list: self.tensor_id_list.append(0) - return 0 + return str(0) else: self.tensor_id_list.append(self.tensor_id_list[-1] + 1) return str(self.tensor_id_list[-1] + 1) - def __contains__(self, elem: Union[torch.Tensor, nn.Module]): + def __contains__(self, elem: Union[torch.Tensor, nn.Module])->bool: for elem_in_list in self.elem_list: if elem is elem_in_list: return True return False - def add_elem(self, elem, name: str): + def add_elem(self, elem, name: str)->Node: if elem in self: warnings.warn(f"{name}: Node already exists for this element ") return self.find_node(elem) else: node = Node(elem, name) self.elem_list.append(elem) - self.node_map[node] = name + self.node_list.append(node) return node def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): @@ -107,7 +107,7 @@ def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): return new_node def find_node(self, elem: Union[torch.Tensor, nn.Module]): - for node in self.node_map.keys(): + for node in self.node_list: if elem is node.elem: return node raise ValueError("elem not found") @@ -122,15 +122,58 @@ def add_edge( source_node.add_outgoing(destination_node) return source_node, destination_node + def get_leaf_modules(self) -> Dict[nn.Module, str]: + filtered_module_names = {} + + for mod, _ in self.module_names.items(): + # Add module to dict + filtered_module_names[mod] = self.module_names[mod] + child_in_graph = False + for _, submod in mod.named_children(): + if submod in self: + child_in_graph = True + break + if child_in_graph: + del filtered_module_names[mod] + return filtered_module_names + + + def populate_from(self, other_graph: "Graph"): + + def is_mod_and_not_in_module_names(node: Node)->bool: + """Check if a node is a module and is included in the module_names of this graph + + Args: + node (Node): Node to verify + + Returns: + bool + """ + if isinstance(node.elem, nn.Module) and node.elem not in self.module_names: + return True + else: + return False + + for node in other_graph.node_list: + if is_mod_and_not_in_module_names(node): + # Skip if not included in the module names + continue + for outgoing_node in node.outgoing_nodes: + if is_mod_and_not_in_module_names(outgoing_node): + # Skip if not included in the module names + continue + else: + self.add_edge(node.elem, outgoing_node.elem) + def __str__(self) -> str: - return "\n".join([f"{n}" for n in self.node_map.keys()]) + return self.to_md() def to_md(self) -> str: mermaid_md = """ ```mermaid graph TD; """ - for node, _ in self.node_map.items(): + for node in self.node_list: for outgoing in node.outgoing_nodes: mermaid_md += f"{node.name} --> {outgoing.name};\n" end = """ @@ -138,6 +181,39 @@ def to_md(self) -> str: """ return mermaid_md + end + def leaf_only(self) -> "Graph": + leaf_modules = self.get_leaf_modules() + filtered_graph = Graph(leaf_modules) + # Populate edges + filtered_graph.populate_from(self) + return filtered_graph + + + def ignore_submodules_of(self, classes: List[Type])->"Graph": + new_named_modules = {} + + # Gather a list of all top level modules, whose submodules are to be ignored + top_level_modules: List[nn.Module] = [] + for mod in self.module_names.keys(): + if mod.__class__ in classes: + top_level_modules.append(mod) + + # List all the submodules of the above module list + sub_modules_to_ignore: List[nn.Module] = [] + for top_mod in top_level_modules: + for sub_mod in top_mod.modules(): + if sub_mod is not top_mod: + sub_modules_to_ignore.append(sub_mod) + + # Iterate over all modules and check if they are submodules of the above list + for mod, name in self.module_names.items(): + if mod not in sub_modules_to_ignore: + new_named_modules[mod] = name + # Create a new graph with the allowed modules + new_graph = Graph(new_named_modules) + new_graph.populate_from(self) + return new_graph + _torch_module_call = torch.nn.Module.__call__ diff --git a/tests/test_graph.py b/tests/test_graph.py index d1e5285f..adab5e14 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -2,6 +2,13 @@ import torch.nn as nn +class Add(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, data1, data2): + return data1 + data2 + # Branched model class MyBranchedModel(nn.Module): def __init__(self) -> None: @@ -9,15 +16,16 @@ def __init__(self) -> None: self.relu1 = nn.ReLU() self.relu2_1 = nn.ReLU() self.relu2_2 = nn.ReLU() + self.add_mod = Add() self.relu3 = nn.ReLU() def forward(self, data): out1 = self.relu1(data) out2_1 = self.relu2_1(out1) out2_2 = self.relu2_2(out1) - out3 = self.relu3(out2_1 + out2_2) - out3.foo = "foo" - return out3 + out3 = self.add_mod(out2_1, out2_2) + out4 = self.relu3(out3) + return out4 input_shape = (2, 28, 28) @@ -27,6 +35,18 @@ def forward(self, data): mymodel = MyBranchedModel() +class DeepModel(nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.block1 = MyBranchedModel() + self.block2 = MyBranchedModel() + + def forward(self, data): + out = self.block1(data) + out2 = self.block2(out) + return out2 + +mydeepmodel = DeepModel() def test_named_modules_map(): from sinabs.graph import named_modules_map @@ -54,7 +74,7 @@ def test_module_forward_wrapper(): # Restore normal behavior nn.Module.__call__ = orig_call - print(model_graph.to_md()) + print(model_graph) def test_graph_tracer(): @@ -63,4 +83,29 @@ def test_graph_tracer(): with GraphTracer(named_modules_map(mymodel)) as tracer, torch.no_grad(): out = mymodel(data) - print(tracer.graph.to_md()) + print(tracer.graph) + + +def test_leaf_only_graph(): + from sinabs.graph import GraphTracer, named_modules_map + + with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): + out = mydeepmodel(data) + + + print(tracer.graph) + + # Get graph with just the leaf nodes + leaf_graph = tracer.graph.leaf_only() + print(leaf_graph) + + +def test_ignore_submodules_of(): + from sinabs.graph import GraphTracer, named_modules_map + + with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): + out = mydeepmodel(data) + + top_overview_graph = tracer.graph.ignore_submodules_of([MyBranchedModel]).leaf_only() + print(top_overview_graph) + From 664d0ba263b637ffcc9deac5f4073bf55c397793 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 23:04:37 +0200 Subject: [PATCH 06/24] only saving index of last used tensor id --- sinabs/graph.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 6dcf722d..7c1686eb 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -59,22 +59,21 @@ def __hash__(self): class Graph: def __init__(self, module_names: Dict[nn.Module, str]) -> None: + self.module_names = module_names self.elem_list = [] self.node_list: List[Node] = [] - self.module_names = module_names - self.tensor_id_list = [] + self._last_used_tensor_id = None @property def node_map_by_id(self): return {n.name: n for n in self.node_list} def get_unique_tensor_id(self)->str: - if not self.tensor_id_list: - self.tensor_id_list.append(0) - return str(0) + if self._last_used_tensor_id is None: + self._last_used_tensor_id = 0 else: - self.tensor_id_list.append(self.tensor_id_list[-1] + 1) - return str(self.tensor_id_list[-1] + 1) + self._last_used_tensor_id += 1 + return str(self._last_used_tensor_id) def __contains__(self, elem: Union[torch.Tensor, nn.Module])->bool: for elem_in_list in self.elem_list: From bd5464722c5cdafe9bf7e6e81b3fbbb16ae34ca6 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Fri, 9 Jun 2023 23:16:31 +0200 Subject: [PATCH 07/24] added sensible tests conditions in place of prints --- sinabs/graph.py | 22 ++++++++++++---------- tests/test_graph.py | 29 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 7c1686eb..f6767e17 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -68,20 +68,26 @@ def __init__(self, module_names: Dict[nn.Module, str]) -> None: def node_map_by_id(self): return {n.name: n for n in self.node_list} - def get_unique_tensor_id(self)->str: + def num_edges(self) -> int: + count = 0 + for node in self.node_list: + count += node.outgoing_nodes + return count + + def get_unique_tensor_id(self) -> str: if self._last_used_tensor_id is None: self._last_used_tensor_id = 0 else: self._last_used_tensor_id += 1 return str(self._last_used_tensor_id) - def __contains__(self, elem: Union[torch.Tensor, nn.Module])->bool: + def __contains__(self, elem: Union[torch.Tensor, nn.Module]) -> bool: for elem_in_list in self.elem_list: if elem is elem_in_list: return True return False - def add_elem(self, elem, name: str)->Node: + def add_elem(self, elem, name: str) -> Node: if elem in self: warnings.warn(f"{name}: Node already exists for this element ") return self.find_node(elem) @@ -136,10 +142,8 @@ def get_leaf_modules(self) -> Dict[nn.Module, str]: del filtered_module_names[mod] return filtered_module_names - def populate_from(self, other_graph: "Graph"): - - def is_mod_and_not_in_module_names(node: Node)->bool: + def is_mod_and_not_in_module_names(node: Node) -> bool: """Check if a node is a module and is included in the module_names of this graph Args: @@ -150,7 +154,7 @@ def is_mod_and_not_in_module_names(node: Node)->bool: """ if isinstance(node.elem, nn.Module) and node.elem not in self.module_names: return True - else: + else: return False for node in other_graph.node_list: @@ -187,8 +191,7 @@ def leaf_only(self) -> "Graph": filtered_graph.populate_from(self) return filtered_graph - - def ignore_submodules_of(self, classes: List[Type])->"Graph": + def ignore_submodules_of(self, classes: List[Type]) -> "Graph": new_named_modules = {} # Gather a list of all top level modules, whose submodules are to be ignored @@ -214,7 +217,6 @@ def ignore_submodules_of(self, classes: List[Type])->"Graph": return new_graph - _torch_module_call = torch.nn.Module.__call__ diff --git a/tests/test_graph.py b/tests/test_graph.py index adab5e14..2b148eb2 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -5,10 +5,11 @@ class Add(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - + def forward(self, data1, data2): return data1 + data2 + # Branched model class MyBranchedModel(nn.Module): def __init__(self) -> None: @@ -35,24 +36,30 @@ def forward(self, data): mymodel = MyBranchedModel() + class DeepModel(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.block1 = MyBranchedModel() self.block2 = MyBranchedModel() - + def forward(self, data): out = self.block1(data) out2 = self.block2(out) return out2 + mydeepmodel = DeepModel() + def test_named_modules_map(): from sinabs.graph import named_modules_map mod_map = named_modules_map(mymodel) print(mod_map) + for k, v in mod_map.items(): + assert isinstance(k, nn.Module) + assert isinstance(v, str) def test_module_forward_wrapper(): @@ -75,6 +82,9 @@ def test_module_forward_wrapper(): nn.Module.__call__ = orig_call print(model_graph) + assert ( + len(model_graph.node_list) == 1 + 5 + 5 + 1 + ) # 1 top module + 5 submodules + 5 tensors + 1 output tensor def test_graph_tracer(): @@ -84,6 +94,9 @@ def test_graph_tracer(): out = mymodel(data) print(tracer.graph) + assert ( + len(tracer.graph.node_list) == 1 + 5 + 5 + 1 + ) # 1 top module + 5 submodules + 5 tensors + 1 output tensor def test_leaf_only_graph(): @@ -92,13 +105,15 @@ def test_leaf_only_graph(): with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): out = mydeepmodel(data) - print(tracer.graph) # Get graph with just the leaf nodes leaf_graph = tracer.graph.leaf_only() print(leaf_graph) - + assert ( + len(leaf_graph.node_list) == len(tracer.graph.node_list) - 3 + ) # No more top modules + def test_ignore_submodules_of(): from sinabs.graph import GraphTracer, named_modules_map @@ -106,6 +121,8 @@ def test_ignore_submodules_of(): with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): out = mydeepmodel(data) - top_overview_graph = tracer.graph.ignore_submodules_of([MyBranchedModel]).leaf_only() + top_overview_graph = tracer.graph.ignore_submodules_of( + [MyBranchedModel] + ).leaf_only() print(top_overview_graph) - + assert len(top_overview_graph.node_list) == 2 + 2 + 1 From fb53a0f38ec98a92fe8ecb947db5c257092a3784 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sat, 10 Jun 2023 17:05:40 +0200 Subject: [PATCH 08/24] removed torchview imports --- sinabs/graph.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index f6767e17..2466cfff 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -1,10 +1,8 @@ +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + import torch import torch.nn as nn -from typing import Union, Tuple, List, Callable, Any, Dict, Optional, Type -from torchview import ComputationGraph -from torchview import RecorderTensor -from torchview.recorder_tensor import Recorder -import warnings def named_modules_map( From 5d4f8dce54b100daea4d1d5ace06f06264e19ef0 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:30:53 +0200 Subject: [PATCH 09/24] added convenience method for graph extraction --- sinabs/graph.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 2466cfff..4b77c1b6 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -104,7 +104,7 @@ def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): name = self.module_names[elem] else: assert isinstance(elem, torch.Tensor) - name = f"Tensor_{self.get_unique_tensor_id()}" + name = f"Tensor_{self.get_unique_tensor_id()}{tuple(elem.shape)}" # add and return the node new_node = self.add_elem(elem, name) return new_node @@ -266,3 +266,11 @@ def __enter__(self) -> "GraphTracer": def __exit__(self, exc_type, exc_value, exc_tb): # Restore normal behavior nn.Module.__call__ = self.original_torch_call + + + +def extract_graph(model: nn.Module, sample_data: Any)->Graph: + with GraphTracer(named_modules_map(model)) as tracer, torch.no_grad(): + out = model(sample_data) + + return tracer.graph \ No newline at end of file From 7ee1f8ea4d94268d3f733a1c794b77397083c0ac Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:31:19 +0200 Subject: [PATCH 10/24] added modules for addition and concatenation --- sinabs/layers/__init__.py | 2 ++ sinabs/layers/add_module.py | 12 ++++++++++++ sinabs/layers/channel_concat.py | 10 ++++++++++ 3 files changed, 24 insertions(+) create mode 100644 sinabs/layers/add_module.py create mode 100644 sinabs/layers/channel_concat.py diff --git a/sinabs/layers/__init__.py b/sinabs/layers/__init__.py index b47c4bfc..6001bc4e 100644 --- a/sinabs/layers/__init__.py +++ b/sinabs/layers/__init__.py @@ -9,3 +9,5 @@ from .reshape import FlattenTime, Repeat, SqueezeMixin, UnflattenTime from .stateful_layer import StatefulLayer from .to_spike import Img2SpikeLayer, Sig2SpikeLayer +from .add_module import Add +from .channel_concat import ConcatenateChannel \ No newline at end of file diff --git a/sinabs/layers/add_module.py b/sinabs/layers/add_module.py new file mode 100644 index 00000000..821430dc --- /dev/null +++ b/sinabs/layers/add_module.py @@ -0,0 +1,12 @@ +import torch.nn as nn + +class Add(nn.Module): + def __init__(self, *args, **kwargs) -> None: + """ + Module form for a simple addition operation. + In the context of events/spikes, events/spikes from two different sources/rasters will be added. + """ + super().__init__(*args, **kwargs) + + def forward(self, data1, data2): + return data1 + data2 \ No newline at end of file diff --git a/sinabs/layers/channel_concat.py b/sinabs/layers/channel_concat.py new file mode 100644 index 00000000..0f8a675b --- /dev/null +++ b/sinabs/layers/channel_concat.py @@ -0,0 +1,10 @@ +import torch +import torch.nn as nn + +class ConcatenateChannel(nn.Module): + def __init__(self, channel_axis=-3) -> None: + super().__init__() + self.channel_axis = -3 + + def forward(self, x, y): + return torch.concat((x, y), self.channel_axis) \ No newline at end of file From b1fb43d4f58ecd2706c9e1b677cb47dfc7d68794 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:31:29 +0200 Subject: [PATCH 11/24] added test for branched SNN --- tests/test_graph.py | 57 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/tests/test_graph.py b/tests/test_graph.py index 2b148eb2..a25f7154 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,13 +1,7 @@ import torch import torch.nn as nn - -class Add(nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def forward(self, data1, data2): - return data1 + data2 +from sinabs.layers import Add # Branched model @@ -126,3 +120,52 @@ def test_ignore_submodules_of(): ).leaf_only() print(top_overview_graph) assert len(top_overview_graph.node_list) == 2 + 2 + 1 + + + +def test_snn_branched(): + from sinabs.layers import IAFSqueeze, ConcatenateChannel, SumPool2d + from torch.nn import Conv2d + from sinabs.graph import extract_graph + + + class MySNN(nn.Module): + def __init__(self) -> None: + super().__init__() + self.conv1 = Conv2d(2, 8, 3, bias=False) + self.iaf1 = IAFSqueeze(batch_size=1) + self.pool1 = SumPool2d(2) + self.conv2_1 = Conv2d(8, 16, 3, stride=1, padding=1, bias=False) + self.iaf2_1 = IAFSqueeze(batch_size=1) + self.pool2_1 = SumPool2d(2) + self.conv2_2 = Conv2d(8, 16, 5, stride=1, padding=2, bias=False) + self.iaf2_2 = IAFSqueeze(batch_size=1) + self.pool2_2 = SumPool2d(2) + self.concat = ConcatenateChannel() + self.conv3 = Conv2d(32, 10, 3, stride=3, bias=False) + self.iaf3 = IAFSqueeze(batch_size=1) + + def forward(self, spikes): + out = self.conv1(spikes) + out = self.iaf1(out) + out = self.pool1(out) + + out1 = self.conv2_1(out) + out1 = self.iaf2_1(out1) + out1 = self.pool2_1(out1) + + out2 = self.conv2_2(out) + out2 = self.iaf2_2(out2) + out2 = self.pool2_2(out2) + + out = self.concat(out1, out2) + out = self.conv3(out) + out = self.iaf3(out) + return out + + my_snn = MySNN() + graph = extract_graph(my_snn, sample_data=torch.rand((100, 2, 14, 14))) + + print(graph) + + From 745b153be3a3e04ff2f1baa1e911caa8a68cbaf8 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:49:41 +0200 Subject: [PATCH 12/24] added optional model name to ignore it from graph --- sinabs/graph.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 4b77c1b6..26423357 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -6,12 +6,13 @@ def named_modules_map( - model: nn.Module, model_name: str = "model" + model: nn.Module, model_name: Optional[str] = "model" ) -> Dict[str, nn.Module]: """Inverse of named modules dictionary Args: model (nn.Module): The module to be hashed + model_name (str | None): Name of the top level module. If this doesn't need to be include, this option can be set to None Returns: Dict[str, nn.Module]: A dictionary with modules as keys, and names as values @@ -19,7 +20,10 @@ def named_modules_map( modules_map = {} for name, mod in model.named_modules(): modules_map[mod] = name - modules_map[model] = model_name + if model_name is None: + del modules_map[model] + else: + modules_map[model] = model_name return modules_map @@ -120,6 +124,9 @@ def add_edge( source: Union[torch.Tensor, nn.Module], destination: Union[torch.Tensor, nn.Module], ): + if self._is_mod_and_not_in_module_names(source): return + if self._is_mod_and_not_in_module_names(destination): return + source_node = self.add_or_get_node_for_elem(source) destination_node = self.add_or_get_node_for_elem(destination) source_node.add_outgoing(destination_node) @@ -140,31 +147,24 @@ def get_leaf_modules(self) -> Dict[nn.Module, str]: del filtered_module_names[mod] return filtered_module_names - def populate_from(self, other_graph: "Graph"): - def is_mod_and_not_in_module_names(node: Node) -> bool: - """Check if a node is a module and is included in the module_names of this graph + def _is_mod_and_not_in_module_names(self, elem: Any) -> bool: + """Check if a node is a module and is included in the module_names of this graph - Args: - node (Node): Node to verify + Args: + node (Node): Node to verify - Returns: - bool - """ - if isinstance(node.elem, nn.Module) and node.elem not in self.module_names: - return True - else: - return False + Returns: + bool + """ + if isinstance(elem, nn.Module) and elem not in self.module_names: + return True + else: + return False + def populate_from(self, other_graph: "Graph"): for node in other_graph.node_list: - if is_mod_and_not_in_module_names(node): - # Skip if not included in the module names - continue for outgoing_node in node.outgoing_nodes: - if is_mod_and_not_in_module_names(outgoing_node): - # Skip if not included in the module names - continue - else: - self.add_edge(node.elem, outgoing_node.elem) + self.add_edge(node.elem, outgoing_node.elem) def __str__(self) -> str: return self.to_md() @@ -269,8 +269,8 @@ def __exit__(self, exc_type, exc_value, exc_tb): -def extract_graph(model: nn.Module, sample_data: Any)->Graph: - with GraphTracer(named_modules_map(model)) as tracer, torch.no_grad(): +def extract_graph(model: nn.Module, sample_data: Any, model_name: Optional[str] = "model")->Graph: + with GraphTracer(named_modules_map(model, model_name=model_name)) as tracer, torch.no_grad(): out = model(sample_data) return tracer.graph \ No newline at end of file From 39e5be9013c8033e2ce41d1beab17ddce3933334 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:50:11 +0200 Subject: [PATCH 13/24] balckened --- sinabs/graph.py | 17 +++++++++++------ tests/test_graph.py | 12 +++++------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 26423357..122a8566 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -124,8 +124,10 @@ def add_edge( source: Union[torch.Tensor, nn.Module], destination: Union[torch.Tensor, nn.Module], ): - if self._is_mod_and_not_in_module_names(source): return - if self._is_mod_and_not_in_module_names(destination): return + if self._is_mod_and_not_in_module_names(source): + return + if self._is_mod_and_not_in_module_names(destination): + return source_node = self.add_or_get_node_for_elem(source) destination_node = self.add_or_get_node_for_elem(destination) @@ -268,9 +270,12 @@ def __exit__(self, exc_type, exc_value, exc_tb): nn.Module.__call__ = self.original_torch_call - -def extract_graph(model: nn.Module, sample_data: Any, model_name: Optional[str] = "model")->Graph: - with GraphTracer(named_modules_map(model, model_name=model_name)) as tracer, torch.no_grad(): +def extract_graph( + model: nn.Module, sample_data: Any, model_name: Optional[str] = "model" +) -> Graph: + with GraphTracer( + named_modules_map(model, model_name=model_name) + ) as tracer, torch.no_grad(): out = model(sample_data) - return tracer.graph \ No newline at end of file + return tracer.graph diff --git a/tests/test_graph.py b/tests/test_graph.py index a25f7154..b9d8bc63 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -122,13 +122,11 @@ def test_ignore_submodules_of(): assert len(top_overview_graph.node_list) == 2 + 2 + 1 - def test_snn_branched(): from sinabs.layers import IAFSqueeze, ConcatenateChannel, SumPool2d from torch.nn import Conv2d from sinabs.graph import extract_graph - class MySNN(nn.Module): def __init__(self) -> None: super().__init__() @@ -144,7 +142,7 @@ def __init__(self) -> None: self.concat = ConcatenateChannel() self.conv3 = Conv2d(32, 10, 3, stride=3, bias=False) self.iaf3 = IAFSqueeze(batch_size=1) - + def forward(self, spikes): out = self.conv1(spikes) out = self.iaf1(out) @@ -153,7 +151,7 @@ def forward(self, spikes): out1 = self.conv2_1(out) out1 = self.iaf2_1(out1) out1 = self.pool2_1(out1) - + out2 = self.conv2_2(out) out2 = self.iaf2_2(out2) out2 = self.pool2_2(out2) @@ -164,8 +162,8 @@ def forward(self, spikes): return out my_snn = MySNN() - graph = extract_graph(my_snn, sample_data=torch.rand((100, 2, 14, 14))) + graph = extract_graph( + my_snn, sample_data=torch.rand((100, 2, 14, 14)), model_name=None + ) print(graph) - - From 3ae616cca2ed68298238caf2430ae9454a3905a3 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:53:06 +0200 Subject: [PATCH 14/24] added assert --- tests/test_graph.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_graph.py b/tests/test_graph.py index b9d8bc63..c0e3a9bd 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -167,3 +167,4 @@ def forward(self, spikes): ) print(graph) + assert len(graph.elem_list) == 25 # 2*12 + 1 From 5d37abf490789da3754cfe05732b23bf510a0430 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 00:57:43 +0200 Subject: [PATCH 15/24] added doc string --- sinabs/graph.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sinabs/graph.py b/sinabs/graph.py index 122a8566..14a123a0 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -273,6 +273,20 @@ def __exit__(self, exc_type, exc_value, exc_tb): def extract_graph( model: nn.Module, sample_data: Any, model_name: Optional[str] = "model" ) -> Graph: + """Extract computational graph between various modules in the model + NOTE: This method is not capable of any compute happening outside of module definitions. + + Args: + model (nn.Module): The module to be analysed + sample_data (Any): Sample data to be used to run by the model + model_name (Optional[str], optional): Name of the top level module. + If specified, it will be included in the graph. + If set to None, only its submodules will be listed in the graph. + Defaults to "model". + + Returns: + Graph: A graph object representing the computational graph of the given model + """ with GraphTracer( named_modules_map(model, model_name=model_name) ) as tracer, torch.no_grad(): From ce55563e37a5639ae9452fa49ad7779a0284a3ed Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 01:08:00 +0200 Subject: [PATCH 16/24] using cat instead of concat --- sinabs/layers/channel_concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sinabs/layers/channel_concat.py b/sinabs/layers/channel_concat.py index 0f8a675b..cd12d8db 100644 --- a/sinabs/layers/channel_concat.py +++ b/sinabs/layers/channel_concat.py @@ -7,4 +7,4 @@ def __init__(self, channel_axis=-3) -> None: self.channel_axis = -3 def forward(self, x, y): - return torch.concat((x, y), self.channel_axis) \ No newline at end of file + return torch.cat((x, y), self.channel_axis) \ No newline at end of file From f734449bd8f8cd60306fe12c7127f5c259e88b15 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 11 Jun 2023 01:09:28 +0200 Subject: [PATCH 17/24] removed torchvision from test requirements --- tests/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/requirements.txt b/tests/requirements.txt index 49a294b1..ceaca48d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,5 +3,4 @@ pytest-cov onnx onnxruntime torch>=1.8 -torchvision matplotlib From 7a3410b0d96104632162065aa50e0481eeb03f40 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Wed, 14 Jun 2023 00:34:35 +0200 Subject: [PATCH 18/24] added method to ignore tensors in graph --- sinabs/__init__.py | 1 + sinabs/graph.py | 60 ++++++++++++++++++++++++++++++++++++++++++--- tests/test_graph.py | 7 ++++++ 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/sinabs/__init__.py b/sinabs/__init__.py index a9628b55..3d6d183b 100644 --- a/sinabs/__init__.py +++ b/sinabs/__init__.py @@ -7,3 +7,4 @@ from .network import Network from .synopcounter import SNNAnalyzer, SynOpCounter from .utils import reset_states, zero_grad +from .graph import extract_graph diff --git a/sinabs/graph.py b/sinabs/graph.py index 14a123a0..3db531cf 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -65,6 +65,9 @@ def __init__(self, module_names: Dict[nn.Module, str]) -> None: self.elem_list = [] self.node_list: List[Node] = [] self._last_used_tensor_id = None + # Add modules to node_list + for mod, name in self.module_names.items(): + self.add_elem(mod, name) @property def node_map_by_id(self): @@ -177,8 +180,12 @@ def to_md(self) -> str: graph TD; """ for node in self.node_list: - for outgoing in node.outgoing_nodes: - mermaid_md += f"{node.name} --> {outgoing.name};\n" + if node.outgoing_nodes: + for outgoing in node.outgoing_nodes: + mermaid_md += f"{node.name} --> {outgoing.name};\n" + else: + mermaid_md += f"{node.name};\n" + end = """ ``` """ @@ -216,6 +223,53 @@ def ignore_submodules_of(self, classes: List[Type]) -> "Graph": new_graph.populate_from(self) return new_graph + def find_source_nodes_of(self, node: Node)->List[Node]: + """Find all the sources of a node in the graph + + Args: + node (Node): Node of interest + + Returns: + List[Node]: A list of all nodes that have this node as outgoing_node + """ + source_node_list = [] + for source_node in self.node_list: + for outnode in source_node.outgoing_nodes: + if node == outnode: + source_node_list.append(source_node) + return source_node_list + + def ignore_tensors(self)->"Graph": + """Simplify the graph by ignoring all the tensors in it + + Returns: + Graph: Returns a simplified graph with only modules in it + """ + graph = Graph(self.module_names) + # Iterate over all the nodes + for node in self.node_list: + if isinstance(node.elem, torch.Tensor): + # Get its source + source_node_list = self.find_source_nodes_of(node) + # If no source, this is probably origin node, just drop it + if len(source_node_list) == 0: + continue + # Get all of its destinations + # If no destinations, it is a leaf node, just drop it. + if node.outgoing_nodes: + for outgoing_node in node.outgoing_nodes: + # Directly add an edge from source to destination + for source_node in source_node_list: + graph.add_edge(source_node.elem, outgoing_node.elem) + # NOTE: Assuming that the destination is a module here + else: + # If it is a module, filter out all edges that have a tensor + # This is to preserve the graph if executed on a graph that is already filtered + for outnode in node.outgoing_nodes: + if isinstance(outnode.elem, nn.Module): + graph.add(node.elem, outnode.elem) + return graph + _torch_module_call = torch.nn.Module.__call__ @@ -234,7 +288,7 @@ def my_forward(mod: nn.Module, *args, **kwargs) -> Any: else: raise Exception("Unknown output format") # Iterate over all outputs and create nodes and edges - for i, output_data in enumerate(out_tuple): + for output_data in out_tuple: # Create nodes and edges model_graph.add_edge(mod, output_data) return out diff --git a/tests/test_graph.py b/tests/test_graph.py index c0e3a9bd..ff13a3d3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -168,3 +168,10 @@ def forward(self, spikes): print(graph) assert len(graph.elem_list) == 25 # 2*12 + 1 + + +def test_ignore_tensors(): + from sinabs import extract_graph + graph = extract_graph(mymodel, sample_data=data) + mod_only_graph = graph.ignore_tensors() + assert len(mod_only_graph.node_list) == 6 \ No newline at end of file From b7427092c725435af2c3b7501bcc33ca0a069553 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Wed, 14 Jun 2023 00:39:38 +0200 Subject: [PATCH 19/24] blackened --- sinabs/graph.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 3db531cf..33805d77 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -223,7 +223,7 @@ def ignore_submodules_of(self, classes: List[Type]) -> "Graph": new_graph.populate_from(self) return new_graph - def find_source_nodes_of(self, node: Node)->List[Node]: + def find_source_nodes_of(self, node: Node) -> List[Node]: """Find all the sources of a node in the graph Args: @@ -239,7 +239,7 @@ def find_source_nodes_of(self, node: Node)->List[Node]: source_node_list.append(source_node) return source_node_list - def ignore_tensors(self)->"Graph": + def ignore_tensors(self) -> "Graph": """Simplify the graph by ignoring all the tensors in it Returns: @@ -334,8 +334,8 @@ def extract_graph( model (nn.Module): The module to be analysed sample_data (Any): Sample data to be used to run by the model model_name (Optional[str], optional): Name of the top level module. - If specified, it will be included in the graph. - If set to None, only its submodules will be listed in the graph. + If specified, it will be included in the graph. + If set to None, only its submodules will be listed in the graph. Defaults to "model". Returns: From 7f42bce9cad365af85ab3d68dbc246d1312e6861 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Wed, 21 Jun 2023 16:46:17 +0200 Subject: [PATCH 20/24] function call bug fix in ignore_nodes --- sinabs/graph.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sinabs/graph.py b/sinabs/graph.py index 33805d77..fc67608c 100644 --- a/sinabs/graph.py +++ b/sinabs/graph.py @@ -239,16 +239,21 @@ def find_source_nodes_of(self, node: Node) -> List[Node]: source_node_list.append(source_node) return source_node_list + def ignore_tensors(self) -> "Graph": """Simplify the graph by ignoring all the tensors in it Returns: Graph: Returns a simplified graph with only modules in it """ + return self.ignore_nodes(torch.Tensor) + + + def ignore_nodes(self, class_type: Type)->"Graph": graph = Graph(self.module_names) # Iterate over all the nodes for node in self.node_list: - if isinstance(node.elem, torch.Tensor): + if isinstance(node.elem, class_type): # Get its source source_node_list = self.find_source_nodes_of(node) # If no source, this is probably origin node, just drop it @@ -261,13 +266,12 @@ def ignore_tensors(self) -> "Graph": # Directly add an edge from source to destination for source_node in source_node_list: graph.add_edge(source_node.elem, outgoing_node.elem) - # NOTE: Assuming that the destination is a module here + # NOTE: Assuming that the destination is not of the same type here else: - # If it is a module, filter out all edges that have a tensor # This is to preserve the graph if executed on a graph that is already filtered for outnode in node.outgoing_nodes: - if isinstance(outnode.elem, nn.Module): - graph.add(node.elem, outnode.elem) + if not isinstance(outnode.elem, class_type): + graph.add_edge(node.elem, outnode.elem) return graph From 7a9f0ea8f1a5ca8f206a910d20207b7cb77b3936 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Wed, 21 Jun 2023 23:24:42 +0200 Subject: [PATCH 21/24] modified Add to Morph --- sinabs/layers/__init__.py | 2 +- sinabs/layers/add_module.py | 12 ------------ sinabs/layers/merge.py | 28 ++++++++++++++++++++++++++++ tests/test_graph.py | 4 ++-- tests/test_merge.py | 22 ++++++++++++++++++++++ 5 files changed, 53 insertions(+), 15 deletions(-) delete mode 100644 sinabs/layers/add_module.py create mode 100644 sinabs/layers/merge.py create mode 100644 tests/test_merge.py diff --git a/sinabs/layers/__init__.py b/sinabs/layers/__init__.py index 6001bc4e..59fa7388 100644 --- a/sinabs/layers/__init__.py +++ b/sinabs/layers/__init__.py @@ -9,5 +9,5 @@ from .reshape import FlattenTime, Repeat, SqueezeMixin, UnflattenTime from .stateful_layer import StatefulLayer from .to_spike import Img2SpikeLayer, Sig2SpikeLayer -from .add_module import Add +from .merge import Merge from .channel_concat import ConcatenateChannel \ No newline at end of file diff --git a/sinabs/layers/add_module.py b/sinabs/layers/add_module.py deleted file mode 100644 index 821430dc..00000000 --- a/sinabs/layers/add_module.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch.nn as nn - -class Add(nn.Module): - def __init__(self, *args, **kwargs) -> None: - """ - Module form for a simple addition operation. - In the context of events/spikes, events/spikes from two different sources/rasters will be added. - """ - super().__init__(*args, **kwargs) - - def forward(self, data1, data2): - return data1 + data2 \ No newline at end of file diff --git a/sinabs/layers/merge.py b/sinabs/layers/merge.py new file mode 100644 index 00000000..cd631469 --- /dev/null +++ b/sinabs/layers/merge.py @@ -0,0 +1,28 @@ +import torch.nn as nn + +class Merge(nn.Module): + def __init__(self) -> None: + """ + Module form for a merge operation. + In the context of events/spikes, events/spikes from two different sources/rasters will be added. + """ + super().__init__() + + def forward(self, data1, data2): + size1 = data1.shape + size2 = data2.shape + if size1 == size2: + return data1 + data2 + # If the sizes are not the same, find the larger size and pad the data accordingly + assert len(size1) == len(size2) + pad1 = () + pad2 = () + # Find the larger sizes + for s1, s2 in zip(size1, size2): + s_max = max(s1, s2) + pad1 = (0, s_max-s1, *pad1) + pad2 = (0, s_max-s2, *pad2) + + data1 = nn.functional.pad(input=data1, pad=pad1, mode="constant", value=0) + data2 = nn.functional.pad(input=data2, pad=pad2, mode="constant", value=0) + return data1 + data2 diff --git a/tests/test_graph.py b/tests/test_graph.py index ff13a3d3..c755d663 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from sinabs.layers import Add +from sinabs.layers import Merge # Branched model @@ -11,7 +11,7 @@ def __init__(self) -> None: self.relu1 = nn.ReLU() self.relu2_1 = nn.ReLU() self.relu2_2 = nn.ReLU() - self.add_mod = Add() + self.add_mod = Merge() self.relu3 = nn.ReLU() def forward(self, data): diff --git a/tests/test_merge.py b/tests/test_merge.py new file mode 100644 index 00000000..1b3911e5 --- /dev/null +++ b/tests/test_merge.py @@ -0,0 +1,22 @@ +import torch +import sinabs.layers as sl + + +def test_morph_same_size(): + data1 = (torch.rand((100, 1, 20, 20)) > 0.5).float() + data2 = (torch.rand((100, 1, 20, 20)) > 0.5).float() + + merge = sl.Merge() + out = merge(data1, data2) + assert out.shape == (100, 1, 20, 20) + + +def test_morph_different_size(): + data1 = (torch.rand((100, 1, 5, 6)) > 0.5).float() + data2 = (torch.rand((100, 10, 5, 5)) > 0.5).float() + + merge = sl.Merge() + out = merge(data1, data2) + + assert out.shape == (100, 10, 5, 6) + assert out.sum() == data1.sum() + data2.sum() From 89cfe360cd1dae480cfe24128f85a0a0029c69d9 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 2 Jul 2023 21:53:10 -0600 Subject: [PATCH 22/24] removed graph (migrated to nirtorch) --- sinabs/__init__.py | 1 - sinabs/layers/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sinabs/__init__.py b/sinabs/__init__.py index 3d6d183b..a9628b55 100644 --- a/sinabs/__init__.py +++ b/sinabs/__init__.py @@ -7,4 +7,3 @@ from .network import Network from .synopcounter import SNNAnalyzer, SynOpCounter from .utils import reset_states, zero_grad -from .graph import extract_graph diff --git a/sinabs/layers/__init__.py b/sinabs/layers/__init__.py index 59fa7388..406a4a38 100644 --- a/sinabs/layers/__init__.py +++ b/sinabs/layers/__init__.py @@ -10,4 +10,4 @@ from .stateful_layer import StatefulLayer from .to_spike import Img2SpikeLayer, Sig2SpikeLayer from .merge import Merge -from .channel_concat import ConcatenateChannel \ No newline at end of file +from .channel_shift import ChannelShift \ No newline at end of file From fbf408e65bbee1880cfdbb03104222e63311db00 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 2 Jul 2023 21:53:21 -0600 Subject: [PATCH 23/24] added channel shift layer --- sinabs/layers/channel_shift.py | 23 +++++++++++++++++++++++ tests/test_channelshift.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 sinabs/layers/channel_shift.py create mode 100644 tests/test_channelshift.py diff --git a/sinabs/layers/channel_shift.py b/sinabs/layers/channel_shift.py new file mode 100644 index 00000000..34021b35 --- /dev/null +++ b/sinabs/layers/channel_shift.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn + + +class ChannelShift(nn.Module): + def __init__(self, channel_shift: int = 0, channel_axis=-3) -> None: + """Given a tensor, shift the channel from the left, ie zero pad from the left. + + Args: + channel_shift (int, optional): Number of channels to shift by. Defaults to 0. + channel_axis (int, optional): The channel axis dimension + NOTE: This has to be a negative dimension such that it counts the dimension from the right. Defaults to -3. + """ + super().__init__() + self.padding = [] + self.channel_shift = channel_shift + self.channel_axis = channel_axis + for axis in range(-channel_axis): + self.padding += [0, 0] + self.padding[-2] = channel_shift + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return nn.functional.pad(input=x, pad=self.padding, mode="constant", value=0) diff --git a/tests/test_channelshift.py b/tests/test_channelshift.py new file mode 100644 index 00000000..bc2e0321 --- /dev/null +++ b/tests/test_channelshift.py @@ -0,0 +1,21 @@ +import torch +from sinabs.layers.channel_shift import ChannelShift + + +def test_channel_shift_default(): + x = torch.rand(1, 10, 5, 5) + cs = ChannelShift() + + out = cs(x) + assert out.shape == x.shape + + +def test_channel_shift(): + num_channels = 10 + channel_shift = 14 + x = torch.rand(1, num_channels, 5, 5) + cs = ChannelShift(channel_shift=channel_shift) + + out = cs(x) + assert len(out.shape) == len(x.shape) + assert out.shape[1] == num_channels + channel_shift From 8a6c7a954e48ff280e42b14bb491afe0ef42aec5 Mon Sep 17 00:00:00 2001 From: Sadique Sheik Date: Sun, 2 Jul 2023 21:55:45 -0600 Subject: [PATCH 24/24] deleted old files --- sinabs/graph.py | 353 -------------------------------- sinabs/layers/channel_concat.py | 10 - tests/test_graph.py | 177 ---------------- 3 files changed, 540 deletions(-) delete mode 100644 sinabs/graph.py delete mode 100644 sinabs/layers/channel_concat.py delete mode 100644 tests/test_graph.py diff --git a/sinabs/graph.py b/sinabs/graph.py deleted file mode 100644 index fc67608c..00000000 --- a/sinabs/graph.py +++ /dev/null @@ -1,353 +0,0 @@ -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union - -import torch -import torch.nn as nn - - -def named_modules_map( - model: nn.Module, model_name: Optional[str] = "model" -) -> Dict[str, nn.Module]: - """Inverse of named modules dictionary - - Args: - model (nn.Module): The module to be hashed - model_name (str | None): Name of the top level module. If this doesn't need to be include, this option can be set to None - - Returns: - Dict[str, nn.Module]: A dictionary with modules as keys, and names as values - """ - modules_map = {} - for name, mod in model.named_modules(): - modules_map[mod] = name - if model_name is None: - del modules_map[model] - else: - modules_map[model] = model_name - return modules_map - - -class Node: - def __init__( - self, - elem: Any, - name: str, - outgoing_nodes: Optional[List["Node"]] = None, - ) -> None: - self.elem = elem - self.name = name - if not outgoing_nodes: - self.outgoing_nodes = [] - else: - self.outgoing_nodes = outgoing_nodes - - def add_outgoing(self, node: "Node"): - self.outgoing_nodes.append(node) - - def __str__(self) -> str: - return f"Node: {self.name}, Out: {len(self.outgoing_nodes)}" - - def __eq__(self, other: Any) -> bool: - # Two nodes are meant to be the same if they refer to the same element - try: - return self.elem is other.elem - except AttributeError: - return False - - def __hash__(self): - # Two nodes are same if they reference the same element - return hash(self.elem) - - -class Graph: - def __init__(self, module_names: Dict[nn.Module, str]) -> None: - self.module_names = module_names - self.elem_list = [] - self.node_list: List[Node] = [] - self._last_used_tensor_id = None - # Add modules to node_list - for mod, name in self.module_names.items(): - self.add_elem(mod, name) - - @property - def node_map_by_id(self): - return {n.name: n for n in self.node_list} - - def num_edges(self) -> int: - count = 0 - for node in self.node_list: - count += node.outgoing_nodes - return count - - def get_unique_tensor_id(self) -> str: - if self._last_used_tensor_id is None: - self._last_used_tensor_id = 0 - else: - self._last_used_tensor_id += 1 - return str(self._last_used_tensor_id) - - def __contains__(self, elem: Union[torch.Tensor, nn.Module]) -> bool: - for elem_in_list in self.elem_list: - if elem is elem_in_list: - return True - return False - - def add_elem(self, elem, name: str) -> Node: - if elem in self: - warnings.warn(f"{name}: Node already exists for this element ") - return self.find_node(elem) - else: - node = Node(elem, name) - self.elem_list.append(elem) - self.node_list.append(node) - return node - - def add_or_get_node_for_elem(self, elem: Union[torch.Tensor, nn.Module]): - if elem in self: - return self.find_node(elem) - else: - # Generate a name - if elem in self.module_names: - name = self.module_names[elem] - else: - assert isinstance(elem, torch.Tensor) - name = f"Tensor_{self.get_unique_tensor_id()}{tuple(elem.shape)}" - # add and return the node - new_node = self.add_elem(elem, name) - return new_node - - def find_node(self, elem: Union[torch.Tensor, nn.Module]): - for node in self.node_list: - if elem is node.elem: - return node - raise ValueError("elem not found") - - def add_edge( - self, - source: Union[torch.Tensor, nn.Module], - destination: Union[torch.Tensor, nn.Module], - ): - if self._is_mod_and_not_in_module_names(source): - return - if self._is_mod_and_not_in_module_names(destination): - return - - source_node = self.add_or_get_node_for_elem(source) - destination_node = self.add_or_get_node_for_elem(destination) - source_node.add_outgoing(destination_node) - return source_node, destination_node - - def get_leaf_modules(self) -> Dict[nn.Module, str]: - filtered_module_names = {} - - for mod, _ in self.module_names.items(): - # Add module to dict - filtered_module_names[mod] = self.module_names[mod] - child_in_graph = False - for _, submod in mod.named_children(): - if submod in self: - child_in_graph = True - break - if child_in_graph: - del filtered_module_names[mod] - return filtered_module_names - - def _is_mod_and_not_in_module_names(self, elem: Any) -> bool: - """Check if a node is a module and is included in the module_names of this graph - - Args: - node (Node): Node to verify - - Returns: - bool - """ - if isinstance(elem, nn.Module) and elem not in self.module_names: - return True - else: - return False - - def populate_from(self, other_graph: "Graph"): - for node in other_graph.node_list: - for outgoing_node in node.outgoing_nodes: - self.add_edge(node.elem, outgoing_node.elem) - - def __str__(self) -> str: - return self.to_md() - - def to_md(self) -> str: - mermaid_md = """ -```mermaid -graph TD; -""" - for node in self.node_list: - if node.outgoing_nodes: - for outgoing in node.outgoing_nodes: - mermaid_md += f"{node.name} --> {outgoing.name};\n" - else: - mermaid_md += f"{node.name};\n" - - end = """ -``` -""" - return mermaid_md + end - - def leaf_only(self) -> "Graph": - leaf_modules = self.get_leaf_modules() - filtered_graph = Graph(leaf_modules) - # Populate edges - filtered_graph.populate_from(self) - return filtered_graph - - def ignore_submodules_of(self, classes: List[Type]) -> "Graph": - new_named_modules = {} - - # Gather a list of all top level modules, whose submodules are to be ignored - top_level_modules: List[nn.Module] = [] - for mod in self.module_names.keys(): - if mod.__class__ in classes: - top_level_modules.append(mod) - - # List all the submodules of the above module list - sub_modules_to_ignore: List[nn.Module] = [] - for top_mod in top_level_modules: - for sub_mod in top_mod.modules(): - if sub_mod is not top_mod: - sub_modules_to_ignore.append(sub_mod) - - # Iterate over all modules and check if they are submodules of the above list - for mod, name in self.module_names.items(): - if mod not in sub_modules_to_ignore: - new_named_modules[mod] = name - # Create a new graph with the allowed modules - new_graph = Graph(new_named_modules) - new_graph.populate_from(self) - return new_graph - - def find_source_nodes_of(self, node: Node) -> List[Node]: - """Find all the sources of a node in the graph - - Args: - node (Node): Node of interest - - Returns: - List[Node]: A list of all nodes that have this node as outgoing_node - """ - source_node_list = [] - for source_node in self.node_list: - for outnode in source_node.outgoing_nodes: - if node == outnode: - source_node_list.append(source_node) - return source_node_list - - - def ignore_tensors(self) -> "Graph": - """Simplify the graph by ignoring all the tensors in it - - Returns: - Graph: Returns a simplified graph with only modules in it - """ - return self.ignore_nodes(torch.Tensor) - - - def ignore_nodes(self, class_type: Type)->"Graph": - graph = Graph(self.module_names) - # Iterate over all the nodes - for node in self.node_list: - if isinstance(node.elem, class_type): - # Get its source - source_node_list = self.find_source_nodes_of(node) - # If no source, this is probably origin node, just drop it - if len(source_node_list) == 0: - continue - # Get all of its destinations - # If no destinations, it is a leaf node, just drop it. - if node.outgoing_nodes: - for outgoing_node in node.outgoing_nodes: - # Directly add an edge from source to destination - for source_node in source_node_list: - graph.add_edge(source_node.elem, outgoing_node.elem) - # NOTE: Assuming that the destination is not of the same type here - else: - # This is to preserve the graph if executed on a graph that is already filtered - for outnode in node.outgoing_nodes: - if not isinstance(outnode.elem, class_type): - graph.add_edge(node.elem, outnode.elem) - return graph - - -_torch_module_call = torch.nn.Module.__call__ - - -def module_forward_wrapper(model_graph: Graph) -> Callable[..., Any]: - def my_forward(mod: nn.Module, *args, **kwargs) -> Any: - # Iterate over all inputs - for i, input_data in enumerate(args): - # Create nodes and edges - model_graph.add_edge(input_data, mod) - out = _torch_module_call(mod, *args, **kwargs) - if isinstance(out, tuple): - out_tuple - elif isinstance(out, torch.Tensor): - out_tuple = (out,) - else: - raise Exception("Unknown output format") - # Iterate over all outputs and create nodes and edges - for output_data in out_tuple: - # Create nodes and edges - model_graph.add_edge(mod, output_data) - return out - - return my_forward - - -class GraphTracer: - """ - Context manager to trace a model's execution graph - - Example: - - ```python - with GraphTracer(mymodel) as tracer, torch.no_grad(): - out = mymodel(data) - - print(tracer.graph.to_md()) - ``` - """ - - def __init__(self, mod: nn.Module) -> None: - self.original_torch_call = nn.Module.__call__ - self.graph = Graph(mod) - - def __enter__(self) -> "GraphTracer": - # Override the torch call method - nn.Module.__call__ = module_forward_wrapper(self.graph) - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - # Restore normal behavior - nn.Module.__call__ = self.original_torch_call - - -def extract_graph( - model: nn.Module, sample_data: Any, model_name: Optional[str] = "model" -) -> Graph: - """Extract computational graph between various modules in the model - NOTE: This method is not capable of any compute happening outside of module definitions. - - Args: - model (nn.Module): The module to be analysed - sample_data (Any): Sample data to be used to run by the model - model_name (Optional[str], optional): Name of the top level module. - If specified, it will be included in the graph. - If set to None, only its submodules will be listed in the graph. - Defaults to "model". - - Returns: - Graph: A graph object representing the computational graph of the given model - """ - with GraphTracer( - named_modules_map(model, model_name=model_name) - ) as tracer, torch.no_grad(): - out = model(sample_data) - - return tracer.graph diff --git a/sinabs/layers/channel_concat.py b/sinabs/layers/channel_concat.py deleted file mode 100644 index cd12d8db..00000000 --- a/sinabs/layers/channel_concat.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -import torch.nn as nn - -class ConcatenateChannel(nn.Module): - def __init__(self, channel_axis=-3) -> None: - super().__init__() - self.channel_axis = -3 - - def forward(self, x, y): - return torch.cat((x, y), self.channel_axis) \ No newline at end of file diff --git a/tests/test_graph.py b/tests/test_graph.py deleted file mode 100644 index c755d663..00000000 --- a/tests/test_graph.py +++ /dev/null @@ -1,177 +0,0 @@ -import torch -import torch.nn as nn - -from sinabs.layers import Merge - - -# Branched model -class MyBranchedModel(nn.Module): - def __init__(self) -> None: - super().__init__() - self.relu1 = nn.ReLU() - self.relu2_1 = nn.ReLU() - self.relu2_2 = nn.ReLU() - self.add_mod = Merge() - self.relu3 = nn.ReLU() - - def forward(self, data): - out1 = self.relu1(data) - out2_1 = self.relu2_1(out1) - out2_2 = self.relu2_2(out1) - out3 = self.add_mod(out2_1, out2_2) - out4 = self.relu3(out3) - return out4 - - -input_shape = (2, 28, 28) -batch_size = 1 - -data = torch.ones((batch_size, *input_shape)) - -mymodel = MyBranchedModel() - - -class DeepModel(nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.block1 = MyBranchedModel() - self.block2 = MyBranchedModel() - - def forward(self, data): - out = self.block1(data) - out2 = self.block2(out) - return out2 - - -mydeepmodel = DeepModel() - - -def test_named_modules_map(): - from sinabs.graph import named_modules_map - - mod_map = named_modules_map(mymodel) - print(mod_map) - for k, v in mod_map.items(): - assert isinstance(k, nn.Module) - assert isinstance(v, str) - - -def test_module_forward_wrapper(): - mymodel = MyBranchedModel() - - orig_call = nn.Module.__call__ - - from sinabs.graph import Graph, module_forward_wrapper, named_modules_map - - model_graph = Graph(named_modules_map(mymodel)) - new_call = module_forward_wrapper(model_graph) - - # Override call to the new wrapped call - nn.Module.__call__ = new_call - - with torch.no_grad(): - out = mymodel(data) - - # Restore normal behavior - nn.Module.__call__ = orig_call - - print(model_graph) - assert ( - len(model_graph.node_list) == 1 + 5 + 5 + 1 - ) # 1 top module + 5 submodules + 5 tensors + 1 output tensor - - -def test_graph_tracer(): - from sinabs.graph import GraphTracer, named_modules_map - - with GraphTracer(named_modules_map(mymodel)) as tracer, torch.no_grad(): - out = mymodel(data) - - print(tracer.graph) - assert ( - len(tracer.graph.node_list) == 1 + 5 + 5 + 1 - ) # 1 top module + 5 submodules + 5 tensors + 1 output tensor - - -def test_leaf_only_graph(): - from sinabs.graph import GraphTracer, named_modules_map - - with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): - out = mydeepmodel(data) - - print(tracer.graph) - - # Get graph with just the leaf nodes - leaf_graph = tracer.graph.leaf_only() - print(leaf_graph) - assert ( - len(leaf_graph.node_list) == len(tracer.graph.node_list) - 3 - ) # No more top modules - - -def test_ignore_submodules_of(): - from sinabs.graph import GraphTracer, named_modules_map - - with GraphTracer(named_modules_map(mydeepmodel)) as tracer, torch.no_grad(): - out = mydeepmodel(data) - - top_overview_graph = tracer.graph.ignore_submodules_of( - [MyBranchedModel] - ).leaf_only() - print(top_overview_graph) - assert len(top_overview_graph.node_list) == 2 + 2 + 1 - - -def test_snn_branched(): - from sinabs.layers import IAFSqueeze, ConcatenateChannel, SumPool2d - from torch.nn import Conv2d - from sinabs.graph import extract_graph - - class MySNN(nn.Module): - def __init__(self) -> None: - super().__init__() - self.conv1 = Conv2d(2, 8, 3, bias=False) - self.iaf1 = IAFSqueeze(batch_size=1) - self.pool1 = SumPool2d(2) - self.conv2_1 = Conv2d(8, 16, 3, stride=1, padding=1, bias=False) - self.iaf2_1 = IAFSqueeze(batch_size=1) - self.pool2_1 = SumPool2d(2) - self.conv2_2 = Conv2d(8, 16, 5, stride=1, padding=2, bias=False) - self.iaf2_2 = IAFSqueeze(batch_size=1) - self.pool2_2 = SumPool2d(2) - self.concat = ConcatenateChannel() - self.conv3 = Conv2d(32, 10, 3, stride=3, bias=False) - self.iaf3 = IAFSqueeze(batch_size=1) - - def forward(self, spikes): - out = self.conv1(spikes) - out = self.iaf1(out) - out = self.pool1(out) - - out1 = self.conv2_1(out) - out1 = self.iaf2_1(out1) - out1 = self.pool2_1(out1) - - out2 = self.conv2_2(out) - out2 = self.iaf2_2(out2) - out2 = self.pool2_2(out2) - - out = self.concat(out1, out2) - out = self.conv3(out) - out = self.iaf3(out) - return out - - my_snn = MySNN() - graph = extract_graph( - my_snn, sample_data=torch.rand((100, 2, 14, 14)), model_name=None - ) - - print(graph) - assert len(graph.elem_list) == 25 # 2*12 + 1 - - -def test_ignore_tensors(): - from sinabs import extract_graph - graph = extract_graph(mymodel, sample_data=data) - mod_only_graph = graph.ignore_tensors() - assert len(mod_only_graph.node_list) == 6 \ No newline at end of file