diff --git a/mava/configs/system/q_learning/rec_iql.yaml b/mava/configs/system/q_learning/rec_iql.yaml index 80dd32e0a..63865e4ff 100644 --- a/mava/configs/system/q_learning/rec_iql.yaml +++ b/mava/configs/system/q_learning/rec_iql.yaml @@ -11,13 +11,13 @@ add_agent_id: True min_buffer_size: 32 update_batch_size: 1 # Number of vectorised gradient updates per device. -rollout_length: 2 # Number of environment steps per vectorised environment. +rollout_length: 2 # Number of environment steps per vectorised enviro²nment. epochs: 2 # Number of learn epochs per training data batch. # sizes -buffer_size: 5000 # size of the replay buffer. Note: total size is this * num_devices +buffer_size: 1000 # size of the replay buffer. Note: total size is this * num_devices sample_batch_size: 32 # size of training data batch sampled from the buffer -sample_sequence_length: 20 # N transitions are sampled, giving N - 1 complete data points +sample_sequence_length: 32 # 20 transitions are sampled, giving 19 complete data points # learning rates q_lr: 3e-4 # the learning rate of the Q network network optimizer @@ -33,5 +33,5 @@ eps_min: 0.05 eps_decay: 1e5 # --- Sebulba parameters --- -samples_per_insert : 2 # The average number of times the learner should sample each item in the replay buffer. -sample_per_inser_tolerance : 6 # Maximum size of the "error" before calls should be blocked. +data_sample_mean: 150 # Average number of times the learner should sample each item from the replay buffer. +error_tolerance: 2 # Tolerance for how much the learner/actor can sample/insert before being blocked. Must be greater than 2 to avoid deadlocks. diff --git a/mava/evaluator.py b/mava/evaluator.py index 6b2fda203..61bdabbdc 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -290,8 +290,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: if config.env.log_win_rate: metrics["won_episode"] = timesteps.extras["won_episode"] - # find the first instance of done to get the metrics at that timestep, we don't - # care about subsequent steps because we only the results from the first episode + # Find the first instance of done to get the metrics at that timestep. done_idx = np.argmax(timesteps.last(), axis=0) metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # uneeded for logging diff --git a/mava/systems/q_learning/anakin/rec_qmix.py b/mava/systems/q_learning/anakin/rec_qmix.py index 7dcccf75c..45bc83a57 100644 --- a/mava/systems/q_learning/anakin/rec_qmix.py +++ b/mava/systems/q_learning/anakin/rec_qmix.py @@ -44,7 +44,7 @@ TrainState, Transition, ) -from mava.types import MarlEnv, Observation +from mava.types import MarlEnv, MavaObservation from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -241,7 +241,7 @@ def make_update_fns( ) -> Callable[[LearnerState[QMIXParams]], Tuple[LearnerState[QMIXParams], Tuple[Metrics, Metrics]]]: def select_eps_greedy_action( action_selection_state: ActionSelectionState, - obs: Observation, + obs: MavaObservation, term_or_trunc: Array, ) -> Tuple[ActionSelectionState, Array]: """Select action to take in eps-greedy way. Batch and agent dims are included.""" @@ -310,7 +310,7 @@ def action_step(action_state: ActionState, _: Any) -> Tuple[ActionState, Dict]: return new_act_state, next_timestep.extras["episode_metrics"] - def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array: + def prep_inputs_to_scannedrnn(obs: MavaObservation, term_or_trunc: chex.Array) -> chex.Array: """Prepares the inputs to the RNN network for either getting q values or the eps-greedy distribution. diff --git a/mava/systems/q_learning/sebulba/rec_iql.py b/mava/systems/q_learning/sebulba/rec_iql.py index a1c7d1718..a1ef4e5ac 100644 --- a/mava/systems/q_learning/sebulba/rec_iql.py +++ b/mava/systems/q_learning/sebulba/rec_iql.py @@ -13,17 +13,14 @@ # limitations under the License. import copy -import time +import queue import threading import warnings from collections import defaultdict -from typing import Any, Callable, Dict, Tuple, List, Sequence -from numpy.typing import NDArray -import queue from queue import Queue +from typing import Any, Dict, List, Sequence, Tuple import chex -import flashbax as fbx import hydra import jax import jax.lax as lax @@ -31,35 +28,28 @@ import numpy as np import optax from colorama import Fore, Style - from flax.core.scope import FrozenVariableDict from flax.linen import FrozenDict from jax import Array, tree from jax.experimental import mesh_utils from jax.experimental.shard_map import shard_map from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding - from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint from mava.evaluator import get_sebulba_eval_fn as get_eval_fn from mava.evaluator import make_rec_eval_act_fn - from mava.networks import RecQNetwork, ScannedRNN -from mava.utils.sebulba import ParamsSource, OffPolicyPipeline as Pipeline, RecordTimeTo, ThreadLifetime, SampleToInsertRatio -from mava.systems.q_learning.types import ( - ActionSelectionState, - Metrics, - QNetParams, - Transition, - SebulbaLearnerState as LearnerState -) -from mava.types import Observation, SebulbaLearnerFn +from mava.systems.q_learning.types import Metrics, QNetParams, Transition +from mava.systems.q_learning.types import SebulbaLearnerState as LearnerState +from mava.types import Observation, SebulbaLearnerFn from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer -from mava.utils.config import check_total_timesteps, check_sebulba_config +from mava.utils.config import check_sebulba_config, check_total_timesteps from mava.utils.jax_utils import switch_leading_axes from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.sebulba import OffPolicyPipeline as Pipeline +from mava.utils.sebulba import ParamsSource, RecordTimeTo, SampleToInsertRatio, ThreadLifetime from mava.wrappers.episode_metrics import get_final_step_metrics from mava.wrappers.gym import GymToJumanji @@ -70,47 +60,52 @@ def rollout( config: DictConfig, rollout_queue: Pipeline, params_source: ParamsSource, - apply_fn , + q_net: RecQNetwork, actor_device: int, seeds: List[int], thread_lifetime: ThreadLifetime, - actor_id : int, + actor_id: int, ) -> None: - """Runs rollouts to collect trajectories from the environment. + """Collects trajectories from the environment by running rollouts. Args: - key (chex.PRNGKey): The PRNGkey. - config (DictConfig): Configuration settings for the environment and rollout. - rollout_queue (Pipeline): Queue for sending collected rollouts to the learner. - params_source (ParamsSource): Source for fetching the latest network parameters - from the learner. - apply_fns (Tuple): Functions for running the actor and critic networks. - actor_device (Device): Actor device to use for rollout. - seeds (List[int]): Seeds for initializing the environment. - thread_lifetime (ThreadLifetime): Manages the thread's lifecycle. + key: The PRNG key for stochasticity. + env: The environment to interact with. + config: Configuration settings for rollout and environment. + rollout_queue: Queue for sending collected trajectories to the learner. + params_source: Provides the latest network parameters from the learner. + q_net: The Q-network. + actor_device: Index of the actor device to use for rollout. + seeds: Seeds for environment initialization. + thread_lifetime: Controls the thread's lifecycle. + actor_id: Unique identifier for the actor. """ name = threading.current_thread().name print(f"{Fore.BLUE}{Style.BRIGHT}Thread {name} started{Style.RESET_ALL}") - num_agents, num_envs = config.system.num_agents, config.arch.num_envs + num_agents = config.system.num_agents move_to_device = lambda x: jax.device_put(x, device=actor_device) @jax.jit def select_eps_greedy_action( - params : FrozenDict , hidden_state, obs: Observation, term_or_trunc: Array, key, t: int - ) -> Tuple[ActionSelectionState, Array]: - """Select action to take in epsilon-greedy way. Batch and agent dims are included. + params: FrozenDict, + hidden_state: jax.Array, + obs: Observation, + term_or_trunc: Array, + key: chex.PRNGKey, + t: int, + ) -> Tuple[Array, Array, int]: + """Selects an action epsilon-greedily. Args: - ---- - action_selection_state: Tuple of online parameters, previous hidden state, - environment timestep (used to calculate epsilon) and a random key. - obs: The observation from the previous timestep. - term_or_trunc: The flag timestep.last() from the previous timestep. + params: Network parameters. + hidden_state: Current RNN hidden state. + obs: Observation from the environment. + term_or_trunc: Termination or truncation flag. + key: PRNG key for sampling. + t: Current timestep (used for epsilon decay). Returns: - ------- - A tuple of the updated action selection state and the chosen action. - + Tuple containing the chosen action, next hidden state, and updated timestep. """ eps = jnp.maximum( @@ -120,25 +115,25 @@ def select_eps_greedy_action( obs = tree.map(lambda x: x[jnp.newaxis, ...], obs) term_or_trunc = tree.map(lambda x: x[jnp.newaxis, ...], term_or_trunc) - next_hidden_state, eps_greedy_dist = apply_fn( - params, hidden_state, (obs, term_or_trunc), eps + next_hidden_state, eps_greedy_dist = q_net.apply( + params, hidden_state, (obs, term_or_trunc), eps ) action = eps_greedy_dist.sample(seed=key) action = action[0, ...] # (1, B, A) -> (B, A) - - return action, next_hidden_state, t + config.arch.num_envs + + return action, next_hidden_state, t + config.arch.num_envs next_timestep = env.reset(seed=seeds) dones = next_timestep.last()[..., jnp.newaxis] - + # Initialise hidden states. hstate = ScannedRNN.initialize_carry( (config.arch.num_envs, num_agents), config.network.hidden_state_dim ) hstate_tpu = tree.map(move_to_device, hstate) step_count = 0 - + # Loop till the desired num_updates is reached. while not thread_lifetime.should_stop(): # Rollout @@ -149,35 +144,37 @@ def select_eps_greedy_action( for _ in range(config.system.rollout_length): with RecordTimeTo(actor_timings["get_params_time"]): params = params_source.get() # Get the latest parameters from the learner - + timestep = next_timestep obs_tpu = tree.map(move_to_device, timestep.observation) - + last_dones = tree.map(move_to_device, dones) - + # Get action and value with RecordTimeTo(actor_timings["compute_action_time"]): key, act_key = jax.random.split(key) - action, hstate_tpu, step_count = select_eps_greedy_action(params, hstate_tpu, obs_tpu, last_dones, act_key, step_count) + action, hstate_tpu, step_count = select_eps_greedy_action( + params, hstate_tpu, obs_tpu, last_dones, act_key, step_count + ) cpu_action = jax.device_get(action) # Step environment with RecordTimeTo(actor_timings["env_step_time"]): next_timestep = env.step(cpu_action) - #Prepare the transation + # Prepare the transation terminal = (1 - timestep.discount[..., 0, jnp.newaxis]).astype(bool) - dones = next_timestep.last()[..., jnp.newaxis] + dones = next_timestep.last()[..., jnp.newaxis] # Append data to storage traj.append( Transition( - timestep.observation, - action, - next_timestep.reward, - terminal, - dones, - next_timestep.extras["real_next_obs"] + timestep.observation, + action, + next_timestep.reward, + terminal, + dones, + next_timestep.extras["real_next_obs"], ) ) @@ -194,8 +191,9 @@ def select_eps_greedy_action( env.close() + def get_learner_step_fn( - apply_fn , + q_net: RecQNetwork, update_fn: optax.TransformUpdateFn, config: DictConfig, ) -> SebulbaLearnerFn[LearnerState, Transition]: @@ -205,38 +203,43 @@ def _update_step( learner_state: LearnerState, traj_batch: Transition, ) -> Tuple[LearnerState, Metrics]: - """A single update of the network. + """Performs a single network update. - This function calculates advantages and targets based on the trajectories - from the actor and updates the actor and critic networks based on the losses. + Calculates targets based on the input trajectories and updates the Q-network + parameters accordingly. Args: - learner_state (LearnerState): contains all the items needed for learning. - traj_batch (PPOTransition): the batch of data to learn with. + learner_state: Current learner state. + traj_batch: Batch of transitions for training. """ - - + def prep_inputs_to_scannedrnn(obs: Observation, term_or_trunc: chex.Array) -> chex.Array: - """Prepares the inputs to the RNN network for either getting q values or the - eps-greedy distribution. + """Prepares inputs for the ScannedRNN network. - Mostly swaps leading axes because the replay buffer outputs (B, T, ... ) - and the RNN takes in (T, B, ...). + Switches leading axes of observations and termination/truncation flags to match the + (T, B, ...) format expected by the RNN. The replay buffer outputs data in (B, T, ...) + format. + + Args: + obs: Observation data. + term_or_trunc: Termination/truncation flags. + + Returns: + Tuple containing the initial hidden state and the formatted input data. """ hidden_state = ScannedRNN.initialize_carry( - (config.system.sample_batch_size, obs.agents_view.shape[2]), config.network.hidden_state_dim + (obs.agents_view.shape[0], obs.agents_view.shape[2]), + config.network.hidden_state_dim, ) # the rb outputs (B, T, ... ) the RNN takes in (T, B, ...) obs = switch_leading_axes(obs) # (B, T) -> (T, B) term_or_trunc = switch_leading_axes(term_or_trunc) # (B, T) -> (T, B) obs_term_or_trunc = (obs, term_or_trunc) - - return hidden_state, obs_term_or_trunc + return hidden_state, obs_term_or_trunc - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: - """Update the network for a single epoch.""" + """Update the network for a single epoch.""" def q_loss_fn( q_online_params: FrozenVariableDict, @@ -249,7 +252,7 @@ def q_loss_fn( hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc) # get online q values of all actions - _, q_online = apply_fn( + _, q_online = q_net.apply( q_online_params, hidden_state, obs_term_or_trunc, method="get_q_values" ) q_online = switch_leading_axes(q_online) # (T, B, ...) -> (B, T, ...) @@ -268,10 +271,9 @@ def q_loss_fn( } return q_loss, loss_info - + params, opt_states, t_train, traj_batch = update_state - # Get data aligned with current/next timestep data_first = tree.map(lambda x: x[:, :-1, ...], traj_batch) data_next = tree.map(lambda x: x[:, 1:, ...], traj_batch) @@ -291,16 +293,16 @@ def q_loss_fn( next_terminal = data_next.terminal # Scan over each sample - hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn( + hidden_state, next_obs_term_or_trunc = prep_inputs_to_scannedrnn( next_obs, next_term_or_trunc ) # eps defaults to 0 - _, next_online_greedy_dist = apply_fn( + _, next_online_greedy_dist = q_net.apply( params.online, hidden_state, next_obs_term_or_trunc ) - _, next_q_vals_target = apply_fn( + _, next_q_vals_target = q_net.apply( params.target, hidden_state, next_obs_term_or_trunc, method="get_q_values" ) @@ -309,7 +311,8 @@ def q_loss_fn( # Double q-value selection next_q_val = jnp.squeeze( - jnp.take_along_axis(next_q_vals_target, next_action[..., jnp.newaxis], axis=-1), axis=-1 + jnp.take_along_axis(next_q_vals_target, next_action[..., jnp.newaxis], axis=-1), + axis=-1, ) next_q_val = switch_leading_axes(next_q_val) # (T, B, ...) -> (B, T, ...) @@ -319,7 +322,9 @@ def q_loss_fn( # Update Q function. q_grad_fn = jax.grad(q_loss_fn, has_aux=True) - q_grads, q_loss_info = q_grad_fn(params.online, obs, term_or_trunc, action, target_q_val) + q_grads, q_loss_info = q_grad_fn( + params.online, obs, term_or_trunc, action, target_q_val + ) # Mean over the device and batch dimension. q_grads, q_loss_info = lax.pmean((q_grads, q_loss_info), axis_name="learner_devices") @@ -337,13 +342,13 @@ def q_loss_fn( # Repack params and opt_states. next_params = QNetParams(next_online_params, next_target_params) - + # Repack. next_state = (next_params, next_opt_state, t_train + 1, traj_batch) - + return next_state, q_loss_info - - update_state = (*learner_state , traj_batch) + + update_state = (*learner_state, traj_batch) update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.epochs ) @@ -352,7 +357,6 @@ def q_loss_fn( learner_state = LearnerState(params, opt_states, train_step) return learner_state, loss_info - def learner_fn( learner_state: LearnerState, traj_batch: Transition ) -> Tuple[LearnerState, Metrics]: @@ -360,24 +364,22 @@ def learner_fn( This function represents the learner, it updates the network parameters by iteratively applying the `_update_step` function for a fixed number of - updates. The `_update_step` function is vectorized over a batch of inputs. + updates. The `_update_step` function is vectorized across learner devices. Args: learner_state (NamedTuple): - params (Params): The initial model parameters. - opt_states (OptStates): The initial optimizer state. - - key (chex.PRNGKey): The random number generator state. - - env_state (LogEnvState): The environment state. - - timesteps (TimeStep): The last timestep of the rollout. + - step_counter int): Number of learning steps. + traj_batch (Transition): The collected trainig data. """ - # This function is shard mapped on the batch axis, but `_update_step` needs - # the first axis to be time #todo is this comment still relevent ? learner_state, loss_info = _update_step(learner_state, traj_batch) return learner_state, loss_info return learner_fn + def learner_thread( learn_fn: SebulbaLearnerFn[LearnerState, Transition], learner_state: LearnerState, @@ -388,7 +390,8 @@ def learner_thread( ) -> None: for _ in range(config.arch.num_evaluation): # Create the lists to store metrics and timings for this learning iteration. - metrics: List[Tuple[Dict, Dict]] = [] + ep_metrics: List[Dict] = [] + train_metrics: List[Dict] = [] rollout_times: List[Dict] = [] learn_times: Dict[str, List[float]] = defaultdict(list) @@ -397,14 +400,15 @@ def learner_thread( # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, (rollout_time, ep_metrics) = pipeline.get(block=True) - + traj_batch, (rollout_time, ep_metric) = pipeline.get() # Update the networks with RecordTimeTo(learn_times["learning_time"]): - learner_state, train_metrics = learn_fn(learner_state, traj_batch) + learner_state, train_metric = learn_fn(learner_state, traj_batch) - metrics.append((ep_metrics, train_metrics)) - rollout_times.append(rollout_time) + train_metrics.append(train_metric) + if ep_metric is not None: + ep_metrics.append(ep_metric) + rollout_times.append(rollout_time) # Update all the params sources so all actors can get the latest params params = jax.block_until_ready(learner_state.params) @@ -412,13 +416,16 @@ def learner_thread( source.update(params.online) # Pass all the metrics and params to the main thread (evaluator) for logging and evaluation - ep_metrics, train_metrics = tree.map(lambda *x: np.asarray(x), *metrics) - rollout_times: Dict[str, NDArray] = tree.map(lambda *x: np.mean(x), *rollout_times) - timing_dict = rollout_times | learn_times + if ep_metrics: + ep_metrics = tree.map(lambda *x: np.asarray(x), *ep_metrics) + train_metrics = tree.map(lambda *x: np.asarray(x), *train_metrics) + + timing_dict = tree.map(lambda *x: np.mean(x), *rollout_times) | learn_times timing_dict = tree.map(np.mean, timing_dict, is_leaf=lambda x: isinstance(x, list)) eval_queue.put((ep_metrics, train_metrics, learner_state, timing_dict)) + def learner_setup( key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ @@ -426,6 +433,7 @@ def learner_setup( RecQNetwork, LearnerState, Sharding, + Transition, ]: """Initialise learner_fn, network and learner state.""" @@ -435,23 +443,22 @@ def learner_setup( action_space = env.single_action_space config.system.num_agents = len(action_space) config.system.num_actions = int(action_space[0].n) - - devices = mesh_utils.create_device_mesh((len(learner_devices), ), devices=learner_devices) + + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) mesh = Mesh(devices, axis_names=("learner_devices")) model_spec = PartitionSpec() data_spec = PartitionSpec("learner_devices") learner_sharding = NamedSharding(mesh, model_spec) - key, q_key = jax.random.split(key, 2) # Shape legend: # T: Time (dummy dimension size = 1) # B: Batch (dummy dimension size = 1) # A: Agent # Make dummy inputs to init recurrent Q network -> need shape (T, B, A, ...) - init_agents_view = jnp.array(env.single_observation_space.sample()) + init_agents_view = jnp.array(env.single_observation_space.sample()) init_action_mask = jnp.ones((config.system.num_agents, config.system.num_actions)) - init_obs = Observation(init_agents_view, init_action_mask) # (A, ...) + init_obs = Observation(init_agents_view, init_action_mask) # (A, ...) # (B, T, A, ...) init_obs_batched = tree.map(lambda x: x[jnp.newaxis, jnp.newaxis, ...], init_obs) dones = jnp.zeros((1, 1, 1), dtype=bool) # (T, B, 1) @@ -491,11 +498,11 @@ def learner_setup( reward=jnp.zeros((config.system.num_agents,), dtype=float), terminal=jnp.zeros((1,), dtype=bool), # one flag for all agents term_or_trunc=jnp.zeros((1,), dtype=bool), - next_obs=init_obs + next_obs=init_obs, ) - + learn_state_spec = LearnerState(model_spec, model_spec, model_spec) - learn = get_learner_step_fn(q_net.apply, opt.update, config) + learn = get_learner_step_fn(q_net, opt.update, config) learn = jax.jit( shard_map( learn, @@ -515,19 +522,15 @@ def learner_setup( restored_params, _ = loaded_checkpoint.restore_params(input_params=params) # Update the params params = restored_params - + # Duplicate learner across Learner devices. - params, opt_state = jax.device_put( - (params, opt_state), learner_sharding - ) - + params, opt_state = jax.device_put((params, opt_state), learner_sharding) + # Initial learner state. - init_learner_state = LearnerState( - params, opt_state, 0 - ) - + init_learner_state = LearnerState(params, opt_state, 0) + env.close() - return learn, q_net.apply, init_learner_state, learner_sharding , init_transition + return learn, q_net, init_learner_state, learner_sharding, init_transition def run_experiment(_config: DictConfig) -> float: @@ -546,11 +549,14 @@ def run_experiment(_config: DictConfig) -> float: np_rng = np.random.default_rng(config.system.seed) # Setup learner. - learn, apply_fn, learner_state, learner_sharding, init_transition = learner_setup(key, config, learner_devices) + learn, q_net, learner_state, learner_sharding, init_transition = learner_setup( + key, config, learner_devices + ) # Setup evaluator. # One key per device for evaluation. - eval_act_fn = make_rec_eval_act_fn(apply_fn, config) + eval_act_fn = make_rec_eval_act_fn(q_net.apply, config) + evaluator, evaluator_envs = get_eval_fn( environments.make_gym_env, eval_act_fn, config, np_rng, absolute_metric=False ) @@ -563,9 +569,30 @@ def run_experiment(_config: DictConfig) -> float: config.system.rollout_length * config.arch.num_envs * config.system.num_updates_per_eval + * len(config.arch.actor_device_ids) + * config.arch.n_threads_per_executor ) - # Setup logger + # Setup RateLimiter + insert_to_sample_ratio = ( + config.system.rollout_length + * config.arch.num_envs + * len(config.arch.actor_device_ids) + * config.arch.n_threads_per_executor + ) / (config.system.sample_sequence_length * config.system.sample_batch_size) + + config.sample_per_insert = config.system.data_sample_mean * insert_to_sample_ratio + config.tolerance = config.sample_per_insert * config.system.error_tolerance + + min_num_inserts = max( + config.system.sample_sequence_length // config.system.rollout_length, + config.system.min_buffer_size // config.system.rollout_length, + 1, + ) + + rate_limiter = SampleToInsertRatio(config.sample_per_insert, min_num_inserts, config.tolerance) + + # Setup logger logger = MavaLogger(config) print_cfg: Dict = OmegaConf.to_container(config, resolve=True) print_cfg["arch"]["devices"] = jax.devices() @@ -580,20 +607,12 @@ def run_experiment(_config: DictConfig) -> float: **config.logger.checkpointing.save_args, # Checkpoint args ) - # Executor setup and launch. inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate - # The rollout queue/ the pipe between actor and learner - - # Setup RateLimiter | todo we can convert all of this calucations to use the batch size but idk how helpful that would be - batch_size_per_insert = config.arch.num_envs * config.system.rollout_length * config.arch.n_threads_per_executor * len(actor_devices) - min_num_inserts = max((config.system.min_buffer_size * config.system.sample_sequence_length) // batch_size_per_insert, 1) - rate_limiter = SampleToInsertRatio(config.system.samples_per_insert, min_num_inserts, config.system.sample_per_inser_tolerance) - # Setup Pipeline pipe_lifetime = ThreadLifetime() - pipe = Pipeline(config, learner_sharding, key, rate_limiter, init_transition, pipe_lifetime)#todo chek key + pipe = Pipeline(config, learner_sharding, key, rate_limiter, init_transition, pipe_lifetime) pipe.start() params_sources: List[ParamsSource] = [] @@ -624,11 +643,11 @@ def run_experiment(_config: DictConfig) -> float: config, pipe, params_source, - apply_fn, + q_net, actor_device, seeds, actor_lifetime, - actor_id + actor_id, ), name=f"Actor-{actor_device}-{thread_id}", ) @@ -664,20 +683,23 @@ def run_experiment(_config: DictConfig) -> float: time_metrics |= {"timestep": t, "pipline_size": pipe.qsize()} logger.log(time_metrics, t, eval_step, LogEvent.MISC) - episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) - episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] - if ep_completed: - logger.log(episode_metrics, t, eval_step, LogEvent.ACT) + if episode_metrics: + episode_metrics, ep_completed = get_final_step_metrics(episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / time_metrics["rollout_time"] + if ep_completed: + logger.log(episode_metrics, t, eval_step, LogEvent.ACT) - train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval - train_metrics["learner_steps_per_second"] = ( - config.system.num_updates_per_eval - ) / time_metrics["learner_time_per_eval"] - logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) + train_metrics["learner_step"] = (eval_step + 1) * config.system.num_updates_per_eval + train_metrics["learner_steps_per_second"] = ( + config.system.num_updates_per_eval + ) / time_metrics["learner_time_per_eval"] + logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) learner_state_cpu = jax.device_get(learner_state) key, eval_key = jax.random.split(key, 2) - eval_metrics = evaluator(learner_state_cpu.params.online, eval_key, {"hidden_state" : eval_hs}) + eval_metrics = evaluator( + learner_state_cpu.params.online, eval_key, {"hidden_state": eval_hs} + ) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) episode_return = np.mean(eval_metrics["episode_return"]) @@ -704,10 +726,13 @@ def run_experiment(_config: DictConfig) -> float: ) key, eval_key = jax.random.split(key, 2) eval_hs = ScannedRNN.initialize_carry( - (min(config.arch.num_absolute_metric_eval_episodes, config.arch.num_envs), config.system.num_agents), + ( + min(config.arch.num_absolute_metric_eval_episodes, config.arch.num_envs), + config.system.num_agents, + ), config.network.hidden_state_dim, ) - eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {"hidden_state" : eval_hs}) + eval_metrics = abs_metric_evaluator(best_params_cpu, eval_key, {"hidden_state": eval_hs}) t = int(steps_per_rollout * (eval_step + 1)) logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) @@ -734,7 +759,6 @@ def run_experiment(_config: DictConfig) -> float: return eval_performance - @hydra.main( config_path="../../../configs/default", config_name="rec_iql_sebulba.yaml", diff --git a/mava/systems/q_learning/types.py b/mava/systems/q_learning/types.py index 5087453f1..442fa67b4 100644 --- a/mava/systems/q_learning/types.py +++ b/mava/systems/q_learning/types.py @@ -21,7 +21,7 @@ from jumanji.env import State from typing_extensions import NamedTuple, TypeAlias -from mava.types import Observation +from mava.types import MavaObservation, Observation Metrics = Dict[str, Array] @@ -29,14 +29,14 @@ class Transition(NamedTuple): """Transition for recurrent Q-learning.""" - obs: Observation + obs: MavaObservation action: Array reward: Array terminal: Array term_or_trunc: Array # Even though we use a trajectory buffer we need to store both obs and next_obs. # This is because of how the `AutoResetWrapper` returns obs at the end of an episode. - next_obs: Observation + next_obs: MavaObservation BufferState: TypeAlias = TrajectoryBufferState[Transition] @@ -109,8 +109,10 @@ class TrainState(NamedTuple, Generic[QLearningParams]): train_steps: Array key: PRNGKey + class SebulbaLearnerState(NamedTuple): """State of the learner for the Sebulba architecture.""" - params : QNetParams - opt_states : optax.OptState - step_counter : int \ No newline at end of file + + params: QNetParams + opt_states: optax.OptState + step_counter: int diff --git a/mava/utils/config.py b/mava/utils/config.py index 767ced753..81664493b 100644 --- a/mava/utils/config.py +++ b/mava/utils/config.py @@ -34,9 +34,9 @@ def check_sebulba_config(config: DictConfig) -> None: int(config.arch.num_envs / len(config.arch.learner_device_ids)) * config.system.rollout_length ) - + # PPO specifique check - if "num_minibatches" in config.system: + if "num_minibatches" in config.system: assert num_eval_samples % config.system.num_minibatches == 0, ( f"Number of training samples per evaluator ({num_eval_samples})" + f"must be divisible by num_minibatches ({config.system.num_minibatches})." diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index 9205efc99..f418f2c80 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -16,20 +16,21 @@ import queue import threading import time -from typing import Any, Dict, List, Sequence, Tuple, Union, Optional +from math import ceil +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp import numpy as np from colorama import Fore, Style from flashbax import make_trajectory_buffer -from mava.systems.q_learning.types import Transition - from jax import tree from jax.sharding import Sharding from jumanji.types import TimeStep +from omegaconf import DictConfig from mava.systems.ppo.types import Params, PPOTransition +from mava.systems.q_learning.types import Transition from mava.types import Metrics QUEUE_PUT_TIMEOUT = 100 @@ -49,7 +50,9 @@ def stop(self) -> None: @jax.jit -def _stack_trajectory(trajectory: Union[List[PPOTransition], List[Transition]]) -> Union[PPOTransition, Transition]: +def _stack_trajectory( + trajectory: Union[List[PPOTransition], List[Transition]], +) -> Union[PPOTransition, Transition]: """Stack a list of parallel_env transitions into a single transition of shape [rollout_len, num_envs, ...].""" return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore @@ -215,14 +218,14 @@ def __init__( self.max_diff = max_diff self.min_size_to_sample = min_size_to_sample - self.inserts = 0 + self.inserts = 0.0 self.samples = 0 self.deletes = 0 self.mutex = threading.Lock() self.condition = threading.Condition(self.mutex) - def num_inserts(self) -> int: + def num_inserts(self) -> float: """Returns the number of inserts.""" with self.mutex: return self.inserts @@ -237,10 +240,10 @@ def num_deletes(self) -> int: with self.mutex: return self.deletes - def insert(self) -> None: + def insert(self, insert_fraction: float = 1) -> None: """Increment the number of inserts and notify all waiting threads.""" with self.mutex: - self.inserts += 1 + self.inserts += insert_fraction self.condition.notify_all() # Notify all waiting threads def delete(self) -> None: @@ -260,9 +263,9 @@ def can_insert(self, num_inserts: int) -> bool: # Assume lock is already held by the caller if num_inserts <= 0: return False - if self.inserts + num_inserts - self.deletes <= self.min_size_to_sample: + if ceil(self.inserts) + num_inserts - self.deletes <= self.min_size_to_sample: return True - diff = (num_inserts + self.inserts) * self.samples_per_insert - self.samples + diff = (num_inserts + ceil(self.inserts)) * self.samples_per_insert - self.samples return diff <= self.max_diff def can_sample(self, num_samples: int) -> bool: @@ -270,9 +273,9 @@ def can_sample(self, num_samples: int) -> bool: # Assume lock is already held by the caller if num_samples <= 0: return False - if self.inserts - self.deletes < self.min_size_to_sample: + if ceil(self.inserts) - self.deletes < self.min_size_to_sample: return False - diff = self.inserts * self.samples_per_insert - self.samples - num_samples + diff = ceil(self.inserts) * self.samples_per_insert - self.samples - num_samples return diff >= self.min_diff def await_can_insert(self, num_inserts: int = 1, timeout: Optional[float] = None) -> bool: @@ -297,7 +300,7 @@ def __repr__(self) -> str: f"min_size_to_sample={self.min_size_to_sample}, " f"min_diff={self.min_diff}, max_diff={self.max_diff})" ) - + class SampleToInsertRatio(RateLimiter): """Maintains a specified ratio between samples and inserts. @@ -405,20 +408,32 @@ def __init__( # Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py -class OffPolicyPipeline(threading.Thread): +class OffPolicyPipeline(threading.Thread): """ The `Pipeline` shards trajectories into learner devices, ensuring trajectories are consumed in the right order to avoid being off-policy and limit the max number of samples in device memory at one time to avoid OOM issues. """ - def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PRNGKey, rate_limiter : RateLimiter, init_transition : Transition, lifetime: ThreadLifetime): + def __init__( + self, + config: DictConfig, + learner_sharding: Sharding, + key: jax.random.PRNGKey, + rate_limiter: RateLimiter, + init_transition: Transition, + lifetime: ThreadLifetime, + ): """ Initializes the pipeline with a maximum size and the devices to shard trajectories across. Args: - max_size: The maximum number of trajectories to keep in the pipeline. + config: Configuration settings for buffers. learner_sharding: The sharding used for the learner's update function. + key: The PRNG key for stochasticity. + rate_limiter: A `RateLimiter` Used to manage how often we are allowed to + sample from the buffers. + init_transition : A sample trasition used to initialize the buffers. lifetime: A `ThreadLifetime` which is used to stop this thread. """ super().__init__(name="Pipeline") @@ -429,31 +444,32 @@ def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PR self.lifetime = lifetime self.num_buffers = len(config.arch.actor_device_ids) * config.arch.n_threads_per_executor - + self.rate_limiter = rate_limiter self.sharding = learner_sharding - self.last_actor_metrics = None - self.move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, self.cpu), tree) - - # Setup Buffer + self.key = key + + assert config.system.sample_batch_size % self.num_buffers == 0, ( + f"The sample batch size ({config.system.sample_batch_size}) must be divisible " + f"by the total number of actors ({self.num_buffers})." + ) + + # Setup Buffers rb = make_trajectory_buffer( - sample_sequence_length=config.system.sample_sequence_length + 1, - period=1, # sample any unique trajectory - add_batch_size=config.arch.num_envs, - sample_batch_size=config.system.sample_batch_size // self.num_buffers, #todo add an assert ? - max_length_time_axis=config.system.buffer_size, - min_length_time_axis=config.system.min_buffer_size, - ) - self.buffer_states = [rb.init(init_transition) for _ in range(self.num_buffers)] + sample_sequence_length=config.system.sample_sequence_length + 1, + period=1, + add_batch_size=config.arch.num_envs, + sample_batch_size=config.system.sample_batch_size // self.num_buffers, + max_length_time_axis=config.system.buffer_size, + min_length_time_axis=config.system.min_buffer_size, + ) + self.buffer_states = [rb.init(init_transition) for _ in range(self.num_buffers)] self.buffer_adds_count = [0] * self.num_buffers - + + # Setup functions self.buffer_add = jax.jit(rb.add, device=self.cpu) self.buffer_sample = jax.jit(rb.sample, device=self.cpu) - - self.key = key + self.move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, self.cpu), tree) - #rate limiter - self.rate_limiter = rate_limiter - def run(self) -> None: """This function ensures that trajectories on the queue are consumed in the right order. The start_condition and end_condition are used to ensure that only 1 thread is processing an @@ -469,14 +485,14 @@ def run(self) -> None: except queue.Empty: continue - def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id : int) -> None: + def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id: int) -> None: start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: self.tickets_queue.put((start_condition, end_condition)) - start_condition.wait() + start_condition.wait() - try: - self.rate_limiter.await_can_insert(timeout=QUEUE_PUT_TIMEOUT) + try: + self.rate_limiter.await_can_insert(timeout=QUEUE_PUT_TIMEOUT) except TimeoutError: print( f"{Fore.RED}{Style.BRIGHT}Actor has timed out on insertion, " @@ -485,28 +501,23 @@ def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id : int) -> Non # [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)] traj = _stack_trajectory(traj) - traj = jax.device_put(traj, device=self.sharding) + traj = jax.device_get(traj) time_dict, episode_metrics = metrics # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} episode_metrics = _stack_trajectory(episode_metrics) - self.buffer_states[actor_id] = self.buffer_add(self.buffer_states[actor_id], traj) self.buffer_adds_count[actor_id] += 1 self._queue.put((time_dict, episode_metrics)) - # check if any buffer has beed added - if any(count > self.rate_limiter.num_inserts() for count in self.buffer_adds_count): - self.rate_limiter.insert() + self.rate_limiter.insert(1 / self.num_buffers) with end_condition: end_condition.notify() # notify that we have finished - def get( - self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, TimeStep, Dict]: + def get(self, timeout: Union[float, None] = None) -> Tuple[Transition, Any]: """Get a trajectory from the pipeline.""" self.key, sample_key = jax.random.split(self.key) @@ -520,16 +531,20 @@ def get( ) # Sample the data - sampled_batch = [self.buffer_sample(state, sample_key).experience for state in self.buffer_states] - sampled_batch = jax.tree_map(lambda *x : np.concatenate(x), *sampled_batch) - sampled_batch = jax.device_put(sampled_batch, device=self.sharding) + # Potential deadlock risk here. Although it hasn't occurred during testing. + # if an unexplained deadlock happens, it is likely due to this section. + sampled_batch: List[Transition] = [ + self.buffer_sample(state, sample_key).experience for state in self.buffer_states + ] + transitions: Transition = jax.tree_map(lambda *x: np.concatenate(x), *sampled_batch) + transitions = jax.device_put(transitions, device=self.sharding) self.rate_limiter.sample() if not self._queue.empty(): - self.last_actor_metrics = self._queue.get() + return transitions, self._queue.get() - return sampled_batch, self.last_actor_metrics + return transitions, (None, None) def clear(self) -> None: """Clear the pipeline.""" @@ -538,6 +553,7 @@ def clear(self) -> None: self._queue.get(block=False) except queue.Empty: break + def qsize(self) -> int: """Returns the number of trajectories in the pipeline.""" - return self._queue.qsize() \ No newline at end of file + return self._queue.qsize() diff --git a/mava/wrappers/gym.py b/mava/wrappers/gym.py index f2511c301..3bd48dbfa 100644 --- a/mava/wrappers/gym.py +++ b/mava/wrappers/gym.py @@ -29,7 +29,7 @@ from gymnasium.vector.utils import write_to_shared_memory from numpy.typing import NDArray -from mava.types import Observation, ObservationGlobalState +from mava.types import MavaObservation, Observation, ObservationGlobalState if TYPE_CHECKING: # https://github.com/python/mypy/issues/6239 from dataclasses import dataclass @@ -54,7 +54,7 @@ class TimeStep: step_type: StepType reward: NDArray discount: NDArray - observation: Union[Observation, ObservationGlobalState] + observation: MavaObservation extras: Dict = field(default_factory=dict) def first(self) -> NDArray: @@ -78,12 +78,13 @@ def __init__( use_shared_rewards: bool = True, add_global_state: bool = False, ): - """Initialise the gym wrapper + """Initialize the gym wrapper Args: env (gymnasium.env): gymnasium env instance. use_shared_rewards (bool, optional): Use individual or shared rewards. Defaults to False. - add_global_state (bool, optional) : Create global observations. Defaults to False. + add_global_state (bool, optional) : Add global state information + to observations. """ super().__init__(env) self._env = env @@ -141,7 +142,7 @@ def get_global_obs(self, obs: NDArray) -> NDArray: class SmacWrapper(UoeWrapper): - """A wrapper that converts actions step to integers.""" + """A wrapper that converts actions to integers.""" def reset( self, seed: Optional[int] = None, options: Optional[dict] = None @@ -251,7 +252,7 @@ def __init__(self, env: gymnasium.vector.VectorEnv): self.env = env self.single_action_space = env.unwrapped.single_action_space self.single_observation_space = env.unwrapped.single_observation_space - self.num_agents = len(self.env.single_action_space) + self.num_agents = len(self.env.single_action_space) def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None) -> TimeStep: obs, info = self.env.reset(seed=seed, options=options) # type: ignore @@ -260,9 +261,9 @@ def reset(self, seed: Optional[list[int]] = None, options: Optional[dict] = None step_type = np.full(num_envs, StepType.FIRST) rewards = np.zeros((num_envs, self.num_agents), dtype=float) - teminated = np.zeros((num_envs, self.num_agents), dtype=float) + terminated = np.zeros((num_envs, self.num_agents), dtype=float) - timestep = self._create_timestep(obs, step_type, teminated, rewards, info) + timestep = self._create_timestep(obs, step_type, terminated, rewards, info) return timestep @@ -271,21 +272,26 @@ def step(self, action: list) -> TimeStep: ep_done = np.logical_or(terminated, truncated) step_type = np.where(ep_done, StepType.LAST, StepType.MID) - terminated = np.repeat(terminated[..., np.newaxis], repeats=self.num_agents ,axis=-1) # (B,) --> (B, N) + terminated = np.repeat( + terminated[..., np.newaxis], repeats=self.num_agents, axis=-1 + ) # (B,) --> (B, N) timestep = self._create_timestep(obs, step_type, terminated, rewards, info) return timestep def _format_observation( - self, obs: NDArray, action_mask: Tuple[NDArray], global_obs: Optional[Tuple[NDArray]] = None + self, + obs: NDArray, + action_mask: Tuple[NDArray], + global_obs: Tuple[Union[NDArray, None]] = (None,), ) -> Union[Observation, ObservationGlobalState]: """Create an observation from the raw observation and environment state.""" action_mask = np.stack(action_mask) obs_data = {"agents_view": obs, "action_mask": action_mask} - - if global_obs[0] is not None: + + if global_obs[0] is not None: global_obs = np.array(global_obs) obs_data["global_state"] = global_obs return ObservationGlobalState(**obs_data) @@ -295,16 +301,23 @@ def _format_observation( def _create_timestep( self, obs: NDArray, step_type: NDArray, terminated: NDArray, rewards: NDArray, info: Dict ) -> TimeStep: - observation = self._format_observation(obs, info["action_mask"], info.get("global_obs", (None,))) - # Filter out the masks and auxiliary data #this is beyond ugly + observation = self._format_observation( + obs, info["action_mask"], info.get("global_obs", (None,)) + ) + # Filter out the masks and auxiliary data extras = {} - extras["episode_metrics"] = {key: value for key, value in info["metrics"].items() if key[0] != "_"} - extras["real_next_obs"] = self._format_observation(info["real_next_obs"], info["real_next_action_mask"] , info["real_next_global_obs"]) + extras["episode_metrics"] = { + key: value for key, value in info["metrics"].items() if key[0] != "_" + } + extras["real_next_obs"] = self._format_observation( # type: ignore + info["real_next_obs"], info["real_next_action_mask"], info["real_next_global_obs"] + ) + if "won_episode" in info: extras["won_episode"] = info["won_episode"] return TimeStep( - step_type=step_type, # type: ignore + step_type=step_type, reward=rewards, discount=1.0 - terminated, observation=observation, @@ -315,7 +328,7 @@ def close(self) -> None: self.env.close() -# Copied form Gymnasium/blob/main/gymnasium/vector/async_vector_env.py +# Copied from Gymnasium/blob/main/gymnasium/vector/async_vector_env.py # Modified to work with multiple agents # Note: The worker handles auto-resetting the environments. # Each environment resets when all of its agents have either terminated or been truncated. @@ -338,16 +351,16 @@ def async_multiagent_worker( # CCR001 if command == "reset": observation, info = env.reset(**data) - info["real_next_obs"] = observation - info["real_next_action_mask"] = info["action_mask"] - info["real_next_global_obs"] = info.get("global_obs", None) + info["real_next_obs"] = observation + info["real_next_action_mask"] = info["action_mask"] + info["real_next_global_obs"] = info.get("global_obs", None) if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) observation = None pipe.send(((observation, info), True)) elif command == "step": # Modified the step function to align with 'AutoResetWrapper'. - # The environment resets immediately upon termination or truncation. + # The environment resets when all agents have either terminated or truncated. ( observation, reward, @@ -355,13 +368,13 @@ def async_multiagent_worker( # CCR001 truncated, info, ) = env.step(data) - info["real_next_obs"] = observation - info["real_next_action_mask"] = info["action_mask"] - info["real_next_global_obs"] = info.get("global_obs", None) - if np.logical_or(terminated, truncated).all(): + info["real_next_obs"] = observation + info["real_next_action_mask"] = info["action_mask"] + info["real_next_global_obs"] = info.get("global_obs", None) + if np.logical_or(terminated, truncated).all(): observation, new_info = env.reset() info["action_mask"] = new_info["action_mask"] - info["global_obs"] = new_info.get("global_obs", None) + info["global_obs"] = new_info.get("global_obs", None) if shared_memory: write_to_shared_memory(observation_space, index, observation, shared_memory) @@ -415,4 +428,4 @@ def async_multiagent_worker( # CCR001 error_queue.put((index, error_type, error_message, trace)) pipe.send((None, False)) finally: - env.close() \ No newline at end of file + env.close()