Skip to content

Commit

Permalink
Merge branch 'gpr_dataset_integration' of github.com:kscalelabs/lerob…
Browse files Browse the repository at this point in the history
…ot into gpr_dataset_integration
  • Loading branch information
budzianowski committed Dec 5, 2024
2 parents 7ef358f + 010ecbe commit 3047ee7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 43 deletions.
58 changes: 22 additions & 36 deletions examples/13_load_gpr_krec_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from_raw_to_lerobot_format

NUM_ACTUATORS = 5
KREC_VIDEO_WIDTH = 128
KREC_VIDEO_HEIGHT = 128
KREC_VIDEO_WIDTH = 640
KREC_VIDEO_HEIGHT = 480
TOLERANCE_S = 0.03
REPO_ID = "gpr_test_krec"

Expand Down Expand Up @@ -91,36 +91,20 @@ def generate_test_video_frame(width: int, height: int, frame_idx: int) -> Image:
return frame


def load_video_frame(video_frame_data: dict, video_readers: dict, root_dir: Path) -> torch.Tensor:
"""Load a specific frame from a video file using timestamp information.
def load_video_frames_batch(video_path: str, num_frames: int) -> np.ndarray:
"""Load all video frames at once using decord.
Args:
video_frame_data: Dictionary containing 'path' and 'timestamp' keys
video_readers: Dictionary mapping video paths to VideoReader objects
root_dir: Root directory where videos are stored
video_path: Path to the video file
num_frames: Number of frames to load
Returns:
torch.Tensor: Video frame in (C, H, W) format, normalized to [0,1]
np.ndarray: Batch of video frames in (N, H, W, C) format
"""
video_path = root_dir / video_frame_data['path']

# Reuse existing VideoReader or create new one
if str(video_path) not in video_readers:
video_readers[str(video_path)] = decord.VideoReader(str(video_path), ctx=decord.cpu(0))
vr = video_readers[str(video_path)]

# Convert timestamp to frame index
fps = 30 # len(vr) / vr.get_avg_duration()
frame_idx = int(video_frame_data['timestamp'] * fps)

# Load the specific frame, clamping frame_idx to valid range
frame_idx = min(max(frame_idx, 0), len(vr) - 1) # Clamp between 0 and last frame
frame = vr[frame_idx].asnumpy()
frame = torch.from_numpy(frame).float()
frame = frame.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
frame = frame / 255.0

return frame
vr = decord.VideoReader(str(video_path), ctx=decord.cpu(0))
frame_indices = list(range(min(len(vr), num_frames)))
frames_batch = vr.get_batch(frame_indices).asnumpy()
return frames_batch


def test_gpr_dataset(raw_dir: Path, videos_dir: Path, fps: int):
Expand Down Expand Up @@ -164,34 +148,36 @@ def test_gpr_dataset(raw_dir: Path, videos_dir: Path, fps: int):
to_idx = episode_data_index["to"][ep_idx].item()
num_frames = to_idx - from_idx

# Load all video frames for this episode at once
video_path = next(raw_dir.glob("*.krec.mkv")) # Adjust this if you have multiple video files
video_frames_batch = load_video_frames_batch(str(video_path), num_frames)

for frame_idx in range(num_frames):
i = from_idx + frame_idx
frame_data = hf_dataset[i]
video_frame = load_video_frame(
frame_data["observation.images.camera"],
video_readers=video_readers, # Pass the video_readers dictionary
root_dir=raw_dir
)

frame = {
key: frame_data[key].numpy().astype(np.float32)
for key in [
"observation.joint_pos",
"observation.joint_vel",
"observation.joint_vel",
"observation.ang_vel",
"observation.euler_rotation",
"action",
]
}
frame["observation.images"] = np.array(video_frame)
# Use the pre-loaded video frame, clamping frame_idx if needed
clamped_idx = min(frame_idx, len(video_frames_batch)-1)
frame["observation.images"] = video_frames_batch[clamped_idx]
frame["timestamp"] = frame_data["timestamp"]

dataset.add_frame(frame)

print(f"Saving episode {ep_idx}")
dataset.save_episode(
task="walk forward",
encode_videos=True,
) # You might want to customize this task description
)
print(f"Done saving episode {ep_idx}")

print("Consolidating dataset...")
Expand Down
11 changes: 4 additions & 7 deletions lerobot/common/datasets/push_dataset_to_hub/gpr_krec_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def load_from_raw(
euler_rotation = np.zeros((num_frames, 3), dtype=np.float32)
prev_actions = np.zeros((num_frames, num_joints), dtype=np.float32)
curr_actions = np.zeros((num_frames, num_joints), dtype=np.float32)
# video_frames = torch.zeros((num_frames, KREC_VIDEO_HEIGHT, KREC_VIDEO_WIDTH, 3), dtype=torch.uint8) # Changed to store actual video frames

video_frames = None
if video:
Expand All @@ -181,11 +180,12 @@ def load_from_raw(
f"({KREC_VIDEO_HEIGHT}, {KREC_VIDEO_WIDTH})"
)

# Create episode video directory with timestamp
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ep_video_dir = raw_dir / "ep_videos" # / timestamp
ep_video_dir = raw_dir / "ep_videos"
tmp_imgs_dir = ep_video_dir / "tmp_images"
if ep_video_dir.exists():
shutil.rmtree(ep_video_dir)
ep_video_dir.mkdir(parents=True, exist_ok=True)
tmp_imgs_dir.mkdir(parents=True, exist_ok=True)

# Save frames as images and encode to video
save_images_concurrently(video_frames_batch, tmp_imgs_dir)
Expand All @@ -206,9 +206,6 @@ def load_from_raw(

# Fill data from KREC frames
for frame_idx, frame in enumerate(krec_obj):
# Load video frame


# Joint positions and velocities
for j, state in enumerate(frame.get_actuator_states()):
joint_pos[frame_idx, j] = state.position
Expand Down

0 comments on commit 3047ee7

Please sign in to comment.