-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
33 lines (27 loc) · 1.03 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch.utils.data import Dataset
import numpy as np
from config import DEVICE
class HARWindows(Dataset):
'''
classdocs
'''
def __init__(self, filepath):
"""
:param filepath: filepath to .npz file with "data_x" and "data_y" arrays
"""
np_dataset = np.load(filepath)
self.data_x = torch.from_numpy(np_dataset["data_x"]).to(dtype=torch.float, device=DEVICE)
self.data_y = torch.from_numpy(np_dataset["data_y"]).to(dtype=torch.long, device=DEVICE)
if len(self.data_x) != len(self.data_y):
raise ValueError("invalid dataset")
def __len__(self):
return len(self.data_x)
def __getitem__(self, idx):
"""
returns a window with its corresponding label
:param idx: index of the window
:returns tuple (data, label), where data has shape (1, rows, columns) and label has shape (1)
the first axis corresponds to the depth of the data, which is 1 but gets larger it is forwarded through the network
"""
return (self.data_x[idx:idx+1], self.data_y[idx])