From 010ecbe6a05e3f0564d015b3de520209a64b4ccb Mon Sep 17 00:00:00 2001 From: Ali K Date: Wed, 4 Dec 2024 18:47:39 -0800 Subject: [PATCH] minor cleanup --- examples/13_load_gpr_krec_dataset.py | 78 ++++++------------- .../push_dataset_to_hub/gpr_krec_format.py | 11 +-- 2 files changed, 26 insertions(+), 63 deletions(-) diff --git a/examples/13_load_gpr_krec_dataset.py b/examples/13_load_gpr_krec_dataset.py index a2eae42ab..bbd43fb44 100644 --- a/examples/13_load_gpr_krec_dataset.py +++ b/examples/13_load_gpr_krec_dataset.py @@ -23,8 +23,8 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset NUM_ACTUATORS = 5 -KREC_VIDEO_WIDTH = 128 -KREC_VIDEO_HEIGHT = 128 +KREC_VIDEO_WIDTH = 640 +KREC_VIDEO_HEIGHT = 480 TOLERANCE_S = 0.03 REPO_ID = "gpr_test_krec" @@ -62,55 +62,20 @@ } - -def generate_test_video_frame(width: int, height: int, frame_idx: int) -> Image: - """ - Generates a dummy video frame with a white square that moves based on the frame index. - :param width: Width of the video frame. - :param height: Height of the video frame. - :param frame_idx: Index of the frame to determine the position of the white square. - :return: PIL Image object. - """ - frame = Image.new("RGB", (width, height), "black") # Create a black frame - draw = ImageDraw.Draw(frame) - square_size = min(width, height) // 4 - x = (frame_idx * 10) % (width - square_size) - y = (frame_idx * 10) % (height - square_size) - draw.rectangle( - [x, y, x + square_size, y + square_size], fill="white" - ) # Add a white square that moves - return frame - -def load_video_frame(video_frame_data: dict, video_readers: dict, root_dir: Path) -> torch.Tensor: - """Load a specific frame from a video file using timestamp information. +def load_video_frames_batch(video_path: str, num_frames: int) -> np.ndarray: + """Load all video frames at once using decord. Args: - video_frame_data: Dictionary containing 'path' and 'timestamp' keys - video_readers: Dictionary mapping video paths to VideoReader objects - root_dir: Root directory where videos are stored - + video_path: Path to the video file + num_frames: Number of frames to load + Returns: - torch.Tensor: Video frame in (C, H, W) format, normalized to [0,1] + np.ndarray: Batch of video frames in (N, H, W, C) format """ - video_path = root_dir / video_frame_data['path'] - - # Reuse existing VideoReader or create new one - if str(video_path) not in video_readers: - video_readers[str(video_path)] = decord.VideoReader(str(video_path), ctx=decord.cpu(0)) - vr = video_readers[str(video_path)] - - # Convert timestamp to frame index - fps = 30 # len(vr) / vr.get_avg_duration() - frame_idx = int(video_frame_data['timestamp'] * fps) - - # Load the specific frame, clamping frame_idx to valid range - frame_idx = min(max(frame_idx, 0), len(vr) - 1) # Clamp between 0 and last frame - frame = vr[frame_idx].asnumpy() - frame = torch.from_numpy(frame).float() - frame = frame.permute(2, 0, 1) # (H, W, C) -> (C, H, W) - frame = frame / 255.0 - - return frame + vr = decord.VideoReader(str(video_path), ctx=decord.cpu(0)) + frame_indices = list(range(min(len(vr), num_frames))) + frames_batch = vr.get_batch(frame_indices).asnumpy() + return frames_batch def test_gpr_dataset(raw_dir: Path, videos_dir: Path, fps: int): @@ -156,35 +121,36 @@ def test_gpr_dataset(raw_dir: Path, videos_dir: Path, fps: int): to_idx = episode_data_index["to"][ep_idx].item() num_frames = to_idx - from_idx + # Load all video frames for this episode at once + video_path = next(raw_dir.glob("*.krec.mkv")) # Adjust this if you have multiple video files + video_frames_batch = load_video_frames_batch(str(video_path), num_frames) + for frame_idx in range(num_frames): i = from_idx + frame_idx frame_data = hf_dataset[i] - video_frame = load_video_frame( - frame_data["observation.images.camera"], - video_readers=video_readers, # Pass the video_readers dictionary - root_dir=raw_dir - ) - # print(video_frame.shape) frame = { key: frame_data[key].numpy().astype(np.float32) for key in [ "observation.joint_pos", - "observation.joint_vel", + "observation.joint_vel", "observation.ang_vel", "observation.euler_rotation", "action", ] } - frame["observation.images"] = np.array(video_frame) + # Use the pre-loaded video frame, clamping frame_idx if needed + clamped_idx = min(frame_idx, len(video_frames_batch)-1) + frame["observation.images"] = video_frames_batch[clamped_idx] frame["timestamp"] = frame_data["timestamp"] dataset.add_frame(frame) + print(f"Saving episode {ep_idx}") dataset.save_episode( task="walk forward", encode_videos=True, - ) # You might want to customize this task description + ) print(f"Done saving episode {ep_idx}") print("Consolidating dataset...") diff --git a/lerobot/common/datasets/push_dataset_to_hub/gpr_krec_format.py b/lerobot/common/datasets/push_dataset_to_hub/gpr_krec_format.py index 311a7e2f1..461cdaad8 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/gpr_krec_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/gpr_krec_format.py @@ -169,7 +169,6 @@ def load_from_raw( euler_rotation = np.zeros((num_frames, 3), dtype=np.float32) prev_actions = np.zeros((num_frames, num_joints), dtype=np.float32) curr_actions = np.zeros((num_frames, num_joints), dtype=np.float32) - # video_frames = torch.zeros((num_frames, KREC_VIDEO_HEIGHT, KREC_VIDEO_WIDTH, 3), dtype=torch.uint8) # Changed to store actual video frames video_frames = None if video: @@ -181,11 +180,12 @@ def load_from_raw( f"({KREC_VIDEO_HEIGHT}, {KREC_VIDEO_WIDTH})" ) - # Create episode video directory with timestamp - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - ep_video_dir = raw_dir / "ep_videos" # / timestamp + ep_video_dir = raw_dir / "ep_videos" tmp_imgs_dir = ep_video_dir / "tmp_images" + if ep_video_dir.exists(): + shutil.rmtree(ep_video_dir) ep_video_dir.mkdir(parents=True, exist_ok=True) + tmp_imgs_dir.mkdir(parents=True, exist_ok=True) # Save frames as images and encode to video save_images_concurrently(video_frames_batch, tmp_imgs_dir) @@ -206,9 +206,6 @@ def load_from_raw( # Fill data from KREC frames for frame_idx, frame in enumerate(krec_obj): - # Load video frame - - # Joint positions and velocities for j, state in enumerate(frame.get_actuator_states()): joint_pos[frame_idx, j] = state.position