diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 203f15e3a..16712fb79 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Cast type in compute gae method to avoid error when using torch compile (@amjames) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 306b43571..651ecdb2d 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -424,7 +424,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: - next_non_terminal = 1.0 - dones + next_non_terminal = 1.0 - dones.astype(np.float32) next_values = last_values else: next_non_terminal = 1.0 - self.episode_starts[step + 1] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f90b1afc0..e96f44fb3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.2 +2.4.0a0