Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify create_cls_trainval_lists.py #3066

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 38 additions & 33 deletions ppcls/utils/create_cls_trainval_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,16 @@

import argparse
import os
import random
from random import shuffle
import string


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', type=str, default='./data')
parser.add_argument('--save_img_list_path', type=str, default='train.txt')
parser.add_argument(
'--train', action='store_true', help='Create train list.')
parser.add_argument('--val', action='store_true', help='Create val list.')

parser.add_argument('--train_list_rate', type=int, default=80)
parser.add_argument('--val_list_rate', type=int, default=20)
parser.add_argument('--test_list_rate', type=int, default=0)
args = parser.parse_args()
return args

Expand Down Expand Up @@ -55,11 +53,10 @@ def main(args):
if os.path.isdir(os.path.join(args.dataset_path, label))
]

if not os.path.exists(
os.path.join(os.path.dirname(args.dataset_path),
'label.txt')) and args.val:
raise Exception(
'The label file is not exist. Please set "--train" first.')
sum_rate = args.train_list_rate + args.val_list_rate + args.test_list_rate
if sum_rate != 100:
raise Exception("训练集、验证集、测试集比例之和需要等于100,请修改后重试")
tags = ["train", "val", "test"]

for index, label_name in enumerate(label_name_list):
for root, dirs, files in os.walk(
Expand All @@ -69,42 +66,50 @@ def main(args):
img_path = os.path.relpath(
os.path.join(root, single_file),
os.path.dirname(args.dataset_path))
if args.val:
class_id_map = parse_class_id_map(
os.path.join(
os.path.dirname(args.dataset_path),
'label.txt'))
img_list.append(
f'{img_path} {class_id_map[label_name]}')
else:
img_list.append(f'{img_path} {index}')
img_list.append(f'{img_path} {index}')
else:
print(
f'WARNING: File {os.path.join(root, single_file)} end with {single_file.split(".")[-1]} is not supported.'
)
label_list.append(f'{index} {label_name}')

shuffle(img_list)
if len(img_list) == 0:
raise Exception(f"Not found any images file in {args.dataset_path}.")

start = 0
image_num = len(img_list)
rate_list = [args.train_list_rate, args.val_list_rate, args.test_list_rate]

for i, tag in enumerate(tags):
rate = rate_list[i]
if rate == 0:
continue
if rate > 100 or rate < 0:
return f"{tag} 数据集的比例应该在0~100之间."

end = start + round(image_num * rate / 100)
if sum(rate_list[i + 1:]) == 0:
end = image_num

txt_file = os.path.abspath(
os.path.join(os.path.dirname(args.dataset_path), tag + '.txt'))
with open(txt_file, 'w') as f:
m = 0
for id in range(start, end):
m += 1
f.write('\n' + img_list[id])
print(f'Already save label.txt in {txt_file}.')
start = end

with open(
os.path.join(
os.path.dirname(args.dataset_path), args.save_img_list_path),
os.path.join(os.path.dirname(args.dataset_path), 'label.txt'),
'w') as f:
f.write('\n'.join(img_list))
f.write('\n'.join(label_list))
print(
f'Already save {args.save_img_list_path} in {os.path.join(os.path.dirname(args.dataset_path), args.save_img_list_path)}.'
f'Already save label.txt in {os.path.abspath(os.path.join(os.path.dirname(args.dataset_path), "label.txt"))}.'
)

if not args.val:
with open(
os.path.join(os.path.dirname(args.dataset_path), 'label.txt'),
'w') as f:
f.write('\n'.join(label_list))
print(
f'Already save label.txt in {os.path.join(os.path.dirname(args.dataset_path), "label.txt")}.'
)


if __name__ == '__main__':
args = parse_args()
Expand Down