forked from sseung0703/Knowledge_distillation_via_TF2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.py
55 lines (46 loc) · 2.15 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import tensorflow as tf
import numpy as np
import scipy.io as sio
def Dataloader(name, data_path = None):
if name == 'cifar10':
return Cifar10()
if name == 'cifar100':
return Cifar100()
def Cifar10():
from tensorflow.keras.datasets.cifar10 import load_data
(train_images, train_labels), (test_images, test_labels) = load_data()
def pre_processing(is_training = False):
def training(image, *argv):
image = tf.cast(image, tf.float32)
image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
image = tf.image.random_flip_left_right(image)
sz = tf.shape(image)
image = tf.pad(image, [[4,4],[4,4],[0,0]], 'REFLECT')
image = tf.image.random_crop(image,sz)
return [image] + [arg for arg in argv]
def inference(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([113.9,123.0,125.3]))/np.array([66.7,62.1,63.0])
return image, label
return training if is_training else inference
return train_images, train_labels, test_images, test_labels, pre_processing
def Cifar100():
from tensorflow.keras.datasets.cifar100 import load_data
(train_images, train_labels), (test_images, test_labels) = load_data()
def pre_processing(is_training = False):
@tf.function
def training(image, *argv):
image = tf.cast(image, tf.float32)
image = (image-np.array([112,124,129]))/np.array([70,65,68])
image = tf.image.random_flip_left_right(image)
sz = tf.shape(image)
image = tf.pad(image, [[4,4],[4,4],[0,0]], 'REFLECT')
image = tf.image.random_crop(image,sz)
return [image] + [arg for arg in argv]
@tf.function
def inference(image, label):
image = tf.cast(image, tf.float32)
image = (image-np.array([112,124,129]))/np.array([70,65,68])
return image, label
return training if is_training else inference
return train_images, train_labels, test_images, test_labels, pre_processing