-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
09677cd
commit fa9bce0
Showing
4 changed files
with
149 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
""" Logger for logging data to HDF5 files """ | ||
import os | ||
import uuid | ||
from typing import Dict | ||
|
||
import h5py | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
class HDF5Logger: | ||
def __init__(self, data_name: str, num_actions: int, max_timesteps: int, num_observations: int): | ||
self.data_name = data_name | ||
self.num_actions = num_actions | ||
self.max_timesteps = max_timesteps | ||
self.num_observations = num_observations | ||
self.max_threshold = 1e3 # Adjust this threshold as needed | ||
self.h5_file, self.h5_dict = self._create_h5_file() | ||
self.current_timestep = 0 | ||
|
||
def _create_h5_file(self): | ||
# Create a unique file ID | ||
idd = str(uuid.uuid4()) | ||
h5_file = h5py.File(f"{self.data_name}/{idd}.h5", "w") | ||
|
||
# Create datasets for logging actions and observations | ||
dset_actions = h5_file.create_dataset("actions", (self.max_timesteps, self.num_actions), dtype=np.float32) | ||
dset_2D_command = h5_file.create_dataset("observations/2D_command", (self.max_timesteps, 2), dtype=np.float32) | ||
dset_3D_command = h5_file.create_dataset("observations/3D_command", (self.max_timesteps, 3), dtype=np.float32) | ||
dset_q = h5_file.create_dataset("observations/q", (self.max_timesteps, self.num_actions), dtype=np.float32) | ||
dset_dq = h5_file.create_dataset("observations/dq", (self.max_timesteps, self.num_actions), dtype=np.float32) | ||
dset_ang_vel = h5_file.create_dataset("observations/ang_vel", (self.max_timesteps, 3), dtype=np.float32) | ||
dset_euler = h5_file.create_dataset("observations/euler", (self.max_timesteps, 3), dtype=np.float32) | ||
dset_t = h5_file.create_dataset("observations/t", (self.max_timesteps, 1), dtype=np.float32) | ||
dset_buffer = h5_file.create_dataset("observations/buffer", (self.max_timesteps, self.num_observations), dtype=np.float32) | ||
|
||
# Map datasets for easy access | ||
h5_dict = { | ||
"actions": dset_actions, | ||
"2D_command": dset_2D_command, | ||
"3D_command": dset_3D_command, | ||
"joint_pos": dset_q, | ||
"joint_vel": dset_dq, | ||
"ang_vel": dset_ang_vel, | ||
"euler_rotation": dset_euler, | ||
"t": dset_t, | ||
"buffer": dset_buffer, | ||
} | ||
return h5_file, h5_dict | ||
|
||
def log_data(self, data: Dict[str, np.ndarray]): | ||
if self.current_timestep >= self.max_timesteps: | ||
print(f"Warning: Exceeded maximum timesteps ({self.max_timesteps})") | ||
return | ||
|
||
for key, dataset in self.h5_dict.items(): | ||
if key in data: | ||
dataset[self.current_timestep] = data[key] | ||
|
||
self.current_timestep += 1 | ||
|
||
def close(self): | ||
for key, dataset in self.h5_dict.items(): | ||
max_val = np.max(np.abs(dataset[:])) | ||
if max_val > self.max_threshold: | ||
print(f"Warning: Found very large values in {key}: {max_val}") | ||
print("File will not be saved to prevent corrupted data") | ||
self.h5_file.close() | ||
# Delete the file | ||
os.remove(self.h5_file.filename) | ||
return | ||
|
||
self.h5_file.close() | ||
|
||
@staticmethod | ||
def visualize_h5(h5_file_path: str): | ||
""" | ||
Visualizes the data from an HDF5 file by plotting each variable one by one. | ||
Args: | ||
h5_file_path (str): Path to the HDF5 file. | ||
""" | ||
try: | ||
# Open the HDF5 file | ||
with h5py.File(h5_file_path, "r") as h5_file: | ||
# Extract all datasets | ||
for key in h5_file.keys(): | ||
group = h5_file[key] | ||
if isinstance(group, h5py.Group): | ||
for subkey in group.keys(): | ||
dataset = group[subkey][:] | ||
HDF5Logger._plot_dataset(f"{key}/{subkey}", dataset) | ||
else: | ||
dataset = group[:] | ||
HDF5Logger._plot_dataset(key, dataset) | ||
|
||
except Exception as e: | ||
print(f"Failed to visualize HDF5 file: {e}") | ||
|
||
@staticmethod | ||
def _plot_dataset(name: str, data: np.ndarray): | ||
""" | ||
Helper method to plot a single dataset. | ||
Args: | ||
name (str): Name of the dataset. | ||
data (np.ndarray): Data to be plotted. | ||
""" | ||
plt.figure(figsize=(10, 5)) | ||
if data.ndim == 2: # Handle multi-dimensional data | ||
for i in range(data.shape[1]): | ||
plt.plot(data[:, i], label=f"{name}[{i}]") | ||
else: | ||
plt.plot(data, label=name) | ||
|
||
plt.title(f"Visualization of {name}") | ||
plt.xlabel("Timesteps") | ||
plt.ylabel("Values") | ||
plt.legend(loc="upper right") | ||
plt.grid(True) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
HDF5Logger.visualize_h5("stompypro/6dc85e02-fc8e-42e1-a396-b0bd578e0816.h5") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters