Skip to content

Commit

Permalink
convert krec to lerobot compatible format (untested, no video)
Browse files Browse the repository at this point in the history
  • Loading branch information
alik-git committed Dec 4, 2024
1 parent bf00e0d commit 0fb6526
Showing 1 changed file with 321 additions and 0 deletions.
321 changes: 321 additions & 0 deletions lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
"""
This script loads a GPR dataset from KREC files and converts it to lerobot dataset format.
Example Usage:
python lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py --raw_dir /path/to/krec/files
python lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py --raw_dir /home/kasm-user/ali_repos/kmodel/data/datasets/krec_data/dec_3__11_10am_og_krecs_edited/2024-12-03_17-47-30/
"""

import argparse
from pathlib import Path
from pprint import pprint

import os
from datetime import datetime

import h5py
import numpy as np
import torch
from datasets import Dataset, Features, Sequence, Value
from tqdm import tqdm

import krec
from scipy.spatial.transform import Rotation as R


from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
concatenate_episodes,
)
from lerobot.common.datasets.utils import hf_transform_to_torch

def get_krec_file_type(file_path: str) -> str:
"""Determine if the file is a direct KREC file or MKV-embedded KREC.
Returns:
'krec' for .krec files
'mkv' for .krec.mkv files
raises RuntimeError for invalid extensions
"""
if file_path.endswith(".krec"):
return "krec"
elif file_path.endswith(".krec.mkv"):
return "mkv"
else:
error_msg = (
f"Invalid file extension. Expected '.krec' or '.krec.mkv', got: {file_path}"
)
raise RuntimeError(error_msg)


def load_krec_direct(krec_file_path: str) -> krec.KRec:
"""Load a KREC file directly."""
return krec.KRec.load(krec_file_path)


def load_krec_from_mkv(mkv_file_path: str) -> krec.KRec:
"""Load a KREC file from an MKV file into a manually created temp directory."""

if not os.path.exists(mkv_file_path):
raise FileNotFoundError(f"File not found: {mkv_file_path}")

# Create a parent temp directory if it doesn't exist
parent_temp_dir = os.path.join(os.path.dirname(mkv_file_path), "temp")
os.makedirs(parent_temp_dir, exist_ok=True)

# Create timestamped subdirectory inside parent temp directory
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
temp_dir = os.path.join(parent_temp_dir, f"temp_{timestamp}")
os.makedirs(temp_dir, exist_ok=True)

base_name = os.path.basename(mkv_file_path).split(".krec.mkv")[0]
krec_file_path = os.path.join(temp_dir, f"{base_name}_from_mkv.krec")

# Extract and load from temp directory
krec.extract_from_video(mkv_file_path, krec_file_path)
return krec.KRec.load(krec_file_path)


def load_krec(file_path: str) -> krec.KRec:
"""Smart loader that handles both direct KREC and MKV-embedded KREC files."""
file_type = get_krec_file_type(file_path)

if file_type == "krec":
return load_krec_direct(file_path)
else: # file_type == 'mkv'
return load_krec_from_mkv(file_path)

def convert_quaternion_to_euler(quat):
"""
Convert Quarternion (xyzw) to Euler angles (rpy)
"""
# Normalize
quat = quat / np.linalg.norm(quat)
euler = R.from_quat(quat).as_euler('xyz')

return euler


def check_format(raw_dir) -> bool:
"""Verify KREC files have expected structure"""
print(f"[DEBUG] Checking format for directory: {raw_dir}")
krec_paths = list(raw_dir.glob("*.krec.mkv"))
assert len(krec_paths) > 0, "No KREC files found"
print(f"[DEBUG] Found {len(krec_paths)} KREC files")

for krec_path in krec_paths:
print(f"[DEBUG] Checking file: {krec_path}")
krec_obj = load_krec_from_mkv(str(krec_path))
first_frame = krec_obj[0]

# Verify required data exists
assert len(first_frame.get_actuator_states()) > 0, "No actuator states found"
assert len(first_frame.get_actuator_commands()) > 0, "No actuator commands found"
assert first_frame.get_imu_values() is not None, "No IMU values found"


def load_from_raw(
raw_dir: Path,
videos_dir: Path,
fps: int,
video: bool,
episodes: list[int] | None = None,
encoding: dict | None = None,
):
"""Load data from KREC files into standardized format"""
print(f"[DEBUG] Loading raw data from: {raw_dir}")
krec_files = sorted(raw_dir.glob("*.krec.mkv"))
num_episodes = len(krec_files)
print(f"[DEBUG] Found {len(krec_files)} total KREC files")

ep_dicts = []
ep_ids = episodes if episodes else range(num_episodes)
print(f"[DEBUG] Processing episodes: {list(ep_ids)}")

for ep_idx in tqdm(ep_ids):
ep_path = krec_files[ep_idx]
print(f"[DEBUG] Processing episode {ep_idx} from file: {ep_path}")
krec_obj = load_krec_from_mkv(str(ep_path))

num_frames = len(krec_obj)
first_frame = krec_obj[0]
num_joints = len(first_frame.get_actuator_states())

# Initialize tensors for this episode
joint_pos = torch.zeros((num_frames, num_joints), dtype=torch.float32)
joint_vel = torch.zeros((num_frames, num_joints), dtype=torch.float32)
ang_vel = torch.zeros((num_frames, 3), dtype=torch.float32)
euler_rotation = torch.zeros((num_frames, 3), dtype=torch.float32)
prev_actions = torch.zeros((num_frames, num_joints), dtype=torch.float32)
curr_actions = torch.zeros((num_frames, num_joints), dtype=torch.float32)

# Fill data from KREC frames
for frame_idx, frame in enumerate(krec_obj):
# Joint positions and velocities
for j, state in enumerate(frame.get_actuator_states()):
joint_pos[frame_idx, j] = state.position
joint_vel[frame_idx, j] = state.velocity

# Actions (commands)
for j, cmd in enumerate(frame.get_actuator_commands()):
curr_actions[frame_idx, j] = cmd.position

# IMU data
imu = frame.get_imu_values()
if imu and imu.gyro:
ang_vel[frame_idx] = torch.tensor(
[imu.gyro.x, imu.gyro.y, imu.gyro.z], dtype=torch.float32
)
if imu and imu.quaternion:
quat = torch.tensor(
[imu.quaternion.x, imu.quaternion.y, imu.quaternion.z, imu.quaternion.w]
)
curr_euler = convert_quaternion_to_euler(quat)
euler_rotation[frame_idx] = torch.tensor(curr_euler, dtype=torch.float32)

# Set previous actions (shifted by 1)
prev_actions[1:] = curr_actions[:-1]
prev_actions[0] = curr_actions[0] # First frame uses same action

# Create done signal (True for last frame)
done = torch.zeros(num_frames, dtype=torch.bool)
done[-1] = True

ep_dict = {
"observation.joint_pos": joint_pos,
"observation.joint_vel": joint_vel,
"observation.ang_vel": ang_vel,
"observation.euler_rotation": euler_rotation,
"prev_actions": prev_actions,
"action": curr_actions,
"episode_index": torch.tensor([ep_idx] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": done,
}
ep_dicts.append(ep_dict)

print(f"[DEBUG] Concatenating {len(ep_dicts)} episodes")
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)

return data_dict


def to_hf_dataset(data_dict, video) -> Dataset:
"""Convert to HuggingFace dataset format"""
print("[DEBUG] Converting to HuggingFace dataset format")
print(f"[DEBUG] Input data_dict keys: {list(data_dict.keys())}")
features = {
"observation.joint_pos": Sequence(
length=data_dict["observation.joint_pos"].shape[1],
feature=Value(dtype="float32", id=None),
),
"observation.joint_vel": Sequence(
length=data_dict["observation.joint_vel"].shape[1],
feature=Value(dtype="float32", id=None),
),
"observation.ang_vel": Sequence(
length=data_dict["observation.ang_vel"].shape[1],
feature=Value(dtype="float32", id=None),
),
"observation.euler_rotation": Sequence(
length=data_dict["observation.euler_rotation"].shape[1],
feature=Value(dtype="float32", id=None),
),
"prev_actions": Sequence(
length=data_dict["prev_actions"].shape[1],
feature=Value(dtype="float32", id=None),
),
"action": Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
),
"episode_index": Value(dtype="int64", id=None),
"frame_index": Value(dtype="int64", id=None),
"timestamp": Value(dtype="float32", id=None),
"next.done": Value(dtype="bool", id=None),
"index": Value(dtype="int64", id=None),
}

print("[DEBUG] Creating HuggingFace dataset")
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
print(f"[DEBUG] Dataset size: {len(hf_dataset)}")
print("[DEBUG] Setting transform function")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset


def from_raw_to_lerobot_format(
raw_dir: Path,
videos_dir: Path,
fps: int | None = None,
video: bool = True,
episodes: list[int] | None = None,
encoding: dict | None = None,
):
"""Main function to convert raw data to LeRobot format"""
print(f"[DEBUG] Starting conversion from raw to LeRobot format")
print(f"[DEBUG] Parameters:")
print(f"[DEBUG] - raw_dir: {raw_dir}")
print(f"[DEBUG] - videos_dir: {videos_dir}")
print(f"[DEBUG] - fps: {fps}")
print(f"[DEBUG] - video: {video}")
print(f"[DEBUG] - episodes: {episodes}")
print(f"[DEBUG] - encoding: {encoding}")
check_format(raw_dir)

if fps is None:
fps = 50 # Default FPS for your dataset
print(f"[DEBUG] Using default FPS: {fps}")

data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
hf_dataset = to_hf_dataset(data_dict, video)
print("[DEBUG] Calculating episode data index")
episode_data_index = calculate_episode_data_index(hf_dataset)

info = {
"codebase_version": CODEBASE_VERSION,
"fps": fps,
"video": video,
}
print(f"[DEBUG] Final info: {info}")

return hf_dataset, episode_data_index, info


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert GPR KREC dataset to LeRobot format"
)
parser.add_argument(
"--raw_dir", type=str, required=True, help="Directory containing raw KREC files"
)
parser.add_argument(
"--videos_dir",
type=str,
default="data/temp",
help="Directory for video output (default: data/temp)",
)
parser.add_argument(
"--fps", type=int, default=50, help="Frames per second (default: 50)"
)
parser.add_argument(
"--video", action="store_true", help="Enable video processing (default: False)"
)

args = parser.parse_args()

raw_dir = Path(args.raw_dir)
videos_dir = Path(args.videos_dir)
videos_dir.mkdir(parents=True, exist_ok=True)

print("Converting raw KREC data to LeRobot format...")
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir=raw_dir, videos_dir=videos_dir, fps=args.fps, video=args.video
)
print("Conversion completed!")
print("\nDataset info:")
pprint(hf_dataset)

0 comments on commit 0fb6526

Please sign in to comment.