diff --git a/lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py b/lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py new file mode 100644 index 000000000..a7dfa835b --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py @@ -0,0 +1,321 @@ +""" +This script loads a GPR dataset from KREC files and converts it to lerobot dataset format. + +Example Usage: + python lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py --raw_dir /path/to/krec/files + + python lerobot/common/datasets/push_dataset_to_hub/gpt_krec_format.py --raw_dir /home/kasm-user/ali_repos/kmodel/data/datasets/krec_data/dec_3__11_10am_og_krecs_edited/2024-12-03_17-47-30/ +""" + +import argparse +from pathlib import Path +from pprint import pprint + +import os +from datetime import datetime + +import h5py +import numpy as np +import torch +from datasets import Dataset, Features, Sequence, Value +from tqdm import tqdm + +import krec +from scipy.spatial.transform import Rotation as R + + +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION +from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, + concatenate_episodes, +) +from lerobot.common.datasets.utils import hf_transform_to_torch + +def get_krec_file_type(file_path: str) -> str: + """Determine if the file is a direct KREC file or MKV-embedded KREC. + + Returns: + 'krec' for .krec files + 'mkv' for .krec.mkv files + raises RuntimeError for invalid extensions + """ + if file_path.endswith(".krec"): + return "krec" + elif file_path.endswith(".krec.mkv"): + return "mkv" + else: + error_msg = ( + f"Invalid file extension. Expected '.krec' or '.krec.mkv', got: {file_path}" + ) + raise RuntimeError(error_msg) + + +def load_krec_direct(krec_file_path: str) -> krec.KRec: + """Load a KREC file directly.""" + return krec.KRec.load(krec_file_path) + + +def load_krec_from_mkv(mkv_file_path: str) -> krec.KRec: + """Load a KREC file from an MKV file into a manually created temp directory.""" + + if not os.path.exists(mkv_file_path): + raise FileNotFoundError(f"File not found: {mkv_file_path}") + + # Create a parent temp directory if it doesn't exist + parent_temp_dir = os.path.join(os.path.dirname(mkv_file_path), "temp") + os.makedirs(parent_temp_dir, exist_ok=True) + + # Create timestamped subdirectory inside parent temp directory + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + temp_dir = os.path.join(parent_temp_dir, f"temp_{timestamp}") + os.makedirs(temp_dir, exist_ok=True) + + base_name = os.path.basename(mkv_file_path).split(".krec.mkv")[0] + krec_file_path = os.path.join(temp_dir, f"{base_name}_from_mkv.krec") + + # Extract and load from temp directory + krec.extract_from_video(mkv_file_path, krec_file_path) + return krec.KRec.load(krec_file_path) + + +def load_krec(file_path: str) -> krec.KRec: + """Smart loader that handles both direct KREC and MKV-embedded KREC files.""" + file_type = get_krec_file_type(file_path) + + if file_type == "krec": + return load_krec_direct(file_path) + else: # file_type == 'mkv' + return load_krec_from_mkv(file_path) + +def convert_quaternion_to_euler(quat): + """ + Convert Quarternion (xyzw) to Euler angles (rpy) + """ + # Normalize + quat = quat / np.linalg.norm(quat) + euler = R.from_quat(quat).as_euler('xyz') + + return euler + + +def check_format(raw_dir) -> bool: + """Verify KREC files have expected structure""" + print(f"[DEBUG] Checking format for directory: {raw_dir}") + krec_paths = list(raw_dir.glob("*.krec.mkv")) + assert len(krec_paths) > 0, "No KREC files found" + print(f"[DEBUG] Found {len(krec_paths)} KREC files") + + for krec_path in krec_paths: + print(f"[DEBUG] Checking file: {krec_path}") + krec_obj = load_krec_from_mkv(str(krec_path)) + first_frame = krec_obj[0] + + # Verify required data exists + assert len(first_frame.get_actuator_states()) > 0, "No actuator states found" + assert len(first_frame.get_actuator_commands()) > 0, "No actuator commands found" + assert first_frame.get_imu_values() is not None, "No IMU values found" + + +def load_from_raw( + raw_dir: Path, + videos_dir: Path, + fps: int, + video: bool, + episodes: list[int] | None = None, + encoding: dict | None = None, +): + """Load data from KREC files into standardized format""" + print(f"[DEBUG] Loading raw data from: {raw_dir}") + krec_files = sorted(raw_dir.glob("*.krec.mkv")) + num_episodes = len(krec_files) + print(f"[DEBUG] Found {len(krec_files)} total KREC files") + + ep_dicts = [] + ep_ids = episodes if episodes else range(num_episodes) + print(f"[DEBUG] Processing episodes: {list(ep_ids)}") + + for ep_idx in tqdm(ep_ids): + ep_path = krec_files[ep_idx] + print(f"[DEBUG] Processing episode {ep_idx} from file: {ep_path}") + krec_obj = load_krec_from_mkv(str(ep_path)) + + num_frames = len(krec_obj) + first_frame = krec_obj[0] + num_joints = len(first_frame.get_actuator_states()) + + # Initialize tensors for this episode + joint_pos = torch.zeros((num_frames, num_joints), dtype=torch.float32) + joint_vel = torch.zeros((num_frames, num_joints), dtype=torch.float32) + ang_vel = torch.zeros((num_frames, 3), dtype=torch.float32) + euler_rotation = torch.zeros((num_frames, 3), dtype=torch.float32) + prev_actions = torch.zeros((num_frames, num_joints), dtype=torch.float32) + curr_actions = torch.zeros((num_frames, num_joints), dtype=torch.float32) + + # Fill data from KREC frames + for frame_idx, frame in enumerate(krec_obj): + # Joint positions and velocities + for j, state in enumerate(frame.get_actuator_states()): + joint_pos[frame_idx, j] = state.position + joint_vel[frame_idx, j] = state.velocity + + # Actions (commands) + for j, cmd in enumerate(frame.get_actuator_commands()): + curr_actions[frame_idx, j] = cmd.position + + # IMU data + imu = frame.get_imu_values() + if imu and imu.gyro: + ang_vel[frame_idx] = torch.tensor( + [imu.gyro.x, imu.gyro.y, imu.gyro.z], dtype=torch.float32 + ) + if imu and imu.quaternion: + quat = torch.tensor( + [imu.quaternion.x, imu.quaternion.y, imu.quaternion.z, imu.quaternion.w] + ) + curr_euler = convert_quaternion_to_euler(quat) + euler_rotation[frame_idx] = torch.tensor(curr_euler, dtype=torch.float32) + + # Set previous actions (shifted by 1) + prev_actions[1:] = curr_actions[:-1] + prev_actions[0] = curr_actions[0] # First frame uses same action + + # Create done signal (True for last frame) + done = torch.zeros(num_frames, dtype=torch.bool) + done[-1] = True + + ep_dict = { + "observation.joint_pos": joint_pos, + "observation.joint_vel": joint_vel, + "observation.ang_vel": ang_vel, + "observation.euler_rotation": euler_rotation, + "prev_actions": prev_actions, + "action": curr_actions, + "episode_index": torch.tensor([ep_idx] * num_frames), + "frame_index": torch.arange(0, num_frames, 1), + "timestamp": torch.arange(0, num_frames, 1) / fps, + "next.done": done, + } + ep_dicts.append(ep_dict) + + print(f"[DEBUG] Concatenating {len(ep_dicts)} episodes") + data_dict = concatenate_episodes(ep_dicts) + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) + + return data_dict + + +def to_hf_dataset(data_dict, video) -> Dataset: + """Convert to HuggingFace dataset format""" + print("[DEBUG] Converting to HuggingFace dataset format") + print(f"[DEBUG] Input data_dict keys: {list(data_dict.keys())}") + features = { + "observation.joint_pos": Sequence( + length=data_dict["observation.joint_pos"].shape[1], + feature=Value(dtype="float32", id=None), + ), + "observation.joint_vel": Sequence( + length=data_dict["observation.joint_vel"].shape[1], + feature=Value(dtype="float32", id=None), + ), + "observation.ang_vel": Sequence( + length=data_dict["observation.ang_vel"].shape[1], + feature=Value(dtype="float32", id=None), + ), + "observation.euler_rotation": Sequence( + length=data_dict["observation.euler_rotation"].shape[1], + feature=Value(dtype="float32", id=None), + ), + "prev_actions": Sequence( + length=data_dict["prev_actions"].shape[1], + feature=Value(dtype="float32", id=None), + ), + "action": Sequence( + length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) + ), + "episode_index": Value(dtype="int64", id=None), + "frame_index": Value(dtype="int64", id=None), + "timestamp": Value(dtype="float32", id=None), + "next.done": Value(dtype="bool", id=None), + "index": Value(dtype="int64", id=None), + } + + print("[DEBUG] Creating HuggingFace dataset") + hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) + print(f"[DEBUG] Dataset size: {len(hf_dataset)}") + print("[DEBUG] Setting transform function") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + +def from_raw_to_lerobot_format( + raw_dir: Path, + videos_dir: Path, + fps: int | None = None, + video: bool = True, + episodes: list[int] | None = None, + encoding: dict | None = None, +): + """Main function to convert raw data to LeRobot format""" + print(f"[DEBUG] Starting conversion from raw to LeRobot format") + print(f"[DEBUG] Parameters:") + print(f"[DEBUG] - raw_dir: {raw_dir}") + print(f"[DEBUG] - videos_dir: {videos_dir}") + print(f"[DEBUG] - fps: {fps}") + print(f"[DEBUG] - video: {video}") + print(f"[DEBUG] - episodes: {episodes}") + print(f"[DEBUG] - encoding: {encoding}") + check_format(raw_dir) + + if fps is None: + fps = 50 # Default FPS for your dataset + print(f"[DEBUG] Using default FPS: {fps}") + + data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding) + hf_dataset = to_hf_dataset(data_dict, video) + print("[DEBUG] Calculating episode data index") + episode_data_index = calculate_episode_data_index(hf_dataset) + + info = { + "codebase_version": CODEBASE_VERSION, + "fps": fps, + "video": video, + } + print(f"[DEBUG] Final info: {info}") + + return hf_dataset, episode_data_index, info + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert GPR KREC dataset to LeRobot format" + ) + parser.add_argument( + "--raw_dir", type=str, required=True, help="Directory containing raw KREC files" + ) + parser.add_argument( + "--videos_dir", + type=str, + default="data/temp", + help="Directory for video output (default: data/temp)", + ) + parser.add_argument( + "--fps", type=int, default=50, help="Frames per second (default: 50)" + ) + parser.add_argument( + "--video", action="store_true", help="Enable video processing (default: False)" + ) + + args = parser.parse_args() + + raw_dir = Path(args.raw_dir) + videos_dir = Path(args.videos_dir) + videos_dir.mkdir(parents=True, exist_ok=True) + + print("Converting raw KREC data to LeRobot format...") + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + raw_dir=raw_dir, videos_dir=videos_dir, fps=args.fps, video=args.video + ) + print("Conversion completed!") + print("\nDataset info:") + pprint(hf_dataset)