-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.py
88 lines (74 loc) · 2.58 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# library imports
import random
import itertools
import numpy as np
# project imports
from node import Node
from edge import Edge
class Graph:
"""
A simple graph object - implemented using a list of nodes with IDs and edges of these IDs
"""
def __init__(self,
nodes: list,
edges: list):
self.nodes = nodes
self.edges = edges
def get_size(self) -> int:
return len(self.nodes)
def next_nodes(self,
id: int):
return [edge.t_id for edge in self.edges if edge.s_id == id]
def next_nodes_with_weight(self,
id: int):
ids = []
ws = []
for edge in self.edges:
if edge.s_id == id:
ids.append(edge.t_id)
ws.append(edge.w)
return ids, ws
@staticmethod
def generate_random(node_count: int,
edge_count: int):
"""
Generate random graph with a given number of nodes and edges
"""
nodes = [Node(id=i) for i in range(node_count)]
edges = []
while len(edges) < edge_count:
s_id = random.randint(0, node_count-1)
t_id = random.randint(0, node_count-1)
if s_id != t_id and Edge(s_id=s_id, t_id=t_id, w=0) not in edges:
edges.append(Edge(s_id=s_id, t_id=t_id, w=1))
return Graph(nodes=nodes,
edges=edges)
@staticmethod
def fully_connected(node_count: int):
"""
Generate a fully connected graph with a given number of nodes
"""
nodes = [Node(id=i) for i in range(node_count)]
edges = []
for i in range(node_count):
for j in range(node_count):
if i != j:
edges.append(Edge(s_id=i, t_id=j, w=1))
return Graph(nodes=nodes,
edges=edges)
@staticmethod
def table_to_edges(data: np.ndarray):
edges = [[Edge(row_index, col_index, val)
for col_index, val in enumerate(row) if val > 0]
for row_index, row in enumerate(data)]
return list(itertools.chain.from_iterable(edges))
def copy(self):
return Graph(nodes=[node.copy() for node in self.nodes],
edges=[edge.copy() for edge in self.edges])
def __hash__(self):
return (self.nodes, self.edges).__hash__()
def __repr__(self):
return self.__str__()
def __str__(self):
return "<Graph: V={}, E={}>".format(len(self.nodes),
len(self.edges))