Skip to content

Commit

Permalink
Avoid torch type-error under torch.compile (#1922)
Browse files Browse the repository at this point in the history
* Avoid torch type-error under torch.compile

* Update changelog and version

* Update stable_baselines3/common/buffers.py

Co-authored-by: Antonin RAFFIN <[email protected]>

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
amjames and araffin authored May 13, 2024
1 parent 285e01f commit 766b9e9
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Cast type in compute gae method to avoid error when using torch compile (@amjames)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def compute_returns_and_advantage(self, last_values: th.Tensor, dones: np.ndarra
last_gae_lam = 0
for step in reversed(range(self.buffer_size)):
if step == self.buffer_size - 1:
next_non_terminal = 1.0 - dones
next_non_terminal = 1.0 - dones.astype(np.float32)
next_values = last_values
else:
next_non_terminal = 1.0 - self.episode_starts[step + 1]
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.2
2.4.0a0

0 comments on commit 766b9e9

Please sign in to comment.