Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #566 from deepsense-ai/rl_init
Browse files Browse the repository at this point in the history
Initial commit of reinforcement learning module.
  • Loading branch information
lukaszkaiser authored Feb 7, 2018
2 parents 103d057 + 3707499 commit 1c98b8e
Show file tree
Hide file tree
Showing 13 changed files with 930 additions and 0 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
'future',
'gevent',
'gunicorn',
'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues.
'munch',
'numpy',
'requests',
'scipy',
Expand Down
16 changes: 16 additions & 0 deletions tensor2tensor/bin/t2t-rl-trainer
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python
"""t2t-rl-trainer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensor2tensor.bin import t2t_rl_trainer

import tensorflow as tf

def main(argv):
t2t_rl_trainer.main(argv)


if __name__ == "__main__":
tf.app.run()
92 changes: 92 additions & 0 deletions tensor2tensor/bin/t2t_rl_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training of RL agent with PPO algorithm."""

from __future__ import absolute_import

import functools
from munch import Munch
import tensorflow as tf

from tensor2tensor.rl.collect import define_collect
from tensor2tensor.rl.envs.utils import define_batch_env
from tensor2tensor.rl.ppo import define_ppo_epoch


def define_train(policy_lambda, env_lambda, config):
env = env_lambda()
action_space = env.action_space
observation_space = env.observation_space

batch_env = define_batch_env(env_lambda, config["num_agents"])

policy_factory = tf.make_template(
'network',
functools.partial(policy_lambda, observation_space,
action_space, config))

(collect_op, memory) = define_collect(policy_factory, batch_env, config)

with tf.control_dependencies([collect_op]):
ppo_op = define_ppo_epoch(memory, policy_factory, config)

return ppo_op


def main():
train(example_params())


def train(params):
policy_lambda, env_lambda, config = params
ppo_op = define_train(policy_lambda, env_lambda, config)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(config.epochs_num):
sess.run(ppo_op)


def example_params():
from tensor2tensor.rl import networks
config = {}
config['init_mean_factor'] = 0.1
config['init_logstd'] = 0.1
config['policy_layers'] = 100, 100
config['value_layers'] = 100, 100
config['num_agents'] = 30
config['clipping_coef'] = 0.2
config['gae_gamma'] = 0.99
config['gae_lambda'] = 0.95
config['entropy_loss_coef'] = 0.01
config['value_loss_coef'] = 1
config['optimizer'] = tf.train.AdamOptimizer
config['learning_rate'] = 1e-4
config['optimization_epochs'] = 15
config['epoch_length'] = 200
config['epochs_num'] = 2000

config = Munch(config)
return networks.feed_forward_gaussian_fun, pendulum_lambda, config


def pendulum_lambda():
import gym
return gym.make("Pendulum-v0")


if __name__ == '__main__':
main()
10 changes: 10 additions & 0 deletions tensor2tensor/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Tensor2Tensor Reinforcement Learning starter.

The rl package intention is to provide possiblity to run reinforcement
algorithms within Tensorflow's computation graph.

Currently the only supported algorithm is Proximy Policy Optimization - PPO.

## Sample usage - training in Pendulum-v0 environment.

```t2t-rl-trainer```
Empty file added tensor2tensor/rl/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tensor2tensor/rl/collect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Collect trajectories from interactions of agent with environment."""

import tensorflow as tf


def define_collect(policy_factory, batch_env, config):

memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]]
memories_shapes_and_types = [
# observation
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
(memory_shape, tf.float32), # reward
(memory_shape, tf.bool), # done
(memory_shape + batch_env.action_shape, tf.float32), # action
(memory_shape, tf.float32), # pdf
(memory_shape, tf.float32), # value function
]
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
for (shape, dtype) in memories_shapes_and_types]
cumulative_rewards = tf.Variable(
tf.zeros(config.num_agents, tf.float32), trainable=False)

should_reset_var = tf.Variable(True, trainable=False)
reset_op = tf.cond(should_reset_var,
lambda: batch_env.reset(tf.range(config.num_agents)),
lambda: 0.0)
with tf.control_dependencies([reset_op]):
reset_once_op = tf.assign(should_reset_var, False)

with tf.control_dependencies([reset_once_op]):

def step(index, scores_sum, scores_num):
# Note - the only way to ensure making a copy of tensor is to run simple
# operation. We are waiting for tf.copy:
# https://github.com/tensorflow/tensorflow/issues/11186
obs_copy = batch_env.observ + 0
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0))
policy = actor_critic.policy
action = policy.sample()
postprocessed_action = actor_critic.action_postprocessing(action)
simulate_output = batch_env.simulate(postprocessed_action[0, ...])
pdf = policy.prob(action)[0]
with tf.control_dependencies(simulate_output):
reward, done = simulate_output
done = tf.reshape(done, (config.num_agents,))
to_save = [obs_copy, reward, done, action[0, ...], pdf,
actor_critic.value[0]]
save_ops = [tf.scatter_update(memory_slot, index, value)
for memory_slot, value in zip(memory, to_save)]
cumulate_rewards_op = cumulative_rewards.assign_add(reward)
agent_indicies_to_reset = tf.where(done)[:, 0]
with tf.control_dependencies([cumulate_rewards_op]):
scores_sum_delta = tf.reduce_sum(
tf.gather(cumulative_rewards, agent_indicies_to_reset))
scores_num_delta = tf.count_nonzero(done, dtype=tf.int32)
with tf.control_dependencies(save_ops + [scores_sum_delta,
scores_num_delta]):
reset_env_op = batch_env.reset(agent_indicies_to_reset)
reset_cumulative_rewards_op = tf.scatter_update(
cumulative_rewards, agent_indicies_to_reset,
tf.zeros(tf.shape(agent_indicies_to_reset)))
with tf.control_dependencies([reset_env_op,
reset_cumulative_rewards_op]):
return [index + 1, scores_sum + scores_sum_delta,
scores_num + scores_num_delta]

init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
index, scores_sum, scores_num = tf.while_loop(
lambda c, _1, _2: c < config.epoch_length,
step,
init,
parallel_iterations=1,
back_prop=False)
mean_score = tf.cond(tf.greater(scores_num, 0),
lambda: scores_sum / tf.cast(scores_num, tf.float32),
lambda: 0.)
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
with tf.control_dependencies([printing]):
return tf.identity(index), memory
Empty file.
129 changes: 129 additions & 0 deletions tensor2tensor/rl/envs/batch_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# The code was based on Danijar Hafner's code from tf.agents:
# https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py

"""Combine multiple environments to step them in batch."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


class BatchEnv(object):
"""Combine multiple environments to step them in batch."""

def __init__(self, envs, blocking):
"""Combine multiple environments to step them in batch.
To step environments in parallel, environments must support a
`blocking=False` argument to their step and reset functions that makes them
return callables instead to receive the result at a later time.
Args:
envs: List of environments.
blocking: Step environments after another rather than in parallel.
Raises:
ValueError: Environments have different observation or action spaces.
"""
self._envs = envs
self._blocking = blocking
observ_space = self._envs[0].observation_space
if not all(env.observation_space == observ_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')
action_space = self._envs[0].action_space
if not all(env.action_space == action_space for env in self._envs):
raise ValueError('All environments must use the same observation space.')

def __len__(self):
"""Number of combined environments."""
return len(self._envs)

def __getitem__(self, index):
"""Access an underlying environment by index."""
return self._envs[index]

def __getattr__(self, name):
"""Forward unimplemented attributes to one of the original environments.
Args:
name: Attribute that was accessed.
Returns:
Value behind the attribute name one of the wrapped environments.
"""
return getattr(self._envs[0], name)

def step(self, actions):
"""Forward a batch of actions to the wrapped environments.
Args:
actions: Batched action to apply to the environment.
Raises:
ValueError: Invalid actions.
Returns:
Batch of observations, rewards, and done flags.
"""
for index, (env, action) in enumerate(zip(self._envs, actions)):
if not env.action_space.contains(action):
message = 'Invalid action at index {}: {}'
raise ValueError(message.format(index, action))
if self._blocking:
transitions = [
env.step(action)
for env, action in zip(self._envs, actions)]
else:
transitions = [
env.step(action, blocking=False)
for env, action in zip(self._envs, actions)]
transitions = [transition() for transition in transitions]
observs, rewards, dones, infos = zip(*transitions)
observ = np.stack(observs).astype(np.float32)
reward = np.stack(rewards).astype(np.float32)
done = np.stack(dones)
info = tuple(infos)
return observ, reward, done, info

def reset(self, indices=None):
"""Reset the environment and convert the resulting observation.
Args:
indices: The batch indices of environments to reset; defaults to all.
Returns:
Batch of observations.
"""
if indices is None:
indices = np.arange(len(self._envs))
if self._blocking:
observs = [self._envs[index].reset() for index in indices]
else:
observs = [self._envs[index].reset(blocking=False) for index in indices]
observs = [observ() for observ in observs]
observ = np.stack(observs)
observ = observ.astype(np.float32)
return observ

def close(self):
"""Send close messages to the external process and join them."""
for env in self._envs:
if hasattr(env, 'close'):
env.close()
Loading

0 comments on commit 1c98b8e

Please sign in to comment.