Skip to content

Commit

Permalink
Resolve cyclic imports by moving the MaybeCallback type to the callba…
Browse files Browse the repository at this point in the history
…cks and the GymEnv type to the base_vec_env.py
  • Loading branch information
ernestum committed Oct 23, 2023
1 parent d1e30bb commit 560eb50
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 23 deletions.
4 changes: 3 additions & 1 deletion stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.utils import explained_variance

SelfA2C = TypeVar("SelfA2C", bound="A2C")
Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from gymnasium import spaces

from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback, MaybeCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.type_aliases import Schedule, TensorDict
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
Expand All @@ -38,6 +38,7 @@
VecTransposeImage,
is_vecenv_wrapped,
unwrap_vec_normalize,
GymEnv,
)
from stable_baselines3.common.vec_env.patch_gym import _convert_space, _patch_env

Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,6 @@ def _on_training_end(self) -> None:
# Flush and close progress bar
self.pbar.refresh()
self.pbar.close()


MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
6 changes: 3 additions & 3 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.type_aliases import RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import VecEnv, GymEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import VecEnv, GymEnv

SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")

Expand Down
7 changes: 1 addition & 6 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""Common aliases for type hints"""

from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
from typing import Any, Callable, Dict, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union

import gymnasium as gym
import numpy as np
import torch as th

from stable_baselines3.common import callbacks, vec_env

GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymResetReturn = Tuple[GymObs, Dict]
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
GymStepReturn = Tuple[GymObs, float, bool, bool, Dict]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]

# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
SummaryWriter = None # type: ignore[misc, assignment]

from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.type_aliases import Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.vec_env import GymEnv


def set_random_seed(seed: int, using_cuda: bool = False) -> None:
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy
from typing import Optional, Type, Union

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper, GymEnv
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
Expand All @@ -16,7 +16,7 @@

# Avoid circular import
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
pass


def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
return indices


GymEnv = Union[gym.Env, VecEnv]


class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.utils import explained_variance, get_schedule_fn

SelfPPO = TypeVar("SelfPPO", bound="PPO")
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy

Expand Down

0 comments on commit 560eb50

Please sign in to comment.