Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix policies type annotations #1735

Merged
merged 1 commit into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a10 (WIP)
Release 2.2.0a11 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Expand Down Expand Up @@ -62,6 +62,7 @@ Others:
- Fixed ``stable_baselines3/common/off_policy_algorithm.py`` type hints
- Fixed ``stable_baselines3/common/distributions.py`` type hints
- Switched to PyTorch 2.1.0 in the CI (fixes type annotations)
- Fixed ``stable_baselines3/common/policies.py`` type hints

Documentation:
^^^^^^^^^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ exclude = [
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/off_policy_algorithm.py",
"stable_baselines3/common/distributions.py",
"stable_baselines3/common/policies.py",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
Expand Down
67 changes: 41 additions & 26 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor

SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
Expand Down Expand Up @@ -119,7 +119,7 @@ def make_features_extractor(self) -> BaseFeaturesExtractor:
"""Helper method to create a features extractor."""
return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs)

def extract_features(self, obs: th.Tensor, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
def extract_features(self, obs: PyTorchObs, features_extractor: BaseFeaturesExtractor) -> th.Tensor:
"""
Preprocess the observation if needed and extract features.

Expand Down Expand Up @@ -219,6 +219,9 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
for key, obs in observation.items():
obs_space = self.observation_space.spaces[key]
vectorized_env = vectorized_env or is_vectorized_observation(maybe_transpose(obs, obs_space), obs_space)
Expand All @@ -228,7 +231,7 @@ def is_vectorized_observation(self, observation: Union[np.ndarray, Dict[str, np.
)
return vectorized_env

def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[th.Tensor, bool]:
def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[PyTorchObs, bool]:
"""
Convert an input observation to a PyTorch tensor that can be fed to a model.
Includes sugar-coating to handle different observations (e.g. normalizing images).
Expand All @@ -239,6 +242,9 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
"""
vectorized_env = False
if isinstance(observation, dict):
assert isinstance(
self.observation_space, spaces.Dict
), f"The observation provided is a dict but the obs space is {self.observation_space}"
# need to copy the dict as the dict in VecFrameStack will become a torch tensor
observation = copy.deepcopy(observation)
for key, obs in observation.items():
Expand All @@ -249,7 +255,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
obs_ = np.array(obs)
vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space)
# Add batch dimension if needed
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape))
observation[key] = obs_.reshape((-1, *self.observation_space[key].shape)) # type: ignore[misc]

elif is_image_space(self.observation_space):
# Handle the different cases for images
Expand All @@ -263,10 +269,10 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -
# Dict obs need to be handled separately
vectorized_env = is_vectorized_observation(observation, self.observation_space)
# Add batch dimension if needed
observation = observation.reshape((-1, *self.observation_space.shape))
observation = observation.reshape((-1, *self.observation_space.shape)) # type: ignore[misc]

observation = obs_as_tensor(observation, self.device)
return observation, vectorized_env
obs_tensor = obs_as_tensor(observation, self.device)
return obs_tensor, vectorized_env


class BasePolicy(BaseModel, ABC):
Expand Down Expand Up @@ -308,7 +314,7 @@ def init_weights(module: nn.Module, gain: float = 1) -> None:
module.bias.data.fill_(0.0)

@abstractmethod
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.

Expand Down Expand Up @@ -354,27 +360,28 @@ def predict(
"and documentation for more information: https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html#vecenv-api-vs-gym-api"
)

observation, vectorized_env = self.obs_to_tensor(observation)
obs_tensor, vectorized_env = self.obs_to_tensor(observation)

with th.no_grad():
actions = self._predict(observation, deterministic=deterministic)
actions = self._predict(obs_tensor, deterministic=deterministic)
# Convert to numpy, and reshape to the original action shape
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape))
actions = actions.cpu().numpy().reshape((-1, *self.action_space.shape)) # type: ignore[misc]

if isinstance(self.action_space, spaces.Box):
if self.squash_output:
# Rescale to proper domain when using squashing
actions = self.unscale_action(actions)
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]
else:
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]

# Remove batch dimension if needed
if not vectorized_env:
assert isinstance(actions, np.ndarray)
actions = actions.squeeze(axis=0)

return actions, state
return actions, state # type: ignore[return-value]

def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Expand All @@ -384,6 +391,9 @@ def scale_action(self, action: np.ndarray) -> np.ndarray:
:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0

Expand All @@ -394,6 +404,9 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:

:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))

Expand Down Expand Up @@ -522,7 +535,7 @@ def __init__(
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()

default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) # type: ignore[arg-type, return-value]

data.update(
dict(
Expand Down Expand Up @@ -616,7 +629,7 @@ def _build(self, lr_schedule: Schedule) -> None:
module.apply(partial(self.init_weights, gain=gain))

# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) # type: ignore[call-arg]

def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Expand All @@ -639,11 +652,11 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
distribution = self._get_action_dist_from_latent(latent_pi)
actions = distribution.get_actions(deterministic=deterministic)
log_prob = distribution.log_prob(actions)
actions = actions.reshape((-1, *self.action_space.shape))
actions = actions.reshape((-1, *self.action_space.shape)) # type: ignore[misc]
return actions, values, log_prob

def extract_features(
self, obs: th.Tensor, features_extractor: Optional[BaseFeaturesExtractor] = None
def extract_features( # type: ignore[override]
self, obs: PyTorchObs, features_extractor: Optional[BaseFeaturesExtractor] = None
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
"""
Preprocess the observation if needed and extract features.
Expand Down Expand Up @@ -691,7 +704,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:
else:
raise ValueError("Invalid action distribution")

def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.

Expand All @@ -701,7 +714,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
"""
return self.get_distribution(observation).get_actions(deterministic=deterministic)

def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
def evaluate_actions(self, obs: PyTorchObs, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
"""
Evaluate actions according to the current policy,
given the observations.
Expand All @@ -725,7 +738,7 @@ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tenso
entropy = distribution.entropy()
return values, log_prob, entropy

def get_distribution(self, obs: th.Tensor) -> Distribution:
def get_distribution(self, obs: PyTorchObs) -> Distribution:
"""
Get the current policy distribution given the observations.

Expand All @@ -736,7 +749,7 @@ def get_distribution(self, obs: th.Tensor) -> Distribution:
latent_pi = self.mlp_extractor.forward_actor(features)
return self._get_action_dist_from_latent(latent_pi)

def predict_values(self, obs: th.Tensor) -> th.Tensor:
def predict_values(self, obs: PyTorchObs) -> th.Tensor:
"""
Get the estimated values according to the current policy given the observations.

Expand Down Expand Up @@ -921,6 +934,8 @@ class ContinuousCritic(BaseModel):
between the actor and the critic (this saves computation time)
"""

features_extractor: BaseFeaturesExtractor

def __init__(
self,
observation_space: spaces.Space,
Expand All @@ -944,10 +959,10 @@ def __init__(

self.share_features_extractor = share_features_extractor
self.n_critics = n_critics
self.q_networks = []
self.q_networks: List[nn.Module] = []
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
q_net_list = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net_list)
self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)

Expand Down
21 changes: 11 additions & 10 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->


def preprocess_obs(
obs: th.Tensor,
obs: Union[th.Tensor, Dict[str, th.Tensor]],
observation_space: spaces.Space,
normalize_images: bool = True,
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
Expand All @@ -105,6 +105,16 @@ def preprocess_obs(
(True by default)
:return:
"""
if isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs # type: ignore[return-value]

assert isinstance(obs, th.Tensor), f"Expecting a torch Tensor, but got {type(obs)}"

if isinstance(observation_space, spaces.Box):
if normalize_images and is_image_space(observation_space):
return obs.float() / 255.0
Expand All @@ -126,15 +136,6 @@ def preprocess_obs(

elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()

elif isinstance(observation_space, spaces.Dict):
# Do not modify by reference the original observation
assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
preprocessed_obs = {}
for key, _obs in obs.items():
preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images)
return preprocessed_obs

else:
raise NotImplementedError(f"Preprocessing not implemented for {observation_space}")

Expand Down
1 change: 1 addition & 0 deletions stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TensorDict = Dict[str, th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List["BaseCallback"], "BaseCallback"]
PyTorchObs = Union[th.Tensor, TensorDict]

# A schedule takes the remaining progress as input
# and ouputs a scalar (e.g. learning rate, clip range, ...)
Expand Down
10 changes: 5 additions & 5 deletions stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
NatureCNN,
create_mlp,
)
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule


class QNetwork(BasePolicy):
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)

def forward(self, obs: th.Tensor) -> th.Tensor:
def forward(self, obs: PyTorchObs) -> th.Tensor:
"""
Predict the q-values.

Expand All @@ -65,7 +65,7 @@ def forward(self, obs: th.Tensor) -> th.Tensor:
"""
return self.q_net(self.extract_features(obs, self.features_extractor))

def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
q_values = self(observation)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
Expand Down Expand Up @@ -177,10 +177,10 @@ def make_q_net(self) -> QNetwork:
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)

def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def forward(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)

def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _predict(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)

def _get_constructor_parameters(self) -> Dict[str, Any]:
Expand Down
Loading