Skip to content

Commit

Permalink
fix: major changes to the ratelimiter configs and a separate buffer p…
Browse files Browse the repository at this point in the history
…er acotr
  • Loading branch information
Louay-Ben-nessir committed Dec 4, 2024
1 parent 7e44d15 commit 6c8452f
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 249 deletions.
10 changes: 5 additions & 5 deletions mava/configs/system/q_learning/rec_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ add_agent_id: True
min_buffer_size: 32
update_batch_size: 1 # Number of vectorised gradient updates per device.

rollout_length: 2 # Number of environment steps per vectorised environment.
rollout_length: 2 # Number of environment steps per vectorised enviro²nment.
epochs: 2 # Number of learn epochs per training data batch.

# sizes
buffer_size: 5000 # size of the replay buffer. Note: total size is this * num_devices
buffer_size: 1000 # size of the replay buffer. Note: total size is this * num_devices
sample_batch_size: 32 # size of training data batch sampled from the buffer
sample_sequence_length: 20 # N transitions are sampled, giving N - 1 complete data points
sample_sequence_length: 32 # 20 transitions are sampled, giving 19 complete data points

# learning rates
q_lr: 3e-4 # the learning rate of the Q network network optimizer
Expand All @@ -33,5 +33,5 @@ eps_min: 0.05
eps_decay: 1e5

# --- Sebulba parameters ---
samples_per_insert : 2 # The average number of times the learner should sample each item in the replay buffer.
sample_per_inser_tolerance : 6 # Maximum size of the "error" before calls should be blocked.
data_sample_mean: 150 # Average number of times the learner should sample each item from the replay buffer.
error_tolerance: 2 # Tolerance for how much the learner/actor can sample/insert before being blocked. Must be greater than 2 to avoid deadlocks.
3 changes: 1 addition & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
if config.env.log_win_rate:
metrics["won_episode"] = timesteps.extras["won_episode"]

# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
# Find the first instance of done to get the metrics at that timestep.
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging
Expand Down
6 changes: 3 additions & 3 deletions mava/systems/q_learning/anakin/rec_qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
TrainState,
Transition,
)
from mava.types import MarlEnv, Observation
from mava.types import MarlEnv, MavaObservation
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
Expand Down Expand Up @@ -241,7 +241,7 @@ def make_update_fns(
) -> Callable[[LearnerState[QMIXParams]], Tuple[LearnerState[QMIXParams], Tuple[Metrics, Metrics]]]:
def select_eps_greedy_action(
action_selection_state: ActionSelectionState,
obs: Observation,
obs: MavaObservation,
term_or_trunc: Array,
) -> Tuple[ActionSelectionState, Array]:
"""Select action to take in eps-greedy way. Batch and agent dims are included."""
Expand Down Expand Up @@ -310,7 +310,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]:

return new_act_state, next_timestep.extras["episode_metrics"]

def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array:
def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array:
"""Prepares the inputs to the RNN network for either getting q values or the
eps-greedy distribution.
Expand Down
Loading

0 comments on commit 6c8452f

Please sign in to comment.