Skip to content

Commit

Permalink
Fix video recorder and add test (#2063)
Browse files Browse the repository at this point in the history
* Fix video recorder and add test

* Update github CI

* Install ffmpeg

* Revert "Update github CI"

This reverts commit 07791e9.

* Skip VecVideoRecorder test on github
  • Loading branch information
araffin authored Dec 21, 2024
1 parent 0fd0db0 commit 57e8b97
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
10 changes: 9 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.5.0a0 (WIP)
Release 2.5.0a1 (WIP)
--------------------------

Breaking Changes:
Expand Down Expand Up @@ -42,6 +42,14 @@ Documentation:
- Add FootstepNet Envs to the project page (@cgaspard3333)
- Added FRASA to the project page (@MarcDcls)

Release 2.4.1 (2024-12-20)
--------------------------

Bug Fixes:
^^^^^^^^^^
- Fixed a bug introduced in v2.4.0 where the ``VecVideoRecorder`` would override videos


Release 2.4.0 (2024-11-18)
--------------------------

Expand Down
23 changes: 12 additions & 11 deletions stable_baselines3/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: Prefix to the video name
"""

video_name: str
video_path: str

def __init__(
self,
venv: VecEnv,
Expand All @@ -50,7 +53,7 @@ def __init__(

if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
metadata = temp_env.get_attr("metadata")[0]
else:
else: # pragma: no cover # assume gym interface
metadata = temp_env.metadata

self.env.metadata = metadata
Expand All @@ -67,15 +70,12 @@ def __init__(
self.step_id = 0
self.video_length = video_length

self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)

self.recording = False
self.recorded_frames: list[np.ndarray] = []

try:
import moviepy # noqa: F401
except ImportError as e:
except ImportError as e: # pragma: no cover
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install 'gymnasium[other]'`") from e

def reset(self) -> VecEnvObs:
Expand All @@ -85,6 +85,9 @@ def reset(self) -> VecEnvObs:
return obs

def _start_video_recorder(self) -> None:
# Update video name and path
self.video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}.mp4"
self.video_path = os.path.join(self.video_folder, self.video_name)
self._start_recording()
self._capture_frame()

Expand All @@ -109,8 +112,6 @@ def _capture_frame(self) -> None:
assert self.recording, "Cannot capture a frame, recording wasn't started."

frame = self.env.render()
if isinstance(frame, list):
frame = frame[-1]

if isinstance(frame, np.ndarray):
self.recorded_frames.append(frame)
Expand All @@ -123,12 +124,12 @@ def _capture_frame(self) -> None:
def close(self) -> None:
"""Closes the wrapper then the video recorder."""
VecEnvWrapper.close(self)
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()

def _start_recording(self) -> None:
"""Start a new recording. If it is already recording, stops the current recording before starting the new one."""
if self.recording:
if self.recording: # pragma: no cover
self._stop_recording()

self.recording = True
Expand All @@ -137,7 +138,7 @@ def _stop_recording(self) -> None:
"""Stop current recording and saves the video."""
assert self.recording, "_stop_recording was called, but no recording was started"

if len(self.recorded_frames) == 0:
if len(self.recorded_frames) == 0: # pragma: no cover
logger.warn("Ignored saving a video as there were zero frames to save.")
else:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
Expand All @@ -150,5 +151,5 @@ def _stop_recording(self) -> None:

def __del__(self) -> None:
"""Warn the user in case last video wasn't saved."""
if len(self.recorded_frames) > 0:
if len(self.recorded_frames) > 0: # pragma: no cover
logger.warn("Unable to save last video! Did you call close()?")
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0a0
2.5.0a1
44 changes: 43 additions & 1 deletion tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
import pytest
from gymnasium import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize, VecVideoRecorder

try:
import moviepy

have_moviepy = True
except ImportError:
have_moviepy = False

N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
Expand Down Expand Up @@ -624,3 +632,37 @@ def test_render(vec_env_class):
vec_env.render()

vec_env.close()


@pytest.mark.skipif(not have_moviepy, reason="moviepy is not installed")
def test_video_recorder(tmp_path):
env_id = "CartPole-v1"
video_folder = str(tmp_path)

vec_env = make_vec_env(env_id, n_envs=1)

# Wrap to check unwrapping works
vec_env = VecNormalize(vec_env)

# Record the video starting at the first step
vec_env = VecVideoRecorder(
vec_env,
video_folder,
record_video_trigger=lambda x: x % 65 == 0,
video_length=10,
name_prefix=f"agent-{env_id}",
)

model = PPO("MlpPolicy", vec_env, n_steps=64, n_epochs=1, verbose=0)

model.learn(total_timesteps=128)

# print all videos in video_folder, should be multiple step 0-100, step 1024-1124
video_files = list(map(str, tmp_path.glob("*.mp4")))

# Clean up
vec_env.close()

assert len(video_files) == 2
assert "agent-CartPole-v1-step-65-to-step-75.mp4" in video_files[0]
assert "agent-CartPole-v1-step-0-to-step-10.mp4" in video_files[1]

0 comments on commit 57e8b97

Please sign in to comment.