Skip to content

Commit

Permalink
Resolve cyclic imports by moving the MaybeCallback type to the callba…
Browse files Browse the repository at this point in the history
…cks and the GymEnv type to the base_vec_env.py
  • Loading branch information
ernestum committed Oct 23, 2023
1 parent 3848f3a commit 0379421
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 29 deletions.
4 changes: 3 additions & 1 deletion stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import explained_variance
from stable_baselines3.common.vec_env import GymEnv

SelfA2C = TypeVar("SelfA2C", bound="A2C")

Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from gymnasium import spaces

from stable_baselines3.common import utils
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, MaybeCallback, ProgressBarCallback
from stable_baselines3.common.env_util import is_wrapped
from stable_baselines3.common.logger import Logger
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.type_aliases import Schedule, TensorDict
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
Expand All @@ -33,6 +33,7 @@
)
from stable_baselines3.common.vec_env import (
DummyVecEnv,
GymEnv,
VecEnv,
VecNormalize,
VecTransposeImage,
Expand Down
9 changes: 7 additions & 2 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import gymnasium as gym
import numpy as np
Expand All @@ -19,7 +19,9 @@
# if the progress bar is used
tqdm = None

from stable_baselines3.common import base_class
if TYPE_CHECKING:
from stable_baselines3.common import base_class

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization

Expand Down Expand Up @@ -704,3 +706,6 @@ def _on_training_end(self) -> None:
# Flush and close progress bar
self.pbar.refresh()
self.pbar.close()


MaybeCallback = Union[None, Callable, List[BaseCallback], BaseCallback]
6 changes: 3 additions & 3 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.type_aliases import RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import GymEnv, VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="OffPolicyAlgorithm")
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.callbacks import BaseCallback, MaybeCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import GymEnv, VecEnv

SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="OnPolicyAlgorithm")

Expand Down
7 changes: 1 addition & 6 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""Common aliases for type hints"""

from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union
from typing import Any, Callable, Dict, NamedTuple, Optional, Protocol, SupportsFloat, Tuple, Union

import gymnasium as gym
import numpy as np
import torch as th

from stable_baselines3.common import callbacks, vec_env

GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
GymResetReturn = Tuple[GymObs, Dict]
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
GymStepReturn = Tuple[GymObs, float, bool, bool, Dict]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]

# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
SummaryWriter = None # type: ignore[misc, assignment]

from stable_baselines3.common.logger import Logger, configure
from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.type_aliases import Schedule, TensorDict, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.vec_env import GymEnv


def set_random_seed(seed: int, using_cuda: bool = False) -> None:
Expand Down
7 changes: 1 addition & 6 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import typing
from copy import deepcopy
from typing import Optional, Type, Union

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, GymEnv, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.stacked_observations import StackedObservations
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
Expand All @@ -14,10 +13,6 @@
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder

# Avoid circular import
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv


def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
"""
Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def _get_indices(self, indices: VecEnvIndices) -> Iterable[int]:
return indices


GymEnv = Union[gym.Env, VecEnv]


class VecEnvWrapper(VecEnv):
"""
Vectorized environment base class
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import torch as th

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.td3.policies import TD3Policy
from stable_baselines3.td3.td3 import TD3

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork

SelfDQN = TypeVar("SelfDQN", bound="DQN")
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
from stable_baselines3.common.vec_env import GymEnv

SelfPPO = TypeVar("SelfPPO", bound="PPO")

Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.sac.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy

SelfSAC = TypeVar("SelfSAC", bound="SAC")
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.callbacks import MaybeCallback
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy, ContinuousCritic
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.common.vec_env import GymEnv
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy

SelfTD3 = TypeVar("SelfTD3", bound="TD3")
Expand Down

0 comments on commit 0379421

Please sign in to comment.