From d424fbe97a60c55fbc55d6eeecf8053d8bc3bd15 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 20 Oct 2023 21:15:04 +0200 Subject: [PATCH] Add rollout_buffer_class and rollout_buffer_kwargs to A2C. --- stable_baselines3/a2c/a2c.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index fda20c9c06..1ea3524089 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -41,6 +41,8 @@ class A2C(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected. + :param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation. :param normalize_advantage: Whether to normalize or not the advantage :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over @@ -75,6 +77,8 @@ def __init__( use_rms_prop: bool = True, use_sde: bool = False, sde_sample_freq: int = -1, + rollout_buffer_class: Optional[Type[RolloutBuffer]] = None, + rollout_buffer_kwargs: Optional[Dict[str, Any]] = None, normalize_advantage: bool = False, stats_window_size: int = 100, tensorboard_log: Optional[str] = None, @@ -96,6 +100,8 @@ def __init__( max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + rollout_buffer_class=rollout_buffer_class, + rollout_buffer_kwargs=rollout_buffer_kwargs, stats_window_size=stats_window_size, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,