-
Notifications
You must be signed in to change notification settings - Fork 5
/
her.py
196 lines (161 loc) · 7.47 KB
/
her.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import click
import numpy as np
import json
from mpi4py import MPI
from baselines import logger
from baselines.common import set_global_seeds, tf_util
from baselines.common.mpi_moments import mpi_moments
import baselines.her.experiment.config as config
from baselines.her.rollout import RolloutWorker
def mpi_average(value):
if not isinstance(value, list):
value = [value]
if not any(value):
value = [0.]
return mpi_moments(np.array(value))[0]
def train(*, policy, rollout_worker, evaluator,
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
save_path, demo_file, **kwargs):
rank = MPI.COMM_WORLD.Get_rank()
if save_path:
latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
best_policy_path = os.path.join(save_path, 'policy_best.pkl')
periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
logger.info("Training...")
best_success_rate = -1
if policy.bc_loss == 1: policy.init_demo_buffer(demo_file) #initialize demo buffer if training with demonstrations
# num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
for epoch in range(n_epochs):
# train
rollout_worker.clear_history()
for _ in range(n_cycles):
episode = rollout_worker.generate_rollouts()
policy.store_episode(episode)
for _ in range(n_batches):
policy.train()
policy.update_target_net()
# test
evaluator.clear_history()
for _ in range(n_test_rollouts):
evaluator.generate_rollouts()
# record logs
logger.record_tabular('epoch', epoch)
for key, val in evaluator.logs('test'):
logger.record_tabular(key, mpi_average(val))
for key, val in rollout_worker.logs('train'):
logger.record_tabular(key, mpi_average(val))
for key, val in policy.logs():
logger.record_tabular(key, mpi_average(val))
if rank == 0:
logger.dump_tabular()
# save the policy if it's better than the previous ones
success_rate = mpi_average(evaluator.current_success_rate())
if rank == 0 and success_rate >= best_success_rate and save_path:
best_success_rate = success_rate
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
evaluator.save_policy(best_policy_path)
evaluator.save_policy(latest_policy_path)
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
policy_path = periodic_policy_path.format(epoch)
logger.info('Saving periodic policy to {} ...'.format(policy_path))
evaluator.save_policy(policy_path)
# make sure that different threads have different seeds
local_uniform = np.random.uniform(size=(1,))
root_uniform = local_uniform.copy()
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
if rank != 0:
assert local_uniform[0] != root_uniform[0]
return policy
def learn(*, network, env, total_timesteps,
seed=None,
eval_env=None,
replay_strategy='future',
policy_save_interval=5,
clip_return=True,
demo_file=None,
override_params=None,
load_path=None,
save_path=None,
**kwargs
):
raise TypeError('HER is not supported in TF2 branch yet, we are still working on it.')
override_params = override_params or {}
if MPI is not None:
rank = MPI.COMM_WORLD.Get_rank()
num_cpu = MPI.COMM_WORLD.Get_size()
# Seed everything.
rank_seed = seed + 1000000 * rank if seed is not None else None
set_global_seeds(rank_seed)
# Prepare params.
params = config.DEFAULT_PARAMS
env_name = env.spec.id
params['env_name'] = env_name
params['replay_strategy'] = replay_strategy
if env_name in config.DEFAULT_ENV_PARAMS:
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
params.update(**override_params) # makes it possible to override any parameter
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
json.dump(params, f)
params = config.prepare_params(params)
params['rollout_batch_size'] = env.num_envs
if demo_file is not None:
params['bc_loss'] = 1
params.update(kwargs)
config.log_params(params, logger=logger)
if num_cpu == 1:
logger.warn()
logger.warn('*** Warning ***')
logger.warn(
'You are running HER with just a single MPI worker. This will work, but the ' +
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
'are looking to reproduce those results, be aware of this. Please also refer to ' +
'https://github.com/openai/baselines/issues/314 for further details.')
logger.warn('****************')
logger.warn()
dims = config.configure_dims(params)
print('dims are {}'.format(dims))
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
#TODO: load path
# if load_path is not None:
# tf_util.load_variables(load_path)
rollout_params = {
'exploit': False,
'use_target_net': False,
'use_demo_states': True,
'compute_Q': False,
'T': params['T'],
}
eval_params = {
'exploit': True,
'use_target_net': params['test_with_polyak'],
'use_demo_states': False,
'compute_Q': True,
'T': params['T'],
}
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
rollout_params[name] = params[name]
eval_params[name] = params[name]
eval_env = eval_env or env
rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
evaluator = RolloutWorker(eval_env, policy, dims, logger, **eval_params)
n_cycles = params['n_cycles']
n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_batch_size
return train(
save_path=save_path, policy=policy, rollout_worker=rollout_worker,
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
policy_save_interval=policy_save_interval, demo_file=demo_file)
@click.command()
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
@click.option('--total_timesteps', type=int, default=int(5e5), help='the number of timesteps to run')
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
def main(**kwargs):
learn(**kwargs)
if __name__ == '__main__':
main()