forked from Hippogriff/CSGNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrouping.py
110 lines (96 loc) · 3.79 KB
/
grouping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
from src.utils import read_config
from src.utils import train_utils
from src.Models.models import ParseModelOutput
max_len = 13
config = read_config.Config("config_synthetic.yml")
valid_permutations = train_utils.valid_permutations
# Load the terminals symbols of the grammar
with open("terminals.txt", "r") as file:
unique_draw = file.readlines()
for index, e in enumerate(unique_draw):
unique_draw[index] = e[0:-1]
parser = ParseModelOutput(unique_draw, max_len // 2 + 1, max_len, config.canvas_shape)
class EditDistance:
"""
Defines edit distance between two programs. Following criterion are used
to find edit distance:
1. Done: Subset string
2. % Subset
3. Primitive type based subsetting
4. Done: Permutation invariant subsetting
"""
def __init__(self):
pass
def edit_distance(self, prog1, prog2, iou):
"""
Calculates edit distance between two programs
:param prog1:
:param prog2:
:param iou:
:return:
"""
prog1_tokens = self.parse(prog1)
prog2_tokens = self.parse(prog2)
all_valid_programs1 = list(set(valid_permutations(prog1_tokens, permutations=[], stack=[], start=True)))
all_valid_programs2 = list(set(valid_permutations(prog2_tokens, permutations=[], stack=[], start=True)))
if iou == 1:
return 0
# if prog1 in prog2:
# return len(prog2_tokens) - len(prog1_tokens)
#
# elif prog2 in prog1:
# return len(prog1_tokens) - len(prog2_tokens)
# else:
# return 100
if len(prog1_tokens) <= len(prog2_tokens):
subsets1 = self.exhaustive_subsets_edit_distance(all_valid_programs1, all_valid_programs2)
return np.min(subsets1)
else:
subsets2 = self.exhaustive_subsets_edit_distance(all_valid_programs2, all_valid_programs1)
return np.min(subsets2)
# return np.min([np.min(subsets1), np.min(subsets2)])
def exhaustive_subsets_edit_distance(self, progs1, progs2):
len_1 = len(progs1)
len_2 = len(progs2)
subset_flag = np.zeros((len_1, len_2))
for index1, p1 in enumerate(progs1):
for index2, p2 in enumerate(progs2):
if p1 in p2:
prog1_tokens = self.parse(p1)
prog2_tokens = self.parse(p2)
subset_flag[index1, index2] = len(prog2_tokens) - len(prog1_tokens)
else:
subset_flag[index1, index2] = 100
return subset_flag
def subset_program_structure_primitives(self, prog1, prog2):
"""
Define edit distance based on partial program structure and primitive
types. If the partial program structure is same and the position of the
primitives is same, then edit distance is positive.
"""
pass
def parse(self, expression):
"""
NOTE: This method is different from parse method in Parser class
Takes an expression, returns a serial program
:param expression: program expression in postfix notation
:return program:
"""
shape_types = ["c", "s", "t"]
op = ["*", "+", "-"]
program = []
for index, value in enumerate(expression):
if value in shape_types:
program.append({})
program[-1]["type"] = "draw"
# find where the parenthesis closes
close_paren = expression[index:].index(")") + index
program[-1]["value"] = expression[index:close_paren + 1]
elif value in op:
program.append({})
program[-1]["type"] = "op"
program[-1]["value"] = value
else:
pass
return program