From 766b9e9f7ddd51c58d38fd090927c7242a249d2f Mon Sep 17 00:00:00 2001 From: Andrew James Date: Mon, 13 May 2024 10:28:23 -0500 Subject: [PATCH] Avoid torch type-error under torch.compile (#1922) * Avoid torch type-error under torch.compile * Update changelog and version * Update stable_baselines3/common/buffers.py Co-authored-by: Antonin RAFFIN --------- Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 1 + stable_baselines3/common/buffers.py | 2 +- stable_baselines3/version.txt | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 203f15e3a..16712fb79 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Cast type in compute gae method to avoid error when using torch compile (@amjames) `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 306b43571..651ecdb2d 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -424,7 +424,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra last_gae_lam = 0 for step in reversed(range(self.buffer_size)): if step == self.buffer_size - 1: - next_non_terminal = 1.0 - dones + next_non_terminal = 1.0 - dones.astype(np.float32) next_values = last_values else: next_non_terminal = 1.0 - self.episode_starts[step + 1] diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index f90b1afc0..e96f44fb3 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.3.2 +2.4.0a0