-
Notifications
You must be signed in to change notification settings - Fork 13
/
grassdata.py
76 lines (66 loc) · 3.07 KB
/
grassdata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch.utils import data
from scipy.io import loadmat
from enum import Enum
class Tree(object):
class NodeType(Enum):
BOX = 0 # box node
ADJ = 1 # adjacency (adjacent part assembly) node
SYM = 2 # symmetry (symmetric part grouping) node
class Node(object):
def __init__(self, box=None, left=None, right=None, node_type=None, sym=None):
self.box = box # box feature vector for a leaf node
self.sym = sym # symmetry parameter vector for a symmetry node
self.left = left # left child for ADJ or SYM (a symmeter generator)
self.right = right # right child
self.node_type = node_type
self.label = torch.LongTensor([self.node_type.value])
def is_leaf(self):
return self.node_type == Tree.NodeType.BOX and self.box is not None
def is_adj(self):
return self.node_type == Tree.NodeType.ADJ
def is_sym(self):
return self.node_type == Tree.NodeType.SYM
def __init__(self, boxes, ops, syms):
box_list = [b for b in torch.split(boxes, 1, 0)]
sym_param = [s for s in torch.split(syms, 1, 0)]
box_list.reverse()
sym_param.reverse()
queue = []
for id in range(ops.size()[1]):
if ops[0, id] == Tree.NodeType.BOX.value:
queue.append(Tree.Node(box=box_list.pop(), node_type=Tree.NodeType.BOX))
elif ops[0, id] == Tree.NodeType.ADJ.value:
left_node = queue.pop()
right_node = queue.pop()
queue.append(Tree.Node(left=left_node, right=right_node, node_type=Tree.NodeType.ADJ))
elif ops[0, id] == Tree.NodeType.SYM.value:
node = queue.pop()
queue.append(Tree.Node(left=node, sym=sym_param.pop(), node_type=Tree.NodeType.SYM))
assert len(queue) == 1
self.root = queue[0]
class GRASSDataset(data.Dataset):
def __init__(self, dir, transform=None):
self.dir = dir
box_data = torch.from_numpy(loadmat(self.dir+'/box_data.mat')['boxes']).float()
op_data = torch.from_numpy(loadmat(self.dir+'/op_data.mat')['ops']).int()
sym_data = torch.from_numpy(loadmat(self.dir+'/sym_data.mat')['syms']).float()
#weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float()
num_examples = op_data.size()[1]
box_data = torch.chunk(box_data, num_examples, 1)
op_data = torch.chunk(op_data, num_examples, 1)
sym_data = torch.chunk(sym_data, num_examples, 1)
#weight_list = torch.chunk(weight_list, num_examples, 1)
self.transform = transform
self.trees = []
for i in range(len(op_data)) :
boxes = torch.t(box_data[i])
ops = torch.t(op_data[i])
syms = torch.t(sym_data[i])
tree = Tree(boxes, ops, syms)
self.trees.append(tree)
def __getitem__(self, index):
tree = self.trees[index]
return tree
def __len__(self):
return len(self.trees)