-
Notifications
You must be signed in to change notification settings - Fork 5
/
input_data.py
112 lines (91 loc) · 4.36 KB
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# ============================================================================
# -----------------生成图片路径和标签的List------------------------------------
train_dir = 'D:/tensorflow/practicePlus/googLeNet/dataset/input_data'
roses = []
label_roses = []
tulips = []
label_tulips = []
dandelion = []
label_dandelion = []
sunflowers = []
label_sunflowers = []
# step1:获取所有的图片路径名,存放到
# 对应的列表中,同时贴上标签,存放到label列表中。
def get_files(file_dir, ratio):
for file in os.listdir(file_dir + '/roses'):
roses.append(file_dir + '/roses' + '/' + file)
label_roses.append(0)
for file in os.listdir(file_dir + '/tulips'):
tulips.append(file_dir + '/tulips' + '/' + file)
label_tulips.append(1)
for file in os.listdir(file_dir + '/dandelion'):
dandelion.append(file_dir + '/dandelion' + '/' + file)
label_dandelion.append(2)
for file in os.listdir(file_dir + '/sunflowers'):
sunflowers.append(file_dir + '/sunflowers' + '/' + file)
label_sunflowers.append(3)
# step2:对生成的图片路径和标签List做打乱处理
image_list = np.hstack((roses, tulips, dandelion, sunflowers))
label_list = np.hstack((label_roses, label_tulips, label_dandelion, label_sunflowers))
# 利用shuffle打乱顺序
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp)
# 从打乱的temp中再取出list(img和lab)
# image_list = list(temp[:, 0])
# label_list = list(temp[:, 1])
# label_list = [int(i) for i in label_list]
# return image_list, label_list
# 将所有的img和lab转换成list
all_image_list = list(temp[:, 0])
all_label_list = list(temp[:, 1])
# 将所得List分为两部分,一部分用来训练tra,一部分用来测试val
# ratio是测试集的比例
n_sample = len(all_label_list)
n_val = int(math.ceil(n_sample * ratio)) # 测试样本数
n_train = n_sample - n_val # 训练样本数
tra_images = all_image_list[0:n_train]
tra_labels = all_label_list[0:n_train]
tra_labels = [int(float(i)) for i in tra_labels]
val_images = all_image_list[n_train:-1]
val_labels = all_label_list[n_train:-1]
val_labels = [int(float(i)) for i in val_labels]
return tra_images, tra_labels, val_images, val_labels
# ---------------------------------------------------------------------------
# --------------------生成Batch----------------------------------------------
# step1:将上面生成的List传入get_batch() ,转换类型,产生一个输入队列queue,因为img和lab
# 是分开的,所以使用tf.train.slice_input_producer(),然后用tf.read_file()从队列中读取图像
# image_W, image_H, :设置好固定的图像高度和宽度
# 设置batch_size:每个batch要放多少张图片
# capacity:一个队列最大多少
def get_batch(image, label, image_W, image_H, batch_size, capacity):
# 转换类型
image = tf.cast(image, tf.string)
label = tf.cast(label, tf.int32)
# make an input queue
input_queue = tf.train.slice_input_producer([image, label])
label = input_queue[1]
image_contents = tf.read_file(input_queue[0]) # read img from a queue
# step2:将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等。
image = tf.image.decode_jpeg(image_contents, channels=3)
# step3:数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮。
image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
image = tf.image.per_image_standardization(image)
# step4:生成batch
# image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32
# label_batch: 1D tensor [batch_size], dtype=tf.int32
image_batch, label_batch = tf.train.batch([image, label],
batch_size=batch_size,
num_threads=32,
capacity=capacity)
# 重新排列label,行数为[batch_size]
label_batch = tf.reshape(label_batch, [batch_size])
image_batch = tf.cast(image_batch, tf.float32)
return image_batch, label_batch