-
Notifications
You must be signed in to change notification settings - Fork 1
/
get_dataset.py
70 lines (52 loc) · 1.89 KB
/
get_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from datasets.image_dataset import ImageDataset, WFDD
from datasets.hdf5_dataset import WMCA_H5, ThreeDMAD, CSMAD
import pandas as pd
import torch
# ImageDataset does not support WFDD yet
ImageDataset_List = [
'CASIA-FASD', 'REPLAY-ATTACK', 'ROSE-YOUTU', 'OULU-NPU', 'MSU-MFSD', 'SIW', 'CASIA-SURF-3DMASK'
'CASIA-SURF', 'CASIA-HIFI-MASK', 'CeFA', 'CelabA-Spoof'
]
HDF5_dataset = ['ThreeDMAD', 'WMCA', 'CSMAD']
def parse_data_list_csv(data_list_path):
"""
:param data_list_path:
:return:
image_list: the image's path
label_list: the label of the image: 0-genuine, 1-photo, 2-replay, 3-mask
"""
csv = pd.read_csv(data_list_path, header=None)
image_list = csv.get(0)
label_list = csv.get(1)
return image_list, label_list
def label_transform(label):
"""
This label transform function transforms labels ('0':genuine, '1':photo, '2':replay, '3':mask) parsed from csv files
to binary labels (0-genuine/real, 1-spoofing/fake).
You can define your own label transform function
:param label: '1'
:return:
"""
new_label = int(bool(int(label)))
return new_label
def get_image_dataset_from_list(csv_path, torchvision_transform=None):
"""
:param file_path_list:
:param torchvision_transform:
:return:
"""
image_path_list, label_list = parse_data_list_csv(csv_path)
transformed_label_list = list(map(label_transform,label_list))
image_dataset = ImageDataset(
file_list=image_path_list,
label_list=transformed_label_list,
torchvision_transform=torchvision_transform,
use_original_frame=False,
bbox_suffix='_bbox_mtccnn.txt'
)
return image_dataset
if __name__ == '__main__':
data_list_csv_path = 'examples/example.csv'
image_dataset = get_image_dataset_from_list(data_list_csv_path)
x = image_dataset.__getitem__(0)
import IPython; IPython.embed()