-
Notifications
You must be signed in to change notification settings - Fork 0
/
write_datasets.py
53 lines (41 loc) · 1.58 KB
/
write_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import sys
from glob import glob
'''
Creditted to [email protected]
https://github.com/GotG/yolotinyv3_medmask_demo
'''
def write_datasets(train_pct, valid_pct, source, target):
train_pct = int(train_pct)
valid_pct = int(valid_pct)
# check extension of files in folder:
images = glob(source + '/*/*.jpg')
images.sort()
number_of_images = len(images)
index_valid = round(number_of_images * valid_pct / 100)
index_train = round(number_of_images * train_pct / 100)
trainfiles = images[:index_train]
validfiles = images[index_train:(index_valid+index_train)]
testfiles = images[(index_valid+index_train):]
print('Number of images: ', number_of_images)
with open(os.path.join(target,'train.txt'), mode='w') as f:
for item in trainfiles:
f.write(item + "\n")
with open(os.path.join(target,'valid.txt'), mode='w') as f:
for item in validfiles:
f.write(item + "\n")
with open(os.path.join(target,'test.txt'), mode='w') as f:
for item in testfiles:
f.write(item + "\n")
print('Number of images for training: ', str(len(trainfiles)))
print('Number of images for validation: ', str(len(validfiles)))
print('Number of images for testing: ', str(len(testfiles)))
if __name__ == '__main__':
if len(sys.argv) > 4:
train_pct = sys.argv[1]
valid_pct = sys.argv[2]
source = sys.argv[3]
target = sys.argv[4]
else:
raise ValueError('Please enter the correct number of inputs')
write_datasets(train_pct, valid_pct, source, target)