Skip to content

Commit

Permalink
Merge pull request #53 from EdanToledo/chore/make_standardise_advanta…
Browse files Browse the repository at this point in the history
…ge_a_choice

Chore/make standardise advantage a choice
  • Loading branch information
EdanToledo authored Apr 2, 2024
2 parents 57953bc + 9d58ba6 commit 78fe0f6
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 8 deletions.
1 change: 1 addition & 0 deletions stoix/configs/system/ff_awr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ decay_learning_rates: False # Whether learning rates should be linearly decayed
gae_lambda: 0.95 # The lambda parameter for the generalized advantage estimator.
beta: 0.05 # The temperature of the exponentiated advantage weights.
weight_clip: 20.0 # The maximum absolute value of the advantage weights.
standardize_advantages: False # Whether to standardize the advantages.
1 change: 1 addition & 0 deletions stoix/configs/system/ff_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ ent_coef: 0.001 # Entropy regularisation term for loss function.
vf_coef: 1.0 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
standardize_advantages: True # Whether to standardize the advantages.
alpha : 2.0
beta : 0.6
1 change: 1 addition & 0 deletions stoix/configs/system/ff_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ ent_coef: 0.001 # Entropy regularisation term for loss function.
vf_coef: 1.0 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
standardize_advantages: True # Whether to standardize the advantages.
1 change: 1 addition & 0 deletions stoix/configs/system/rec_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ent_coef: 0.01 # Entropy regularisation term for loss function.
vf_coef: 0.5 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
standardize_advantages: True # Whether to standardize the advantages.

# --- Recurrent hyperparameters ---
recurrent_chunk_size: ~ # The size of the chunks in which the recurrent sequences are divided during the training process.
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/awr/ff_awr.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def _actor_loss_fn(
r_t = sequence.reward[:, :-1]
d_t = (1 - sequence.done.astype(jnp.float32)[:, :-1]) * config.system.gamma
advantages, _ = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=False
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=False,
standardize_advantages=config.system.standardize_advantages,
)
weights = jnp.exp(advantages / config.system.beta)
weights = jnp.minimum(weights, config.system.weight_clip)
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/awr/ff_awr_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,12 @@ def _actor_loss_fn(
r_t = sequence.reward[:, :-1]
d_t = (1 - sequence.done.astype(jnp.float32)[:, :-1]) * config.system.gamma
advantages, _ = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=False
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=False,
standardize_advantages=config.system.standardize_advantages,
)
weights = jnp.exp(advantages / config.system.beta)
weights = jnp.minimum(weights, config.system.weight_clip)
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/ppo/ff_dpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
advantages, targets = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=True
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=True,
standardize_advantages=config.system.standardize_advantages,
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/ppo/ff_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
advantages, targets = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=True
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=True,
standardize_advantages=config.system.standardize_advantages,
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/ppo/ff_ppo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
advantages, targets = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=True
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=True,
standardize_advantages=config.system.standardize_advantages,
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
Expand Down
7 changes: 6 additions & 1 deletion stoix/systems/ppo/rec_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def _env_step(
d_t = 1.0 - traj_batch.done.astype(jnp.float32)
d_t = (d_t * config.system.gamma).astype(jnp.float32)
advantages, targets = batch_truncated_generalized_advantage_estimation(
r_t, d_t, config.system.gae_lambda, v_t, time_major=True
r_t,
d_t,
config.system.gae_lambda,
v_t,
time_major=True,
standardize_advantages=config.system.standardize_advantages,
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
Expand Down
2 changes: 0 additions & 2 deletions stoix/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def ppo_loss(
pi_log_prob_t: chex.Array, b_pi_log_prob_t: chex.Array, gae_t: chex.Array, epsilon: float
) -> chex.Array:
ratio = jnp.exp(pi_log_prob_t - b_pi_log_prob_t)
gae_t = (gae_t - gae_t.mean()) / (gae_t.std() + 1e-8)
loss_actor1 = ratio * gae_t
loss_actor2 = (
jnp.clip(
Expand All @@ -38,7 +37,6 @@ def dpo_loss(
beta: float,
) -> chex.Array:
log_diff = pi_log_prob_t - b_pi_log_prob_t
gae_t = (gae_t - gae_t.mean()) / (gae_t.std() + 1e-8)
ratio = jnp.exp(log_diff)
is_pos = (gae_t >= 0.0).astype(jnp.float32)
r1 = ratio - 1.0
Expand Down
5 changes: 5 additions & 0 deletions stoix/utils/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def batch_truncated_generalized_advantage_estimation(
values: chex.Array,
stop_target_gradients: bool = True,
time_major: bool = False,
standardize_advantages: bool = False,
) -> Tuple[chex.Array, chex.Array]:
"""Computes truncated generalized advantage estimates for a sequence length k.
Expand All @@ -39,6 +40,7 @@ def batch_truncated_generalized_advantage_estimation(
to targets.
time_major: If True, the first dimension of the input tensors is the time
dimension.
standardize_advantages: If True, standardize the advantages.
Returns:
Multistep truncated generalized advantage estimation at times [0, k-1].
Expand Down Expand Up @@ -84,6 +86,9 @@ def _body(
lambda x: jax.lax.stop_gradient(x), (advantage_t, target_values)
)

if standardize_advantages:
advantage_t = jax.nn.standardize(advantage_t, axis=(0, 1))

return advantage_t, target_values


Expand Down

0 comments on commit 78fe0f6

Please sign in to comment.