Skip to content

Commit

Permalink
Fix policies type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 6, 2023
1 parent a35c08c commit 451f87e
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 56 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a10 (WIP)
Release 2.2.0a11 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Expand Down Expand Up @@ -62,6 +62,7 @@ Others:
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ exclude = [
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
"stable_baselines3/common/policies.py",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
Expand Down
67 changes: 41 additions & 26 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor

SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
Expand Down Expand Up @@ -119,7 +119,7 @@ def make_features_extractor(self) -> BaseFeaturesExtractor:
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)

def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.
Expand Down Expand Up @@ -219,6 +219,9 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
Expand All @@ -228,7 +231,7 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.
)
return vectorized_env

def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
Expand All @@ -239,6 +242,9 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
observation = copy.deepcopy(observation)
for key, obs in observation.items():
Expand All @@ -249,7 +255,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]

elif is_image_space(self.observation_space):
# Handle the different cases for images
Expand All @@ -263,10 +269,10 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape))
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]

observation = obs_as_tensor(observation, self.device)
return observation, vectorized_env
obs_tensor = obs_as_tensor(observation, self.device)
return obs_tensor, vectorized_env


class BasePolicy(BaseModel, ABC):
Expand Down Expand Up @@ -308,7 +314,7 @@ def init_weights(module: nn.Module, gain: float = 1) -> None:
module.bias.data.fill_(0.0)

@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
Expand Down Expand Up @@ -354,27 +360,28 @@ def predict(
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)

observation, vectorized_env = self.obs_to_tensor(observation)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)

with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]

# Remove batch dimension if needed
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)

return actions, state
return actions, state # type: ignore[return-value]

def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Expand All @@ -384,6 +391,9 @@ def scale_action(self, action: np.ndarray) -> np.ndarray:
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0

Expand All @@ -394,6 +404,9 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))

Expand Down Expand Up @@ -522,7 +535,7 @@ def __init__(
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]

data.update(
dict(
Expand Down Expand Up @@ -616,7 +629,7 @@ def _build(self, lr_schedule: Schedule) -> None:
module.apply(partial(self.init_weights, gain=gain))

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]

def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Expand All @@ -639,11 +652,11 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob

def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
Expand Down Expand Up @@ -691,7 +704,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
else:
raise ValueError("Invalid action distribution")

def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
Expand All @@ -701,7 +714,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)

def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
Expand All @@ -725,7 +738,7 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso
entropy = distribution.entropy()
return values, log_prob, entropy

def get_distribution(self, obs: th.Tensor) -> Distribution:
def get_distribution(self, obs: PyTorchObs) -> Distribution:
"""
Get the current policy distribution given the observations.
Expand All @@ -736,7 +749,7 @@ def get_distribution(self, obs: th.Tensor) -> Distribution:
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)

def predict_values(self, obs: th.Tensor) -> th.Tensor:
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.
Expand Down Expand Up @@ -921,6 +934,8 @@ class ContinuousCritic(BaseModel):
between the actor and the critic (this saves computation time)
"""

features_extractor: BaseFeaturesExtractor

def __init__(
self,
observation_space: spaces.Space,
Expand All @@ -944,10 +959,10 @@ def __init__(

self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
self.q_networks: List[nn.Module] = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

Expand Down
21 changes: 11 additions & 10 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->


def preprocess_obs(
obs: th.Tensor,
obs: Union[th.Tensor, Dict[str, th.Tensor]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
Expand All @@ -105,6 +105,16 @@ def preprocess_obs(
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs # type: ignore[return-value]

assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"

if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
Expand All @@ -126,15 +136,6 @@ def preprocess_obs(

elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()

elif isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs

else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")

Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict]

# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Expand Down
10 changes: 5 additions & 5 deletions stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule


class QNetwork(BasePolicy):
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)

def forward(self, obs: th.Tensor) -> th.Tensor:
def forward(self, obs: PyTorchObs) -> th.Tensor:
"""
Predict the q-values.
Expand All @@ -65,7 +65,7 @@ def forward(self, obs: th.Tensor) -> th.Tensor:
"""
return self.q_net(self.extract_features(obs, self.features_extractor))

def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
q_values = self(observation)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
Expand Down Expand Up @@ -177,10 +177,10 @@ def make_q_net(self) -> QNetwork:
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)

def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)

def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)

def _get_constructor_parameters(self) -> Dict[str, Any]:
Expand Down
Loading

0 comments on commit 451f87e

Please sign in to comment.