Skip to content

Commit

Permalink
Add rollout_buffer_class and rollout_buffer_kwargs to PPO.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Oct 20, 2023
1 parent 73da70f commit d464102
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class PPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
Expand Down Expand Up @@ -92,6 +94,8 @@ def __init__(
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
Expand All @@ -113,6 +117,8 @@ def __init__(
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
Expand Down

0 comments on commit d464102

Please sign in to comment.