diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 718571f0c..64846bd18 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -5,10 +5,12 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import explained_variance +from stable_baselines3.common.vec_env import GymEnv SelfA2C = TypeVar("SelfA2C", bound="A2C") diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e8759990..cb9438a9e 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -14,7 +14,7 @@ from gymnasium import spaces from stable_baselines3.common import utils -from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, MaybeCallback, ProgressBarCallback from stable_baselines3.common.env_util import is_wrapped from stable_baselines3.common.logger import Logger from stable_baselines3.common.monitor import Monitor @@ -22,7 +22,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict +from stable_baselines3.common.type_aliases import Schedule, TensorDict from stable_baselines3.common.utils import ( check_for_correct_spaces, get_device, @@ -33,6 +33,7 @@ ) from stable_baselines3.common.vec_env import ( DummyVecEnv, + GymEnv, VecEnv, VecNormalize, VecTransposeImage, diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index 5089bba2b..a604524d4 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,7 +1,7 @@ import os import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import gymnasium as gym import numpy as np @@ -19,7 +19,9 @@ # if the progress bar is used tqdm = None -from stable_baselines3.common import base_class +if TYPE_CHECKING: + from stable_baselines3.common import base_class + from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization @@ -704,3 +706,6 @@ def _on_training_end(self) -> None: # Flush and close progress bar self.pbar.refresh() self.pbar.close() + + +MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index e8dcac4a4..981b742c3 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -12,13 +12,13 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.utils import safe_mean, should_collect_more_steps -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import GymEnv, VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm") diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index dc2b5abd0..c9f93ef40 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -8,11 +8,11 @@ from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer -from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import obs_as_tensor, safe_mean -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import GymEnv, VecEnv SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm") diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 2f98ee198..ddcb33120 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,15 +1,11 @@ """Common aliases for type hints""" from enum import Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union +from typing import Any, Callable, Dict, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union -import gymnasium as gym import numpy as np import torch as th -from stable_baselines3.common import callbacks, vec_env - -GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymResetReturn = Tuple[GymObs, Dict] AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]] @@ -17,7 +13,6 @@ AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]] TensorDict = Dict[str, th.Tensor] OptimizerStateDict = Dict[str, Any] -MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] # A schedule takes the remaining progress as input # and ouputs a scalar (e.g. learning rate, clip range, ...) diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 4950822d4..7b74f3093 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -22,7 +22,8 @@ SummaryWriter = None # type: ignore[misc, assignment] from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.vec_env import GymEnv def set_random_seed(seed: int, using_cuda: bool = False) -> None: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 2c036373a..ce25ed807 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,8 +1,7 @@ -import typing from copy import deepcopy from typing import Optional, Type, Union -from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, GymEnv, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.stacked_observations import StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv @@ -14,10 +13,6 @@ from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder -# Avoid circular import -if typing.TYPE_CHECKING: - from stable_baselines3.common.type_aliases import GymEnv - def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]: """ diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 5e06a5c0c..41f350cfc 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -341,6 +341,9 @@ def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]: return indices +GymEnv = Union[gym.Env, VecEnv] + + class VecEnvWrapper(VecEnv): """ Vectorized environment base class diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..4d457dee8 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -3,8 +3,10 @@ import torch as th from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.vec_env import GymEnv from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..611a4d6ec 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -7,10 +7,12 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update +from stable_baselines3.common.vec_env import GymEnv from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork SelfDQN = TypeVar("SelfDQN", bound="DQN") diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index ea7cf5ed4..511677587 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -7,10 +7,12 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.vec_env import GymEnv SelfPPO = TypeVar("SelfPPO", bound="PPO") diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa5028..2f1e5461b 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -6,11 +6,13 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update +from stable_baselines3.common.vec_env import GymEnv from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy SelfSAC = TypeVar("SelfSAC", bound="SAC") diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..391a8728a 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -6,11 +6,13 @@ from torch.nn import functional as F from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import MaybeCallback from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy, ContinuousCritic -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_parameters_by_name, polyak_update +from stable_baselines3.common.vec_env import GymEnv from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy SelfTD3 = TypeVar("SelfTD3", bound="TD3")