Skip to content

Commit

Permalink
Fix cycle gan: (#1829) (#1835)
Browse files Browse the repository at this point in the history
1. use_cudnn=False
2. fix saving checkponint
3. using compiled program
  • Loading branch information
wanghaoshuang authored Mar 5, 2019
1 parent a8976f5 commit 8875bb2
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 22 deletions.
4 changes: 2 additions & 2 deletions fluid/PaddleCV/gan/cycle_gan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ env CUDA_VISIBLE_DEVICES=0 python train.py

```
env CUDA_VISIBLE_DEVICE=0 python infer.py \
--init_model="models/1" --input="./data/inputA/*" \
--output="./output"
--init_model="checkpoints/1" --input="./data/inputA/*" \
--input_style A --output="./output"
```

训练150轮的模型预测效果如图2和图3所示:
Expand Down
4 changes: 3 additions & 1 deletion fluid/PaddleCV/gan/cycle_gan/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def infer(args):
data_shape = [-1, 3, 256, 256]
input = fluid.layers.data(name='input', shape=data_shape, dtype='float32')
if args.input_style == "A":
model_name = 'g_a'
fake = build_generator_resnet_9blocks(input, name="g_A")
elif args.input_style == "B":
model_name = 'g_b'
fake = build_generator_resnet_9blocks(input, name="g_B")
else:
raise "Input with style [%s] is not supported." % args.input_style
Expand All @@ -37,7 +39,7 @@ def infer(args):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.init_model)
fluid.io.load_persistables(exe, args.init_model + "/" + model_name)

if not os.path.exists(args.output):
os.makedirs(args.output)
Expand Down
6 changes: 5 additions & 1 deletion fluid/PaddleCV/gan/cycle_gan/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import numpy as np
import os

use_cudnn = True
# cudnn is not better when batch size is 1.
use_cudnn = False
if 'ce_mode' in os.environ:
use_cudnn = False


def cal_padding(img_size, stride, filter_size, dilation=1):
"""Calculate padding size."""
valid_filter_size = dilation * (filter_size - 1) + 1
Expand All @@ -18,6 +20,8 @@ def cal_padding(img_size, stride, filter_size, dilation=1):


def instance_norm(input, name=None):
# TODO([email protected]): Check the accuracy when using fluid.layers.layer_norm.
# return fluid.layers.layer_norm(input, begin_norm_axis=2)
helper = fluid.layer_helper.LayerHelper("instance_norm", **locals())
dtype = helper.input_dtype()
epsilon = 1e-5
Expand Down
50 changes: 32 additions & 18 deletions fluid/PaddleCV/gan/cycle_gan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from utility import add_arguments, print_arguments, ImagePool
from trainer import *


parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
Expand All @@ -36,7 +35,7 @@
def train(args):

max_images_num = data_reader.max_images_num()
shuffle=True
shuffle = True
if args.run_ce:
np.random.seed(10)
fluid.default_startup_program().random_seed = 90
Expand Down Expand Up @@ -66,9 +65,11 @@ def train(args):
exe.run(fluid.default_startup_program())
A_pool = ImagePool()
B_pool = ImagePool()

A_reader = paddle.batch(data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(data_reader.b_reader(shuffle=shuffle), args.batch_size)()

A_reader = paddle.batch(
data_reader.a_reader(shuffle=shuffle), args.batch_size)()
B_reader = paddle.batch(
data_reader.b_reader(shuffle=shuffle), args.batch_size)()
if not args.run_ce:
A_test_reader = data_reader.a_test_reader()
B_test_reader = data_reader.b_test_reader()
Expand Down Expand Up @@ -119,13 +120,13 @@ def checkpoints(epoch):
if not os.path.exists(out_path):
os.makedirs(out_path)
fluid.io.save_persistables(
exe, out_path + "/g_a", main_program=g_A_trainer.program, filename="params")
exe, out_path + "/g_a", main_program=g_A_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/g_b", main_program=g_B_trainer.program, filename="params")
exe, out_path + "/g_b", main_program=g_B_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/d_a", main_program=d_A_trainer.program, filename="params")
exe, out_path + "/d_a", main_program=d_A_trainer.program)
fluid.io.save_persistables(
exe, out_path + "/d_b", main_program=d_B_trainer.program, filename="params")
exe, out_path + "/d_b", main_program=d_B_trainer.program)
print("saved checkpoint to {}".format(out_path))
sys.stdout.flush()

Expand All @@ -144,8 +145,21 @@ def init_model():

if args.init_model:
init_model()
losses=[[], []]
losses = [[], []]
t_time = 0

g_A_trainer_program = fluid.CompiledProgram(
g_A_trainer.program).with_data_parallel(
loss_name=g_A_trainer.g_loss_A.name)
g_B_trainer_program = fluid.CompiledProgram(
g_B_trainer.program).with_data_parallel(
loss_name=g_B_trainer.g_loss_B.name)
d_B_trainer_program = fluid.CompiledProgram(
d_B_trainer.program).with_data_parallel(
loss_name=d_B_trainer.d_loss_B.name)
d_A_trainer_program = fluid.CompiledProgram(
d_A_trainer.program).with_data_parallel(
loss_name=d_A_trainer.d_loss_A.name)
for epoch in range(args.epoch):
batch_id = 0
for i in range(max_images_num):
Expand All @@ -158,7 +172,7 @@ def init_model():
s_time = time.time()
# optimize the g_A network
g_A_loss, fake_B_tmp = exe.run(
g_A_trainer.program,
g_A_trainer_program,
fetch_list=[g_A_trainer.g_loss_A, g_A_trainer.fake_B],
feed={"input_A": tensor_A,
"input_B": tensor_B})
Expand All @@ -167,14 +181,14 @@ def init_model():

# optimize the d_B network
d_B_loss = exe.run(
d_B_trainer.program,
d_B_trainer_program,
fetch_list=[d_B_trainer.d_loss_B],
feed={"input_B": tensor_B,
"fake_pool_B": fake_pool_B})[0]

# optimize the g_B network
g_B_loss, fake_A_tmp = exe.run(
g_B_trainer.program,
g_B_trainer_program,
fetch_list=[g_B_trainer.g_loss_B, g_B_trainer.fake_A],
feed={"input_A": tensor_A,
"input_B": tensor_B})
Expand All @@ -183,16 +197,16 @@ def init_model():

# optimize the d_A network
d_A_loss = exe.run(
d_A_trainer.program,
d_A_trainer_program,
fetch_list=[d_A_trainer.d_loss_A],
feed={"input_A": tensor_A,
"fake_pool_A": fake_pool_A})[0]
batch_time = time.time() - s_time
t_time += batch_time
print("epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(
epoch, batch_id, g_A_loss[0], d_B_loss[0], g_B_loss[0],
d_A_loss[0], batch_time))
print(
"epoch{}; batch{}; g_A_loss: {}; d_B_loss: {}; g_B_loss: {}; d_A_loss: {}; "
"Batch_time_cost: {:.2f}".format(epoch, batch_id, g_A_loss[
0], d_B_loss[0], g_B_loss[0], d_A_loss[0], batch_time))
losses[0].append(g_A_loss[0])
losses[1].append(d_A_loss[0])
sys.stdout.flush()
Expand Down

0 comments on commit 8875bb2

Please sign in to comment.