Skip to content

Commit

Permalink
fix: deadlock caused by deleting when buffer is full
Browse files Browse the repository at this point in the history
  • Loading branch information
Louay-Ben-nessir committed Nov 26, 2024
1 parent ee4834f commit 7e44d15
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 19 deletions.
14 changes: 9 additions & 5 deletions mava/systems/q_learning/sebulba/rec_iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,17 +580,21 @@ def run_experiment(_config: DictConfig) -> float:
**config.logger.checkpointing.save_args, # Checkpoint args
)


# Executor setup and launch.
inital_params = jax.device_put(learner_state.params, actor_devices[0]) # unreplicate

# The rollout queue/ the pipe between actor and learner

# Setup RateLimiter | todo WE COLLECT BATCH8SIZE8PER8INSER PER STEMP BUT WE SAMPLE SAMPLE8BTACH8SUZE PER SAMPLE
batch_size_per_insert = config.arch.num_envs * config.system.rollout_length
min_num_inserts = max(config.system.min_buffer_size // batch_size_per_insert, 1) #todo min buffer size here?
# Setup RateLimiter | todo we can convert all of this calucations to use the batch size but idk how helpful that would be
batch_size_per_insert = config.arch.num_envs * config.system.rollout_length * config.arch.n_threads_per_executor * len(actor_devices)
min_num_inserts = max((config.system.min_buffer_size * config.system.sample_sequence_length) // batch_size_per_insert, 1)
rate_limiter = SampleToInsertRatio(config.system.samples_per_insert, min_num_inserts, config.system.sample_per_inser_tolerance)

pipe = Pipeline(config, learner_sharding, key, rate_limiter, init_transition)#todo chek key

# Setup Pipeline
pipe_lifetime = ThreadLifetime()
pipe = Pipeline(config, learner_sharding, key, rate_limiter, init_transition, pipe_lifetime)#todo chek key
pipe.start()

params_sources: List[ParamsSource] = []
actor_threads: List[threading.Thread] = []
Expand Down
47 changes: 33 additions & 14 deletions mava/utils/sebulba.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from jax.sharding import Sharding
from jumanji.types import TimeStep

# todo: remove the ppo dependencies when we make sebulba for other systems
from mava.systems.ppo.types import Params, PPOTransition
from mava.types import Metrics

Expand All @@ -50,7 +49,7 @@ def stop(self) -> None:


@jax.jit
def _stack_trajectory(trajectory: List[PPOTransition]) -> PPOTransition:
def _stack_trajectory(trajectory: Union[List[PPOTransition], List[Transition]]) -> Union[PPOTransition, Transition]:
"""Stack a list of parallel_env transitions into a single
transition of shape [rollout_len, num_envs, ...]."""
return tree.map(lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), *trajectory) # type: ignore
Expand Down Expand Up @@ -406,14 +405,14 @@ def __init__(


# Modified from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py
class OffPolicyPipeline:
class OffPolicyPipeline(threading.Thread):
"""
The `Pipeline` shards trajectories into learner devices,
ensuring trajectories are consumed in the right order to avoid being off-policy
and limit the max number of samples in device memory at one time to avoid OOM issues.
"""

def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PRNGKey, rate_limiter : RateLimiter, init_transition : Transition):
def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PRNGKey, rate_limiter : RateLimiter, init_transition : Transition, lifetime: ThreadLifetime):
"""
Initializes the pipeline with a maximum size and the devices to shard trajectories across.
Expand All @@ -422,9 +421,12 @@ def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PR
learner_sharding: The sharding used for the learner's update function.
lifetime: A `ThreadLifetime` which is used to stop this thread.
"""
super().__init__(name="Pipeline")
self.cpu = jax.devices("cpu")[0]

self.tickets_queue: queue.Queue = queue.Queue()
self._queue: queue.Queue = queue.Queue()
self.lifetime = lifetime

self.num_buffers = len(config.arch.actor_device_ids) * config.arch.n_threads_per_executor

Expand All @@ -446,25 +448,40 @@ def __init__(self,config : dict, learner_sharding: Sharding, key : jax.random.PR

self.buffer_add = jax.jit(rb.add, device=self.cpu)
self.buffer_sample = jax.jit(rb.sample, device=self.cpu)

# How many times we inserted to all of the buffers
self.complete_adds_count = 0

self.key = key

#rate limiter
self.rate_limiter = rate_limiter

def run(self) -> None:
"""This function ensures that trajectories on the queue are consumed in the right order. The
start_condition and end_condition are used to ensure that only 1 thread is processing an
item from the queue at one time, ensuring predictable memory usage.
"""
while not self.lifetime.should_stop():
try:
start_condition, end_condition = self.tickets_queue.get(timeout=1)
with end_condition:
with start_condition:
start_condition.notify()
end_condition.wait()
except queue.Empty:
continue

def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id : int) -> None:
start_condition, end_condition = (threading.Condition(), threading.Condition())
with start_condition:
self.tickets_queue.put((start_condition, end_condition))
start_condition.wait()

try:
self.rate_limiter.await_can_insert(timeout=QUEUE_PUT_TIMEOUT)
except TimeoutError:
print(
f"{Fore.RED}{Style.BRIGHT}Actor has timed out on insertion, "
f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}"
)
if self.buffer_states[actor_id].is_full:
self.rate_limiter.delete()

# [Transition(num_envs)] * rollout_len -> Transition[done=(num_envs, rollout_len, ...)]
traj = _stack_trajectory(traj)
Expand All @@ -480,11 +497,13 @@ def put(self, traj: Sequence[Transition], metrics: Tuple, actor_id : int) -> Non

self._queue.put((time_dict, episode_metrics))

# check if all buffers have beed added to
if all(count > self.complete_adds_count for count in self.buffer_adds_count):
self.complete_adds_count += 1
# check if any buffer has beed added
if any(count > self.rate_limiter.num_inserts() for count in self.buffer_adds_count):
self.rate_limiter.insert()

with end_condition:
end_condition.notify() # notify that we have finished

def get(
self, block: bool = True, timeout: Union[float, None] = None
) -> Tuple[PPOTransition, TimeStep, Dict]:
Expand All @@ -500,9 +519,9 @@ def get(
f"this should not happen. A deadlock might be occurring{Style.RESET_ALL}"
)

# sample the data
# Sample the data
sampled_batch = [self.buffer_sample(state, sample_key).experience for state in self.buffer_states]
sampled_batch = jax.tree_map(lambda *x : np.concatenate(x), *sampled_batch) #np not jnp
sampled_batch = jax.tree_map(lambda *x : np.concatenate(x), *sampled_batch)
sampled_batch = jax.device_put(sampled_batch, device=self.sharding)

self.rate_limiter.sample()
Expand Down

0 comments on commit 7e44d15

Please sign in to comment.