forked from JeasunLok/ResNet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_create.py
45 lines (39 loc) · 1.87 KB
/
dataset_create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
from os import getcwd
def get_classes(classes_path):
with open(classes_path, encoding='utf-8') as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names, len(class_names)
#-------------------------------------------------------------------#
# classes_path 指向model_data下的txt,与自己训练的数据集相关
# 训练前一定要修改classes_path,使其对应自己的数据集
# txt文件中是自己所要去区分的种类
# 与训练和预测所用的classes_path一致即可
#-------------------------------------------------------------------#
classes_path = 'dataset/cls_classes.txt'
#-------------------------------------------------------#
# datasets_path 指向数据集所在的路径
#-------------------------------------------------------#
datasets_path = 'images'
sets = ["train", "val", "test"]
classes, _ = get_classes(classes_path)
print(classes)
if __name__ == "__main__":
for se in sets:
list_file = open(os.path.join("images", 'cls_' + se + '.txt'), 'w')
datasets_path_t = os.path.join(datasets_path, se)
types_name = os.listdir(datasets_path_t)
for type_name in types_name:
if type_name not in classes:
continue
cls_id = classes.index(type_name)
photos_path = os.path.join(datasets_path_t, type_name)
photos_name = os.listdir(photos_path)
for photo_name in photos_name:
_, postfix = os.path.splitext(photo_name)
if postfix not in ['.jpg', '.png', '.jpeg', '.tif']:
continue
list_file.write(str(cls_id) + ";" + '%s'%(os.path.join(photos_path, photo_name)))
list_file.write('\n')
list_file.close()