forked from fire717/Machine-Learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
56 lines (38 loc) · 1.51 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#coding:utf-8
# @fire
import cv2
import os,sys
import numpy as np
from PIL import Image
import random
from my_data import myData
from my_model import myModel
def getAllName(file_dir):
L=[]
for root, dirs, files in os.walk(file_dir):
# root 所指的是当前正在遍历的这个文件夹的本身的地址
# dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
# files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
for file in files:
if os.path.splitext(file)[1] == '.jpg' or os.path.splitext(file)[1] == '.png':
L.append(os.path.join(root, file))
return L
data_path_fake = "data/train/fake/"
data_path_true = "data/train/true/"
fake_imgs_train = getAllName(data_path_fake)
true_imgs_train = getAllName(data_path_true)
data_path_fake = "data/val/fake/"
data_path_true = "data/val/true/"
fake_imgs_val = getAllName(data_path_fake)
true_imgs_val = getAllName(data_path_true)
batch_size = 16
nb_epoch = 20
img_name_list_train_cate1 = true_imgs_train
img_name_list_train_cate2 = fake_imgs_train
img_name_list_val_cate1 = true_imgs_val
img_name_list_val_cate2 = fake_imgs_val
my_data = myData(batch_size, nb_epoch, img_name_list_train_cate1, img_name_list_train_cate2,
img_name_list_val_cate1, img_name_list_val_cate2)
print(my_data.total_train, my_data.total_val)
my_model = myModel()
my_model.train(my_data)