Skip to content

Commit

Permalink
Copy and paste video recorder to prevent the need to rewrite the vec …
Browse files Browse the repository at this point in the history
…vide recorder wrapper
  • Loading branch information
pseudo-rnd-thoughts committed Apr 3, 2024
1 parent d7ed302 commit 39f0900
Showing 1 changed file with 172 additions and 5 deletions.
177 changes: 172 additions & 5 deletions stable_baselines3/common/vec_env/vec_video_recorder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,180 @@
import json
import os
from typing import Callable
import os.path
import tempfile
from typing import Callable, List, Optional

import numpy as np
from gymnasium import error, logger

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv


# This is copy and pasted from Gymnasium v0.26.1
class VideoRecorder:
"""VideoRecorder renders a nice movie of a rollout, frame by frame.
It comes with an ``enabled`` option, so you can still use the same code on episodes where you don't want to record video.
Note:
You are responsible for calling :meth:`close` on a created VideoRecorder, or else you may leak an encoder process.
"""

def __init__(
self,
env,
path: Optional[str] = None,
metadata: Optional[dict] = None,
enabled: bool = True,
base_path: Optional[str] = None,
):
"""Video recorder renders a nice movie of a rollout, frame by frame.
Args:
env (Env): Environment to take video of.
path (Optional[str]): Path to the video file; will be randomly chosen if omitted.
metadata (Optional[dict]): Contents to save to the metadata file.
enabled (bool): Whether to actually record video, or just no-op (for convenience)
base_path (Optional[str]): Alternatively, path to the video file without extension, which will be added.
Raises:
Error: You can pass at most one of `path` or `base_path`
Error: Invalid path given that must have a particular file extension
"""
try:
# check that moviepy is now installed
import moviepy # noqa: F401
except ImportError as e:
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e

self._async = env.metadata.get("semantics.async")
self.enabled = enabled
self._closed = False

self.render_history: list[np.ndarray] = []
self.env = env

self.render_mode = env.render_mode

if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode:
logger.warn(
f"Disabling video recorder because environment {env} was not initialized with any compatible video "
"mode between `rgb_array` and `rgb_array_list`"
)
# Disable since the environment has not been initialized with a compatible `render_mode`
self.enabled = False

# Don't bother setting anything else if not enabled
if not self.enabled:
return

if path is not None and base_path is not None:
raise error.Error("You can pass at most one of `path` or `base_path`.")

required_ext = ".mp4"
if path is None:
if base_path is not None:
# Base path given, append ext
path = base_path + required_ext
else:
# Otherwise, just generate a unique filename
with tempfile.NamedTemporaryFile(suffix=required_ext) as f:
path = f.name
self.path = path

path_base, actual_ext = os.path.splitext(self.path)

if actual_ext != required_ext:
raise error.Error(f"Invalid path given: {self.path} -- must have file extension {required_ext}.")

self.frames_per_sec = env.metadata.get("render_fps", 30)

self.broken = False

# Dump metadata
self.metadata = metadata or {}
self.metadata["content_type"] = "video/mp4"
self.metadata_path = f"{path_base}.meta.json"
self.write_metadata()

logger.info(f"Starting new video recorder writing to {self.path}")
self.recorded_frames: list[np.ndarray] = []

@property
def functional(self):
"""Returns if the video recorder is functional, is enabled and not broken."""
return self.enabled and not self.broken

def capture_frame(self):
"""Render the given `env` and add the resulting frame to the video."""
frame = self.env.render()
if isinstance(frame, List):
self.render_history += frame
frame = frame[-1]

if not self.functional:
return
if self._closed:
logger.warn("The video recorder has been closed and no frames will be captured anymore.")
return
logger.debug("Capturing video frame: path=%s", self.path)

if frame is None:
if self._async:
return
else:
# Indicates a bug in the environment: don't want to raise
# an error here.
logger.warn(
"Env returned None on `render()`. Disabling further rendering for video recorder by marking as "
f"disabled: path={self.path} metadata_path={self.metadata_path}"
)
self.broken = True
else:
self.recorded_frames.append(frame)

def close(self):
"""Flush all data to disk and close any open frame encoders."""
if not self.enabled or self._closed:
return

# First close the environment
self.env.close()

# Close the encoder
if len(self.recorded_frames) > 0:
try:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
except ImportError as e:
raise error.DependencyNotInstalled("MoviePy is not installed, run `pip install moviepy`") from e

logger.debug(f"Closing video encoder: path={self.path}")
clip = ImageSequenceClip(self.recorded_frames, fps=self.frames_per_sec)
clip.write_videofile(self.path)
else:
# No frames captured. Set metadata.
if self.metadata is None:
self.metadata = {}
self.metadata["empty"] = True

self.write_metadata()

# Stop tracking this for autoclose
self._closed = True

def write_metadata(self):
"""Writes metadata to metadata path."""
with open(self.metadata_path, "w") as f:
json.dump(self.metadata, f)

def __del__(self):
"""Closes the environment correctly when the recorder is deleted."""
# Make sure we've closed up shop when garbage collecting
self.close()


class VecVideoRecorder(VecEnvWrapper):
"""
Wraps a VecEnv or VecEnvWrapper object to record rendered image as mp4 video.
Expand All @@ -20,7 +189,7 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: Prefix to the video name
"""

# video_recorder: video_recorder.VideoRecorder
video_recorder: VideoRecorder

def __init__(
self,
Expand Down Expand Up @@ -71,9 +240,7 @@ def start_video_recorder(self) -> None:

video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
base_path = os.path.join(self.video_folder, video_name)
# self.video_recorder = video_recorder.VideoRecorder(
# env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
# )
self.video_recorder = VideoRecorder(env=self.env, base_path=base_path, metadata={"step_id": self.step_id})

self.video_recorder.capture_frame()
self.recorded_frames = 1
Expand Down

0 comments on commit 39f0900

Please sign in to comment.