diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f0c8..adc8103ad6 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -7,7 +7,9 @@ from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.utils import explained_variance SelfA2C = TypeVar("SelfA2C", bound="A2C") diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e87599903..76809114f1 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -14,7 +14,7 @@ from gymnasium import spaces from stable_baselines3.common import utils -from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback, MaybeCallback from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor @@ -22,7 +22,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict +from stable_baselines3.common.type_aliases import Schedule, TensorDict from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -38,6 +38,7 @@ VecTransposeImage, is_vecenv_wrapped, unwrap_vec_normalize, + GymEnv, ) from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 5089bba2b0..54e31c9d1f 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -704,3 +704,6 @@ def _on_training_end(self) -> None: # Flush and close progress bar self.pbar.refresh() self.pbar.close() + + +MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index e8dcac4a47..fe89b34be9 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -12,13 +12,13 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.utils import safe_mean, should_collect_more_steps -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv, GymEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm") diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index dc2b5abd02..d68838310d 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -8,11 +8,11 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv, GymEnv SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm") diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 2f98ee198b..ddcb331205 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,15 +1,11 @@ """Common aliases for type hints""" from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union -import gymnasium as gym import numpy as np import torch as th -from stable_baselines3.common import callbacks, vec_env - -GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymResetReturn = Tuple[GymObs, Dict] AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] @@ -17,7 +13,6 @@ AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] # A schedule takes the remaining progress as input # and ouputs a scalar (e.g. learning rate, clip range, ...) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 4950822d4b..7b74f3093e 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -22,7 +22,8 @@ SummaryWriter = None # type: ignore[misc, assignment] from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.vec_env import GymEnv def set_random_seed(seed: int, using_cuda: bool = False) -> None: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 2c036373aa..dbf2d8259c 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -2,7 +2,7 @@ from copy import deepcopy from typing import Optional, Type, Union -from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper, GymEnv from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.stacked_observations import StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv @@ -16,7 +16,7 @@ # Avoid circular import if typing.TYPE_CHECKING: - from stable_baselines3.common.type_aliases import GymEnv + pass def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 16518a1028..e7215f30bb 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -316,6 +316,9 @@ def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: return indices +GymEnv = Union[gym.Env, VecEnv] + + class VecEnvWrapper(VecEnv): """ Vectorized environment base class diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357c..44fcff2629 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -4,7 +4,9 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df02..8476216d62 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -9,7 +9,9 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ea7cf5ed4a..69dd17869b 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -9,7 +9,9 @@ from stable_baselines3.common.buffers import RolloutBuffer from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.utils import explained_variance, get_schedule_fn SelfPPO = TypeVar("SelfPPO", bound="PPO") diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa50282..3f42d0c7cd 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -9,7 +9,9 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e01..43b4f728c2 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -9,7 +9,9 @@ from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy