-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathem_graph_pickler.py
131 lines (93 loc) · 3.26 KB
/
em_graph_pickler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations
from typing import Any
import io
import os
import pathlib
import pickle
import pickletools
import tempfile
import uuid
import networkx
import pyrsistence
class Pickler(pickle.Pickler):
def __init__(self) -> None:
self._io = io.BytesIO()
super().__init__(self._io)
self.fast = True
def pickle(self, obj: Any) -> bytes:
self._io.seek(0)
self.dump(obj)
size = self._io.tell()
self._io.seek(0)
data = self._io.read(size)
return pickletools.optimize(data)
class Unpickler(pickle.Unpickler):
def __init__(self) -> None:
self._io = io.BytesIO()
super().__init__(self._io)
def unpickle(self, data: bytes) -> Any:
self._io.seek(0)
self._io.write(pickletools.optimize(data))
self._io.seek(0)
obj = self.load()
return obj
class ObservedDict(dict):
def __new__(cls, *args: Any, **kwargs: Any) -> ObservedDict:
self = super().__new__(cls, *args, **kwargs)
self._observer = None
return self
def __setitem__(self, key: Any, value: Any) -> None:
super().__setitem__(key, value)
if self._observer:
observer, key = self._observer
observer[key] = self
def __getstate__(self) -> dict:
return dict(self.__dict__, _observer=None)
class EMDict(pyrsistence.EMDict):
def __init__(self, path: Path | str = None) -> None:
self._path = (
path or pathlib.Path(tempfile.gettempdir()) / f"em_dict-{uuid.uuid1()}"
)
super().__init__(str(self._path), pickler=Pickler(), unpickler=Unpickler())
print(f"EMDict created at {self._path}")
def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(value, ObservedDict):
value._observer = (self, key)
super().__setitem__(key, value)
def __getitem__(self, key: Any) -> Any:
value = super().__getitem__(key)
if isinstance(value, ObservedDict):
value._observer = (self, key)
return value
def __getstate__(self) -> dict:
return {"path": self._path}
def __setstate__(self, state: dict) -> None:
path = state["path"]
self.__init__(path)
class Graph(networkx.DiGraph):
node_dict_factory = ObservedDict
node_attr_dict_factory = ObservedDict
adjlist_outer_dict_factory = EMDict
adjlist_inner_dict_factory = ObservedDict
edge_attr_dict_factory = ObservedDict
graph_attr_dict_factory = ObservedDict
def main() -> None:
em_graph = Graph()
while True:
random_graph = networkx.erdos_renyi_graph(20, 0.1, directed=True)
if networkx.is_weakly_connected(random_graph):
break
em_graph.add_nodes_from(random_graph)
em_graph.add_edges_from(random_graph.edges())
print(f"Random graph : {random_graph}")
print(f"EM graph : {em_graph}")
#
# Now em_graph and random_graph should be equal. Run a simple test over the
# dominating sets.
#
random_doms = networkx.immediate_dominators(random_graph, 0)
em_doms = networkx.immediate_dominators(em_graph, 0)
assert random_doms == em_doms, "Immediate dominators not the same"
print("Test successful")
if __name__ == "__main__":
main()