Skip to content

Commit

Permalink
chore: edit reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Apr 2, 2024
1 parent 7687e9e commit 57953bc
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 9 deletions.
2 changes: 1 addition & 1 deletion stoix/configs/arch/anakin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ evaluation_greedy: False # Evaluate the policy greedily. If True the policy will
# an action which corresponds to the greatest logit. If false, the policy will sample
# from the logits.
num_eval_episodes: 128 # Number of episodes to evaluate per evaluation.
num_evaluation: 20 # Number of evenly spaced evaluations to perform during training.
num_evaluation: 50 # Number of evenly spaced evaluations to perform during training.
absolute_metric: True # Whether the absolute metric should be computed. For more details
# on the absolute metric please see: https://arxiv.org/abs/2209.10485
6 changes: 2 additions & 4 deletions stoix/systems/vpg/ff_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from stoix.utils.logger import LogEvent, StoixLogger
from stoix.utils.multistep import batch_n_step_bootstrapped_returns
from stoix.utils.multistep import batch_discounted_returns
from stoix.utils.total_timestep_checker import check_total_timesteps
from stoix.utils.training import make_learning_rate
from stoix.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -86,9 +86,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transi
v_t = jnp.concatenate([traj_batch.value, last_val[..., jnp.newaxis]], axis=-1)[:, 1:]
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
monte_carlo_returns = batch_n_step_bootstrapped_returns(
r_t, d_t, v_t, config.system.rollout_length
)
monte_carlo_returns = batch_discounted_returns(r_t, d_t, v_t, True, False)

def _actor_loss_fn(
actor_params: FrozenDict,
Expand Down
6 changes: 2 additions & 4 deletions stoix/systems/vpg/ff_reinforce_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from stoix.utils.logger import LogEvent, StoixLogger
from stoix.utils.multistep import batch_n_step_bootstrapped_returns
from stoix.utils.multistep import batch_discounted_returns
from stoix.utils.total_timestep_checker import check_total_timesteps
from stoix.utils.training import make_learning_rate
from stoix.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -86,9 +86,7 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Transi
v_t = jnp.concatenate([traj_batch.value, last_val[..., jnp.newaxis]], axis=-1)[:, 1:]
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
monte_carlo_returns = batch_n_step_bootstrapped_returns(
r_t, d_t, v_t, config.system.rollout_length
)
monte_carlo_returns = batch_discounted_returns(r_t, d_t, v_t, True, False)

def _actor_loss_fn(
actor_params: FrozenDict,
Expand Down
138 changes: 138 additions & 0 deletions stoix/utils/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,141 @@ def batch_retrace_continuous(
stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1
)
return target_tm1 - q_tm1


def batch_lambda_returns(
r_t: chex.Array,
discount_t: chex.Array,
v_t: chex.Array,
lambda_: chex.Numeric = 1.0,
stop_target_gradients: bool = False,
time_major: bool = False,
) -> chex.Array:
"""Estimates a multistep truncated lambda return from a trajectory.
Given a a trajectory of length `T+1`, generated under some policy π, for each
time-step `t` we can estimate a target return `G_t`, by combining rewards,
discounts, and state values, according to a mixing parameter `lambda`.
The parameter `lambda_` mixes the different multi-step bootstrapped returns,
corresponding to accumulating `k` rewards and then bootstrapping using `v_t`.
rₜ₊₁ + γₜ₊₁ vₜ₊₁
rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ vₜ₊₂
rₜ₊₁ + γₜ₊₁ rₜ₊₂ + γₜ₊₁ γₜ₊₂ rₜ₊₂ + γₜ₊₁ γₜ₊₂ γₜ₊₃ vₜ₊₃
The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
In the `on-policy` case, we estimate a return target `G_t` for the same
policy π that was used to generate the trajectory. In this setting the
parameter `lambda_` is typically a fixed scalar factor. Depending
on how values `v_t` are computed, this function can be used to construct
targets for different multistep reinforcement learning updates:
TD(λ): `v_t` contains the state value estimates for each state under π.
Q(λ): `v_t = max(q_t, axis=-1)`, where `q_t` estimates the action values.
Sarsa(λ): `v_t = q_t[..., a_t]`, where `q_t` estimates the action values.
In the `off-policy` case, the mixing factor is a function of state, and
different definitions of `lambda` implement different off-policy corrections:
Per-decision importance sampling: λₜ = λ ρₜ = λ [π(aₜ|sₜ) / μ(aₜ|sₜ)]
V-trace, as instantiated in IMPALA: λₜ = min(1, ρₜ)
Note that the second option is equivalent to applying per-decision importance
sampling, but using an adaptive λ(ρₜ) = min(1/ρₜ, 1), such that the effective
bootstrap parameter at time t becomes λₜ = λ(ρₜ) * ρₜ = min(1, ρₜ).
This is the interpretation used in the ABQ(ζ) algorithm (Mahmood 2017).
Of course this can be augmented to include an additional factor λ. For
instance we could use V-trace with a fixed additional parameter λ = 0.9, by
setting λₜ = 0.9 * min(1, ρₜ) or, alternatively (but not equivalently),
λₜ = min(0.9, ρₜ).
Estimated return are then often used to define a td error, e.g.: ρₜ(Gₜ - vₜ).
See "Reinforcement Learning: An Introduction" by Sutton and Barto.
(http://incompleteideas.net/sutton/book/ebook/node74.html).
Args:
r_t: sequence of rewards rₜ for timesteps t in B x [1, T].
discount_t: sequence of discounts γₜ for timesteps t in B x [1, T].
v_t: sequence of state values estimates under π for timesteps t in B x [1, T].
lambda_: mixing parameter; a scalar or a vector for timesteps t in B x [1, T].
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
time_major: If True, the first dimension of the input tensors is the time
dimension.
Returns:
Multistep lambda returns.
"""

chex.assert_rank([r_t, discount_t, v_t, lambda_], [2, 2, 2, {0, 1, 2}])
chex.assert_type([r_t, discount_t, v_t, lambda_], float)
chex.assert_equal_shape([r_t, discount_t, v_t])

# Swap axes to make time axis the first dimension
if not time_major:
r_t, discount_t, v_t = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), (r_t, discount_t, v_t))

# If scalar make into vector.
lambda_ = jnp.ones_like(discount_t) * lambda_

# Work backwards to compute `G_{T-1}`, ..., `G_0`.
def _body(
acc: chex.Array, xs: Tuple[chex.Array, chex.Array, chex.Array, chex.Array]
) -> Tuple[chex.Array, chex.Array]:
returns, discounts, values, lambda_ = xs
acc = returns + discounts * ((1 - lambda_) * values + lambda_ * acc)
return acc, acc

_, returns = jax.lax.scan(_body, v_t[-1], (r_t, discount_t, v_t, lambda_), reverse=True)

if not time_major:
returns = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), returns)

return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(returns), returns)


def batch_discounted_returns(
r_t: chex.Array,
discount_t: chex.Array,
v_t: chex.Array,
stop_target_gradients: bool = False,
time_major: bool = False,
) -> chex.Array:
"""Calculates a discounted return from a trajectory.
The returns are computed recursively, from `G_{T-1}` to `G_0`, according to:
Gₜ = rₜ₊₁ + γₜ₊₁ Gₜ₊₁.
See "Reinforcement Learning: An Introduction" by Sutton and Barto.
(http://incompleteideas.net/sutton/book/ebook/node61.html).
Args:
r_t: reward sequence at time t.
discount_t: discount sequence at time t.
v_t: value sequence or scalar at time t.
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
Returns:
Discounted returns.
"""
chex.assert_rank([r_t, discount_t, v_t], [2, 2, {0, 1, 2}])
chex.assert_type([r_t, discount_t, v_t], float)

# If scalar make into vector.
bootstrapped_v = jnp.ones_like(discount_t) * v_t
return batch_lambda_returns(
r_t,
discount_t,
bootstrapped_v,
lambda_=1.0,
stop_target_gradients=stop_target_gradients,
time_major=time_major,
)

0 comments on commit 57953bc

Please sign in to comment.