-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
70 lines (53 loc) · 2.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import tensorflow as tf
import tensorflow_datasets as tfds
def read_png(filename):
"""Loads a PNG image file."""
string = tf.io.read_file(filename)
return tf.image.decode_image(string, channels=3)
def write_png(filename, image):
"""Saves an image to a PNG file."""
string = tf.image.encode_png(image)
tf.io.write_file(filename, string)
def list_of_paths(data_dir):
list_paths = []
for path in os.listdir(data_dir):
full_path = os.path.join(data_dir, path)
if os.path.isfile(full_path):
list_paths += [full_path]
list_paths.sort()
return list_paths
def check_image_size(image, patchsize):
shape = tf.shape(image)
return shape[0] >= patchsize and shape[1] >= patchsize and shape[-1] == 3
def crop_image(image, patchsize):
image = tf.image.random_crop(image, (patchsize, patchsize, 3))
return tf.cast(image, tf.keras.mixed_precision.global_policy().compute_dtype)
def get_dataset(name, split, args):
"""Creates input data pipeline from a TF Datasets dataset."""
with tf.device("/cpu:0"):
dataset = tfds.load(name, split=split, shuffle_files=True)
if split == "train":
dataset = dataset.repeat()
dataset = dataset.filter(
lambda x: check_image_size(x["image"], args.patchsize))
dataset = dataset.map(
lambda x: crop_image(x["image"], args.patchsize))
dataset = dataset.batch(args.batchsize, drop_remainder=True)
return dataset
def get_custom_dataset(split, args):
"""Creates input data pipeline from custom PNG images."""
with tf.device("/cpu:0"):
files = glob.glob(args.train_glob)
if not files:
raise RuntimeError(f"No training images found with glob "
f"'{args.train_glob}'.")
dataset = tf.data.Dataset.from_tensor_slices(files)
dataset = dataset.shuffle(len(files), reshuffle_each_iteration=True)
if split == "train":
dataset = dataset.repeat()
dataset = dataset.map(
lambda x: crop_image(read_png(x), args.patchsize),
num_parallel_calls=args.preprocess_threads)
dataset = dataset.batch(args.batchsize, drop_remainder=True)
return dataset