This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #566 from deepsense-ai/rl_init
Initial commit of reinforcement learning module.
- Loading branch information
Showing
13 changed files
with
930 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
"""t2t-rl-trainer.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from tensor2tensor.bin import t2t_rl_trainer | ||
|
||
import tensorflow as tf | ||
|
||
def main(argv): | ||
t2t_rl_trainer.main(argv) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Tensor2Tensor Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Training of RL agent with PPO algorithm.""" | ||
|
||
from __future__ import absolute_import | ||
|
||
import functools | ||
from munch import Munch | ||
import tensorflow as tf | ||
|
||
from tensor2tensor.rl.collect import define_collect | ||
from tensor2tensor.rl.envs.utils import define_batch_env | ||
from tensor2tensor.rl.ppo import define_ppo_epoch | ||
|
||
|
||
def define_train(policy_lambda, env_lambda, config): | ||
env = env_lambda() | ||
action_space = env.action_space | ||
observation_space = env.observation_space | ||
|
||
batch_env = define_batch_env(env_lambda, config["num_agents"]) | ||
|
||
policy_factory = tf.make_template( | ||
'network', | ||
functools.partial(policy_lambda, observation_space, | ||
action_space, config)) | ||
|
||
(collect_op, memory) = define_collect(policy_factory, batch_env, config) | ||
|
||
with tf.control_dependencies([collect_op]): | ||
ppo_op = define_ppo_epoch(memory, policy_factory, config) | ||
|
||
return ppo_op | ||
|
||
|
||
def main(): | ||
train(example_params()) | ||
|
||
|
||
def train(params): | ||
policy_lambda, env_lambda, config = params | ||
ppo_op = define_train(policy_lambda, env_lambda, config) | ||
|
||
with tf.Session() as sess: | ||
sess.run(tf.global_variables_initializer()) | ||
for _ in range(config.epochs_num): | ||
sess.run(ppo_op) | ||
|
||
|
||
def example_params(): | ||
from tensor2tensor.rl import networks | ||
config = {} | ||
config['init_mean_factor'] = 0.1 | ||
config['init_logstd'] = 0.1 | ||
config['policy_layers'] = 100, 100 | ||
config['value_layers'] = 100, 100 | ||
config['num_agents'] = 30 | ||
config['clipping_coef'] = 0.2 | ||
config['gae_gamma'] = 0.99 | ||
config['gae_lambda'] = 0.95 | ||
config['entropy_loss_coef'] = 0.01 | ||
config['value_loss_coef'] = 1 | ||
config['optimizer'] = tf.train.AdamOptimizer | ||
config['learning_rate'] = 1e-4 | ||
config['optimization_epochs'] = 15 | ||
config['epoch_length'] = 200 | ||
config['epochs_num'] = 2000 | ||
|
||
config = Munch(config) | ||
return networks.feed_forward_gaussian_fun, pendulum_lambda, config | ||
|
||
|
||
def pendulum_lambda(): | ||
import gym | ||
return gym.make("Pendulum-v0") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Tensor2Tensor Reinforcement Learning starter. | ||
|
||
The rl package intention is to provide possiblity to run reinforcement | ||
algorithms within Tensorflow's computation graph. | ||
|
||
Currently the only supported algorithm is Proximy Policy Optimization - PPO. | ||
|
||
## Sample usage - training in Pendulum-v0 environment. | ||
|
||
```t2t-rl-trainer``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Tensor2Tensor Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Collect trajectories from interactions of agent with environment.""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def define_collect(policy_factory, batch_env, config): | ||
|
||
memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]] | ||
memories_shapes_and_types = [ | ||
# observation | ||
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32), | ||
(memory_shape, tf.float32), # reward | ||
(memory_shape, tf.bool), # done | ||
(memory_shape + batch_env.action_shape, tf.float32), # action | ||
(memory_shape, tf.float32), # pdf | ||
(memory_shape, tf.float32), # value function | ||
] | ||
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False) | ||
for (shape, dtype) in memories_shapes_and_types] | ||
cumulative_rewards = tf.Variable( | ||
tf.zeros(config.num_agents, tf.float32), trainable=False) | ||
|
||
should_reset_var = tf.Variable(True, trainable=False) | ||
reset_op = tf.cond(should_reset_var, | ||
lambda: batch_env.reset(tf.range(config.num_agents)), | ||
lambda: 0.0) | ||
with tf.control_dependencies([reset_op]): | ||
reset_once_op = tf.assign(should_reset_var, False) | ||
|
||
with tf.control_dependencies([reset_once_op]): | ||
|
||
def step(index, scores_sum, scores_num): | ||
# Note - the only way to ensure making a copy of tensor is to run simple | ||
# operation. We are waiting for tf.copy: | ||
# https://github.com/tensorflow/tensorflow/issues/11186 | ||
obs_copy = batch_env.observ + 0 | ||
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0)) | ||
policy = actor_critic.policy | ||
action = policy.sample() | ||
postprocessed_action = actor_critic.action_postprocessing(action) | ||
simulate_output = batch_env.simulate(postprocessed_action[0, ...]) | ||
pdf = policy.prob(action)[0] | ||
with tf.control_dependencies(simulate_output): | ||
reward, done = simulate_output | ||
done = tf.reshape(done, (config.num_agents,)) | ||
to_save = [obs_copy, reward, done, action[0, ...], pdf, | ||
actor_critic.value[0]] | ||
save_ops = [tf.scatter_update(memory_slot, index, value) | ||
for memory_slot, value in zip(memory, to_save)] | ||
cumulate_rewards_op = cumulative_rewards.assign_add(reward) | ||
agent_indicies_to_reset = tf.where(done)[:, 0] | ||
with tf.control_dependencies([cumulate_rewards_op]): | ||
scores_sum_delta = tf.reduce_sum( | ||
tf.gather(cumulative_rewards, agent_indicies_to_reset)) | ||
scores_num_delta = tf.count_nonzero(done, dtype=tf.int32) | ||
with tf.control_dependencies(save_ops + [scores_sum_delta, | ||
scores_num_delta]): | ||
reset_env_op = batch_env.reset(agent_indicies_to_reset) | ||
reset_cumulative_rewards_op = tf.scatter_update( | ||
cumulative_rewards, agent_indicies_to_reset, | ||
tf.zeros(tf.shape(agent_indicies_to_reset))) | ||
with tf.control_dependencies([reset_env_op, | ||
reset_cumulative_rewards_op]): | ||
return [index + 1, scores_sum + scores_sum_delta, | ||
scores_num + scores_num_delta] | ||
|
||
init = [tf.constant(0), tf.constant(0.0), tf.constant(0)] | ||
index, scores_sum, scores_num = tf.while_loop( | ||
lambda c, _1, _2: c < config.epoch_length, | ||
step, | ||
init, | ||
parallel_iterations=1, | ||
back_prop=False) | ||
mean_score = tf.cond(tf.greater(scores_num, 0), | ||
lambda: scores_sum / tf.cast(scores_num, tf.float32), | ||
lambda: 0.) | ||
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ") | ||
with tf.control_dependencies([printing]): | ||
return tf.identity(index), memory |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The Tensor2Tensor Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# The code was based on Danijar Hafner's code from tf.agents: | ||
# https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py | ||
|
||
"""Combine multiple environments to step them in batch.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
|
||
|
||
class BatchEnv(object): | ||
"""Combine multiple environments to step them in batch.""" | ||
|
||
def __init__(self, envs, blocking): | ||
"""Combine multiple environments to step them in batch. | ||
To step environments in parallel, environments must support a | ||
`blocking=False` argument to their step and reset functions that makes them | ||
return callables instead to receive the result at a later time. | ||
Args: | ||
envs: List of environments. | ||
blocking: Step environments after another rather than in parallel. | ||
Raises: | ||
ValueError: Environments have different observation or action spaces. | ||
""" | ||
self._envs = envs | ||
self._blocking = blocking | ||
observ_space = self._envs[0].observation_space | ||
if not all(env.observation_space == observ_space for env in self._envs): | ||
raise ValueError('All environments must use the same observation space.') | ||
action_space = self._envs[0].action_space | ||
if not all(env.action_space == action_space for env in self._envs): | ||
raise ValueError('All environments must use the same observation space.') | ||
|
||
def __len__(self): | ||
"""Number of combined environments.""" | ||
return len(self._envs) | ||
|
||
def __getitem__(self, index): | ||
"""Access an underlying environment by index.""" | ||
return self._envs[index] | ||
|
||
def __getattr__(self, name): | ||
"""Forward unimplemented attributes to one of the original environments. | ||
Args: | ||
name: Attribute that was accessed. | ||
Returns: | ||
Value behind the attribute name one of the wrapped environments. | ||
""" | ||
return getattr(self._envs[0], name) | ||
|
||
def step(self, actions): | ||
"""Forward a batch of actions to the wrapped environments. | ||
Args: | ||
actions: Batched action to apply to the environment. | ||
Raises: | ||
ValueError: Invalid actions. | ||
Returns: | ||
Batch of observations, rewards, and done flags. | ||
""" | ||
for index, (env, action) in enumerate(zip(self._envs, actions)): | ||
if not env.action_space.contains(action): | ||
message = 'Invalid action at index {}: {}' | ||
raise ValueError(message.format(index, action)) | ||
if self._blocking: | ||
transitions = [ | ||
env.step(action) | ||
for env, action in zip(self._envs, actions)] | ||
else: | ||
transitions = [ | ||
env.step(action, blocking=False) | ||
for env, action in zip(self._envs, actions)] | ||
transitions = [transition() for transition in transitions] | ||
observs, rewards, dones, infos = zip(*transitions) | ||
observ = np.stack(observs).astype(np.float32) | ||
reward = np.stack(rewards).astype(np.float32) | ||
done = np.stack(dones) | ||
info = tuple(infos) | ||
return observ, reward, done, info | ||
|
||
def reset(self, indices=None): | ||
"""Reset the environment and convert the resulting observation. | ||
Args: | ||
indices: The batch indices of environments to reset; defaults to all. | ||
Returns: | ||
Batch of observations. | ||
""" | ||
if indices is None: | ||
indices = np.arange(len(self._envs)) | ||
if self._blocking: | ||
observs = [self._envs[index].reset() for index in indices] | ||
else: | ||
observs = [self._envs[index].reset(blocking=False) for index in indices] | ||
observs = [observ() for observ in observs] | ||
observ = np.stack(observs) | ||
observ = observ.astype(np.float32) | ||
return observ | ||
|
||
def close(self): | ||
"""Send close messages to the external process and join them.""" | ||
for env in self._envs: | ||
if hasattr(env, 'close'): | ||
env.close() |
Oops, something went wrong.