diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5999aa9c7..78d42eba3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.2.0a9 (WIP) +Release 2.2.0a10 (WIP) -------------------------- **Support for options at reset, bug fixes and better error messages** @@ -59,6 +59,7 @@ Others: - Buffers do no call an additional ``.copy()`` when storing new transitions - Fixed ``ActorCriticPolicy.extract_features()`` signature by adding an optional ``features_extractor`` argument - Update dependencies (accept newer Shimmy/Sphinx version and remove ``sphinx_autodoc_typehints``) +- Fixed ``stable_baselines3/common/distributions.py`` type hints Documentation: ^^^^^^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index 2d0c61914..9476868c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ exclude = [ "stable_baselines3/common/on_policy_algorithm.py", "stable_baselines3/common/vec_env/stacked_observations.py", "stable_baselines3/common/vec_env/subproc_vec_env.py", - "stable_baselines3/common/vec_env/patch_gym.py" + "stable_baselines3/common/vec_env/patch_gym.py", + "stable_baselines3/common/distributions.py", ] [tool.mypy] @@ -43,8 +44,7 @@ ignore_missing_imports = true follow_imports = "silent" show_error_codes = true exclude = """(?x)( - stable_baselines3/common/distributions.py$ - | stable_baselines3/common/off_policy_algorithm.py$ + stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/policies.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 8a8e9f903..149345d83 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -175,7 +175,7 @@ def log_prob(self, actions: th.Tensor) -> th.Tensor: log_prob = self.distribution.log_prob(actions) return sum_independent_dims(log_prob) - def entropy(self) -> th.Tensor: + def entropy(self) -> Optional[th.Tensor]: return sum_independent_dims(self.distribution.entropy()) def sample(self) -> th.Tensor: @@ -216,7 +216,7 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6): super().__init__(action_dim) # Avoid NaN (prevents division by zero or log of zero) self.epsilon = epsilon - self.gaussian_actions = None + self.gaussian_actions: Optional[th.Tensor] = None def proba_distribution( self: SelfSquashedDiagGaussianDistribution, mean_actions: th.Tensor, log_std: th.Tensor @@ -339,7 +339,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: def proba_distribution( self: SelfMultiCategoricalDistribution, action_logits: th.Tensor ) -> SelfMultiCategoricalDistribution: - self.distribution = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] + self.distribution = [Categorical(logits=split) for split in th.split(action_logits, list(self.action_dims), dim=1)] return self def log_prob(self, actions: th.Tensor) -> th.Tensor: @@ -440,6 +440,13 @@ class StateDependentNoiseDistribution(Distribution): :param epsilon: small value to avoid NaN due to numerical imprecision. """ + bijector: Optional["TanhBijector"] + latent_sde_dim: Optional[int] + weights_dist: Normal + _latent_sde: th.Tensor + exploration_mat: th.Tensor + exploration_matrices: th.Tensor + def __init__( self, action_dim: int, @@ -454,10 +461,6 @@ def __init__( self.latent_sde_dim = None self.mean_actions = None self.log_std = None - self.weights_dist = None - self.exploration_mat = None - self.exploration_matrices = None - self._latent_sde = None self.use_expln = use_expln self.full_std = full_std self.epsilon = epsilon @@ -489,6 +492,7 @@ def get_std(self, log_std: th.Tensor) -> th.Tensor: if self.full_std: return std + assert self.latent_sde_dim is not None # Reduce the number of parameters: return th.ones(self.latent_sde_dim, self.action_dim).to(log_std.device) * std @@ -675,10 +679,13 @@ def make_proba_distribution( cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution return cls(get_action_dim(action_space), **dist_kwargs) elif isinstance(action_space, spaces.Discrete): - return CategoricalDistribution(action_space.n, **dist_kwargs) + return CategoricalDistribution(int(action_space.n), **dist_kwargs) elif isinstance(action_space, spaces.MultiDiscrete): return MultiCategoricalDistribution(list(action_space.nvec), **dist_kwargs) elif isinstance(action_space, spaces.MultiBinary): + assert isinstance( + action_space.n, int + ), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead." return BernoulliDistribution(action_space.n, **dist_kwargs) else: raise NotImplementedError( @@ -702,7 +709,10 @@ def kl_divergence(dist_true: Distribution, dist_pred: Distribution) -> th.Tensor # MultiCategoricalDistribution is not a PyTorch Distribution subclass # so we need to implement it ourselves! if isinstance(dist_pred, MultiCategoricalDistribution): - assert np.allclose(dist_pred.action_dims, dist_true.action_dims), "Error: distributions must have the same input space" + assert isinstance(dist_true, MultiCategoricalDistribution) # already checked above, for mypy + assert np.allclose( + dist_pred.action_dims, dist_true.action_dims + ), f"Error: distributions must have the same input space: {dist_pred.action_dims} != {dist_true.action_dims}" return th.stack( [th.distributions.kl_divergence(p, q) for p, q in zip(dist_true.distribution, dist_pred.distribution)], dim=1, diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index a2d0e59c1..2b5251cf8 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -204,7 +204,7 @@ def get_action_dim(action_space: spaces.Space) -> int: # Number of binary actions assert isinstance( action_space.n, int - ), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead." + ), f"Multi-dimensional MultiBinary({action_space.n}) action space is not supported. You can flatten it instead." return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index b7120ad62..b208680f4 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.2.0a9 +2.2.0a10 diff --git a/tests/test_spaces.py b/tests/test_spaces.py index e4a933976..e006c1f96 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -168,3 +168,9 @@ def test_float64_action_space(model_class, obs_space, action_space): initial_obs, _ = env.reset() action, _ = model.predict(initial_obs, deterministic=False) assert action.dtype == env.action_space.dtype + + +def test_multidim_binary_not_supported(): + env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3])) + with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"): + A2C("MlpPolicy", env)