Skip to content

Commit

Permalink
add unittest for drex
Browse files Browse the repository at this point in the history
  • Loading branch information
ruoyuGao committed May 4, 2023
1 parent ff4de47 commit a6baf30
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
37 changes: 37 additions & 0 deletions ding/entry/tests/test_serial_entry_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_ngu_config import cartpole_ngu_config, cartpole_ngu_create_config
from dizoo.classic_control.cartpole.config.cartpole_drex_dqn_config import cartpole_drex_dqn_config, cartpole_drex_dqn_create_config
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
serial_pipeline_reward_model_onpolicy
from ding.entry.application_entry_trex_collect_data import trex_collecting_data
from ding.entry.application_entry_drex_collect_data import drex_collecting_data

cfg = [
{
Expand Down Expand Up @@ -131,3 +133,38 @@ def test_trex():
assert False, "pipeline fail"
finally:
os.popen('rm -rf test_serial_pipeline_trex*')


@pytest.mark.unittest
def test_drex():
exp_name = 'test_serial_pipeline_drex_expert'
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
config[0].exp_name = exp_name
expert_policy = serial_pipeline(config, seed=0)

exp_name = 'test_serial_pipeline_drex_collect'
config = [deepcopy(cartpole_drex_dqn_config), deepcopy(cartpole_drex_dqn_create_config)]
config[0].exp_name = exp_name
config[0].reward_model.exp_name = exp_name
config[0].reward_model.expert_model_path = 'test_serial_pipeline_drex_expert/ckpt/ckpt_best.pth.tar'
config[0].reward_model.reward_model_path = 'test_serial_pipeline_drex_collect/cartpole.params'
config[0].reward_model.offline_data_path = 'test_serial_pipeline_drex_collect'
config[0].reward_model.checkpoint_max = 100
config[0].reward_model.checkpoint_step = 100
config[0].reward_model.num_snippets = 100

args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
args.cfg[0].policy.collect.n_episode = 8
del args.cfg[0].policy.collect.n_sample
args.cfg[0].bc_iteration = 1000 # for unittest
args.cfg[1].policy.type = 'bc'
drex_collecting_data(args=args)
try:
serial_pipeline_reward_model_offpolicy(
config, seed=0, max_train_iter=1, pretrain_reward=True, cooptrain_reward=False
)
except Exception:
assert False, "pipeline fail"
finally:
os.popen('rm -rf test_serial_pipeline_drex*')
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
update_per_collect=5,
batch_size=64,
learning_rate=0.001,
learner=dict(hook=dict(save_ckpt_after_iter=1000)),
),
collect=dict(n_sample=8),
eval=dict(evaluator=dict(eval_freq=40, )),
Expand Down

0 comments on commit a6baf30

Please sign in to comment.