Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
alik-git committed Dec 3, 2024
1 parent 76a9c65 commit f1cbc9e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 35 deletions.
53 changes: 27 additions & 26 deletions sim/compare_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,53 +15,56 @@
"""

import argparse
import h5py
import numpy as np
from pathlib import Path
import argparse


def load_h5_file(file_path):
"""Load H5 file and return a dictionary of datasets"""
data = {}
with h5py.File(file_path, 'r') as f:
with h5py.File(file_path, "r") as f:
# Recursively load all datasets
def load_group(group, prefix=''):
def load_group(group, prefix=""):
for key in group.keys():
path = f"{prefix}/{key}" if prefix else key
if isinstance(group[key], h5py.Group):
load_group(group[key], path)
else:
data[path] = group[key][:]

load_group(f)
return data


def compare_h5_files(issac_path, mujoco_path):
"""Compare two H5 files and print differences"""
print(f"\nLoading files:")
print(f"Isaac: {issac_path}")
print(f"Mujoco: {mujoco_path}")

# Load both files
issac_data = load_h5_file(issac_path)
mujoco_data = load_h5_file(mujoco_path)

print("\nFile lengths:")
print(f"Isaac datasets: {len(issac_data)}")
print(f"Mujoco datasets: {len(mujoco_data)}")

print("\nDataset shapes:")
print("\nIsaac shapes:")
for key, value in issac_data.items():
print(f"{key}: {value.shape}")

print("\nMujoco shapes:")
for key, value in mujoco_data.items():
print(f"{key}: {value.shape}")

# Find common keys
common_keys = set(issac_data.keys()) & set(mujoco_data.keys())
print(f"\nCommon datasets: {len(common_keys)}")

# Find uncommon keys
issac_only_keys = set(issac_data.keys()) - common_keys
mujoco_only_keys = set(mujoco_data.keys()) - common_keys
Expand All @@ -71,47 +74,45 @@ def compare_h5_files(issac_path, mujoco_path):
print(f"\nMujoco only datasets: {len(mujoco_only_keys)}")
for key in mujoco_only_keys:
print(f"Mujoco only: {key}")

# Compare data for common keys
for key in common_keys:
issac_arr = issac_data[key]
mujoco_arr = mujoco_data[key]

print(f"\n========== For {key} ===============")

if issac_arr.shape != mujoco_arr.shape:
print(f"\n{key} - Shape mismatch: Isaac {issac_arr.shape} vs Mujoco {mujoco_arr.shape}")

# Calculate differences
min_shape = min(issac_arr.shape[0], mujoco_arr.shape[0])
if issac_arr.shape != mujoco_arr.shape:
raise ValueError(f"Shapes do not match for {key}. Cannot compare datasets with different shapes.")
diff = np.abs(issac_arr[:min_shape] - mujoco_arr[:min_shape])
max_diff = np.max(diff)
mean_diff = np.mean(diff)

print(f"Max difference: {max_diff:.6f}")
print(f"Mean difference: {mean_diff:.6f}\n")

start_idx = 0
display_timesteps = 10
end_idx = start_idx + display_timesteps
np.set_printoptions(formatter={'float': '{:0.6f}'.format}, suppress=True)

np.set_printoptions(formatter={"float": "{:0.6f}".format}, suppress=True)
print("Isaac:\n", issac_arr[start_idx:end_idx])
print("Mujoco:\n", mujoco_arr[start_idx:end_idx])


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Compare two H5 files from Isaac and Mujoco simulations')
parser.add_argument('--isaac-file', required=True, help='Path to Isaac simulation H5 file')
parser.add_argument('--mujoco-file', required=True, help='Path to Mujoco simulation H5 file')
parser = argparse.ArgumentParser(description="Compare two H5 files from Isaac and Mujoco simulations")
parser.add_argument("--isaac-file", required=True, help="Path to Isaac simulation H5 file")
parser.add_argument("--mujoco-file", required=True, help="Path to Mujoco simulation H5 file")

args = parser.parse_args()

print(f"Isaac path: {args.isaac_file}")
print(f"Mujoco path: {args.mujoco_file}")

compare_h5_files(args.isaac_file, args.mujoco_file)


compare_h5_files(args.isaac_file, args.mujoco_file)
31 changes: 22 additions & 9 deletions sim/h5_logger.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
""" Logger for logging data to HDF5 files """
import os
import uuid
from datetime import datetime
from typing import Dict

import h5py
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt # dependency issues with python 3.8


class HDF5Logger:
def __init__(self, data_name: str, num_actions: int, max_timesteps: int, num_observations: int, h5_out_dir: str = "sim/resources/"):
def __init__(
self,
data_name: str,
num_actions: int,
max_timesteps: int,
num_observations: int,
h5_out_dir: str = "sim/resources/",
):
self.data_name = data_name
self.num_actions = num_actions
self.max_timesteps = max_timesteps
Expand All @@ -32,16 +40,22 @@ def _create_h5_file(self):
h5_file = h5py.File(h5_file_path, "w")

# Create datasets for logging actions and observations
dset_prev_actions = h5_file.create_dataset("prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_prev_actions = h5_file.create_dataset(
"prev_actions", (self.max_timesteps, self.num_actions), dtype=np.float32
)
dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32)
dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32)
dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_dq = h5_file.create_dataset("observations/dq", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_ang_vel = h5_file.create_dataset("observations/ang_vel", (self.max_timesteps, 3), dtype=np.float32)
dset_euler = h5_file.create_dataset("observations/euler", (self.max_timesteps, 3), dtype=np.float32)
dset_t = h5_file.create_dataset("observations/t", (self.max_timesteps, 1), dtype=np.float32)
dset_buffer = h5_file.create_dataset("observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32)
dset_curr_actions = h5_file.create_dataset("curr_actions", (self.max_timesteps, self.num_actions), dtype=np.float32)
dset_buffer = h5_file.create_dataset(
"observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32
)
dset_curr_actions = h5_file.create_dataset(
"curr_actions", (self.max_timesteps, self.num_actions), dtype=np.float32
)

# Map datasets for easy access
h5_dict = {
Expand All @@ -62,11 +76,11 @@ def log_data(self, data: Dict[str, np.ndarray]):
if self.current_timestep >= self.max_timesteps:
print(f"Warning: Exceeded maximum timesteps ({self.max_timesteps})")
return

for key, dataset in self.h5_dict.items():
if key in data:
dataset[self.current_timestep] = data[key]

self.current_timestep += 1

def close(self):
Expand All @@ -79,7 +93,7 @@ def close(self):
# Delete the file
os.remove(self.h5_file.filename)
return

self.h5_file.close()

@staticmethod
Expand Down Expand Up @@ -116,7 +130,6 @@ def _plot_dataset(name: str, data: np.ndarray):
name (str): Name of the dataset.
data (np.ndarray): Data to be plotted.
"""
import matplotlib.pyplot as plt # dependency issues with python 3.8
plt.figure(figsize=(10, 5))
if data.ndim == 2: # Handle multi-dimensional data
for i in range(data.shape[1]):
Expand Down

0 comments on commit f1cbc9e

Please sign in to comment.