Skip to content

Commit

Permalink
Fix distributions type hints (#1733)
Browse files Browse the repository at this point in the history
* Fix distributions type hints

* Add test for multim binary action space

* Fix test
  • Loading branch information
araffin authored Nov 6, 2023
1 parent 294f2b4 commit 018ea5a
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 15 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.2.0a9 (WIP)
Release 2.2.0a10 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Expand Down Expand Up @@ -59,6 +59,7 @@ Others:
- Buffers do no call an additional ``.copy()`` when storing new transitions
- Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument
- Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``)
- Fixed ``stable_baselines3/common/distributions.py`` type hints

Documentation:
^^^^^^^^^^^^^^
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ exclude = [
"stable_baselines3/common/on_policy_algorithm.py",
"stable_baselines3/common/vec_env/stacked_observations.py",
"stable_baselines3/common/vec_env/subproc_vec_env.py",
"stable_baselines3/common/vec_env/patch_gym.py"
"stable_baselines3/common/vec_env/patch_gym.py",
"stable_baselines3/common/distributions.py",
]

[tool.mypy]
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/common/distributions.py$
| stable_baselines3/common/off_policy_algorithm.py$
stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/vec_env/__init__.py$
| stable_baselines3/common/vec_env/vec_normalize.py$
Expand Down
28 changes: 19 additions & 9 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor:
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)

def entropy(self) -> th.Tensor:
def entropy(self) -> Optional[th.Tensor]:
return sum_independent_dims(self.distribution.entropy())

def sample(self) -> th.Tensor:
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6):
super().__init__(action_dim)
# Avoid NaN (prevents division by zero or log of zero)
self.epsilon = epsilon
self.gaussian_actions = None
self.gaussian_actions: Optional[th.Tensor] = None

def proba_distribution(
self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor
Expand Down Expand Up @@ -339,7 +339,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
def proba_distribution(
self: SelfMultiCategoricalDistribution, action_logits: th.Tensor
) -> SelfMultiCategoricalDistribution:
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)]
return self

def log_prob(self, actions: th.Tensor) -> th.Tensor:
Expand Down Expand Up @@ -440,6 +440,13 @@ class StateDependentNoiseDistribution(Distribution):
:param epsilon: small value to avoid NaN due to numerical imprecision.
"""

bijector: Optional["TanhBijector"]
latent_sde_dim: Optional[int]
weights_dist: Normal
_latent_sde: th.Tensor
exploration_mat: th.Tensor
exploration_matrices: th.Tensor

def __init__(
self,
action_dim: int,
Expand All @@ -454,10 +461,6 @@ def __init__(
self.latent_sde_dim = None
self.mean_actions = None
self.log_std = None
self.weights_dist = None
self.exploration_mat = None
self.exploration_matrices = None
self._latent_sde = None
self.use_expln = use_expln
self.full_std = full_std
self.epsilon = epsilon
Expand Down Expand Up @@ -489,6 +492,7 @@ def get_std(self, log_std: th.Tensor) -> th.Tensor:

if self.full_std:
return std
assert self.latent_sde_dim is not None
# Reduce the number of parameters:
return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std

Expand Down Expand Up @@ -675,10 +679,13 @@ def make_proba_distribution(
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
return cls(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
return CategoricalDistribution(int(action_space.n), **dist_kwargs)
elif isinstance(action_space, spaces.MultiDiscrete):
return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs)
elif isinstance(action_space, spaces.MultiBinary):
assert isinstance(
action_space.n, int
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
raise NotImplementedError(
Expand All @@ -702,7 +709,10 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor
# MultiCategoricalDistribution is not a PyTorch Distribution subclass
# so we need to implement it ourselves!
if isinstance(dist_pred, MultiCategoricalDistribution):
assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space"
assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy
assert np.allclose(
dist_pred.action_dims, dist_true.action_dims
), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}"
return th.stack(
[th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)],
dim=1,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_action_dim(action_space: spaces.Space) -> int:
# Number of binary actions
assert isinstance(
action_space.n, int
), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead."
), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead."
return int(action_space.n)
else:
raise NotImplementedError(f"{action_space} action space is not supported")
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a9
2.2.0a10
6 changes: 6 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,9 @@ def test_float64_action_space(model_class, obs_space, action_space):
initial_obs, _ = env.reset()
action, _ = model.predict(initial_obs, deterministic=False)
assert action.dtype == env.action_space.dtype


def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
A2C("MlpPolicy", env)

0 comments on commit 018ea5a

Please sign in to comment.