-
Notifications
You must be signed in to change notification settings - Fork 0
/
voxel_dataset.py
114 lines (91 loc) · 3.98 KB
/
voxel_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import faiss
from einops import rearrange
from config import *
import faiss
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
class VoxelHeightDataset(Dataset):
"""
Monocular forward facing views of 64x64 voxel grid
45, 115 x rotation
-44 to 44 z rotation
"""
def __init__(self, dataset_size, dir, transform=None):
self.dataset_size = dataset_size
self.dir = dir
self.transform = transform
def __len__(self):
return self.dataset_size
def get_camera_view(self, idx):
img = io.imread(os.path.join(self.dir, "voxels_{}.png".format(idx)))
cv = rearrange(img, 'h w c -> c h w')
cv = np.delete(cv, 3, 0) # Remove alpha channel...
# xt = rearrange(img, 'h w c -> (h w) c').astype('float32')
# sq = faiss.ScalarQuantizer(screen_size[0], faiss.ScalarQuantizer.QT_4bit)
# sq.train(xt)
# codes = sq.compute_codes(xt)
# print(codes.shape)
return torch.from_numpy(cv.astype('float32'))
def get_voxel_grid(self, idx):
height_grid = torch.zeros(voxel_grid_size, dtype=torch.float32)
reward_grid = torch.zeros(voxel_grid_size, dtype=torch.int32)
with open(os.path.join(self.dir, "visible_elements_{}".format(idx))) as visible:
lines = visible.readlines()
for line in lines:
data = [int(d) for d in line.split(",")]
height_grid[data[0]] = data[1]
reward_grid[data[0]] = data[2]
return height_grid, reward_grid
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
camera_view = self.get_camera_view(idx)
# TODO: Use Faiss to generate k means so that the input data can use a color index byte to reduce input Tensor size while retaining color
height_grid, reward_grid = self.get_voxel_grid(idx)
sample = {'view':camera_view, 'grid':height_grid, 'rewards':reward_grid}
if self.transform:
sample = self.transform(sample)
return sample
# Same as above, but rewards for predicting number of visible pixels on a given voxel rather than height
class VoxelVisibilityDataset(Dataset):
def __init__(self, dataset_size, dir, transform=None):
self.dataset_size = dataset_size
self.dir = dir
self.transform = transform
def __len__(self):
return self.dataset_size
def get_camera_view(self, idx):
img = io.imread(os.path.join(self.dir, "voxels_{}.png".format(idx)))
cv = rearrange(img, 'h w c -> c h w')
cv = np.delete(cv, 3, 0) # Remove alpha channel...
return torch.from_numpy(cv.astype('float32'))
def get_voxel_grid(self, idx):
visibility_grid = torch.zeros(voxel_grid_size, dtype=torch.float32)
reward_grid = torch.zeros(voxel_grid_size, dtype=torch.int32)
with open(os.path.join(self.dir, "visible_elements_{}".format(idx))) as visible:
lines = visible.readlines()
for line in lines:
data = [int(d) for d in line.split(",")]
visibility_grid[data[0]] = data[2]
reward_grid[data[0]] = data[2]
return visibility_grid, reward_grid
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
camera_view = self.get_camera_view(idx)
# TODO: Use Faiss to generate k means so that the input data can use a color index byte to reduce input Tensor size while retaining color
visibility_grid, reward_grid = self.get_voxel_grid(idx)
sample = {'view':camera_view, 'grid':visibility_grid, 'rewards':reward_grid}
if self.transform:
sample = self.transform(sample)
return sample