Skip to content

Commit

Permalink
cleanup and organize
Browse files Browse the repository at this point in the history
  • Loading branch information
alik-git committed Dec 27, 2024
1 parent 1c0aabc commit 95b51b9
Showing 1 changed file with 117 additions and 81 deletions.
198 changes: 117 additions & 81 deletions krecviz/urdf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,58 +8,133 @@

import argparse
import logging
import math
import sys
from pathlib import Path

import math
import numpy as np
import rerun as rr # pip install rerun-sdk
import rerun as rr
import scipy.spatial.transform as st
import trimesh
from PIL import Image
from urdf_parser_py import urdf as urdf_parser # type: ignore[import-untyped]


# Separate debug-print functions.

def debug_print_log_view_coordinates(entity_path_val: str, entity_val: rr.ViewCoordinates, timeless_val: bool) -> None:
"""
Print debug info before calling rr.log(...) for the root view coordinates.
"""
print("======================")
print("rerun_log")
print(f"entity_path = self.add_entity_path_prefix(\"\") with value '{entity_path_val}'")
print(f"entity = rr.ViewCoordinates.RIGHT_HAND_Z_UP with value {entity_val}")
print(f"timeless = {timeless_val}")


def debug_print_log_joint(entity_path_w_prefix: str,
joint: urdf_parser.Joint,
translation: list[float] | None,
rotation: list[list[float]] | None) -> None:
"""
Print debug info before logging the Transform3D of a joint.
"""
print("======================")
print("rerun_log")
print(f"entity_path = entity_path_w_prefix with value '{entity_path_w_prefix}'")
print("Original joint RPY values:")
if joint.origin is not None and joint.origin.rpy is not None:
print(f" => rpy = {[round(float(x), 3) for x in joint.origin.rpy]}")
else:
print(" => rpy = None")

print("entity = rr.Transform3D with:")
print(" translation:", [f"{x:>8.3f}" for x in translation] if translation else None)
print(" mat3x3:")
if rotation:
for row in rotation:
print(" [" + ", ".join(f"{x:>8.3f}" for x in row) + "]")
else:
print(" None")


def debug_print_unsupported_geometry(entity_path_val: str, log_text: str) -> None:
"""
Print debug info for the 'Unsupported geometry' case before logging rr.TextLog.
"""
print("======================")
print("rerun_log")
print(f"entity_path = self.add_entity_path_prefix(\"\") with value '{entity_path_val}'")
print(f"entity = rr.TextLog(...) with value '{log_text}'")


def debug_print_log_trimesh(entity_path: str,
mesh3d_entity: rr.Mesh3D,
timeless_val: bool,
mesh: trimesh.Trimesh) -> None:
"""
Print debug info prior to rr.log(...) a single Trimesh.
"""
print("======================")
print("rerun_log log_trimesh")
print(f"entity_path = entity_path with value '{entity_path}'")
print("entity = rr.Mesh3D(...) with these numeric values:")

# Print only the first three vertex positions for brevity
first_three_vertices = mesh.vertices[:3].tolist()
print(" => vertex_positions (first 3):")
for vertex in first_three_vertices:
print(f" [{', '.join(f'{x:>7.3f}' for x in vertex)}]")

print(f"timeless = {timeless_val}")


def debug_print_final_link_transform(link_name: str,
chain: list[str],
final_tf: np.ndarray) -> None:
"""
Print the final transform accumulated for a link.
"""
print(f"Link '{link_name}': BFS chain = {chain}")
print(" => final_tf (4x4) =")
for row in final_tf:
print(" [{: 8.3f} {: 8.3f} {: 8.3f} {: 8.3f}]".format(*row))
print()


# --- CHANGED ---
# We now have a small custom Euler-to-matrix function that is logically equivalent
# to the old `st.Rotation.from_euler("xyz", rpy).as_matrix()`.
def rotation_from_euler_xyz(rpy):
"""
Given a 3-element list/tuple [rx, ry, rz] of Euler angles in radians,
build the corresponding 3x3 rotation matrix for an 'XYZ' rotation sequence.
In the 'xyz' convention, we first rotate by rx around the X-axis,
then by ry around the Y-axis,
then by rz around the Z-axis.
"""
rx, ry, rz = rpy

# Precompute sines/cosines
cx, sx = math.cos(rx), math.sin(rx)
cy, sy = math.cos(ry), math.sin(ry)
cz, sz = math.cos(rz), math.sin(rz)

# Rotation around X-axis
R_x = np.array([
[1, 0, 0],
[0, cx, -sx],
[0, sx, cx],
], dtype=np.float64)

# Rotation around Y-axis
R_y = np.array([
[ cy, 0, sy],
[ 0, 1, 0],
[-sy, 0, cy],
], dtype=np.float64)

# Rotation around Z-axis
R_z = np.array([
[ cz, -sz, 0],
[ sz, cz, 0],
[ 0, 0, 1],
], dtype=np.float64)

# Final rotation = Rz @ Ry @ Rx
R_final = R_z @ R_y @ R_x
return R_final
return R_z @ R_y @ R_x


class URDFLogger:
Expand Down Expand Up @@ -97,11 +172,10 @@ def log(self) -> None:
entity_val = rr.ViewCoordinates.RIGHT_HAND_Z_UP
timeless_val = True

print("======================")
print("rerun_log")
print(f"entity_path = self.add_entity_path_prefix(\"\") with value '{entity_path_val}'")
print(f"entity = rr.ViewCoordinates.RIGHT_HAND_Z_UP with value {entity_val}")
print(f"timeless = {timeless_val}")
# --- CHANGED ---
# Now we call our debug-print function instead of inlining the prints:
debug_print_log_view_coordinates(entity_path_val, entity_val, timeless_val)

rr.log(
entity_path=entity_path_val,
entity=entity_val,
Expand All @@ -118,7 +192,7 @@ def log(self) -> None:
entity_path = self.link_entity_path(link)
self.log_link(entity_path, link)

# -- AFTER we log everything, let's print a final transform for each link:
# Print final transforms
self.print_final_link_transforms()

def log_link(self, entity_path: str, link: urdf_parser.Link) -> None:
Expand All @@ -135,35 +209,17 @@ def log_joint(self, entity_path: str, joint: urdf_parser.Joint) -> None:
translation = [float(x) for x in joint.origin.xyz]

if joint.origin is not None and joint.origin.rpy is not None:
# We call our custom rotation_from_euler_xyz
rotation_matrix = rotation_from_euler_xyz(joint.origin.rpy)
# Convert to a Python list-of-lists
rotation = [[float(x) for x in row] for row in rotation_matrix]

entity_path_w_prefix = self.add_entity_path_prefix(entity_path)
if isinstance(translation, list) and isinstance(rotation, list):
self.entity_to_transform[entity_path_w_prefix] = (translation, rotation)

# Prepare debug prints
print("======================")
print("rerun_log")
print(f"entity_path = entity_path_w_prefix with value '{entity_path_w_prefix}'")
print("Original joint RPY values:")
if joint.origin is not None and joint.origin.rpy is not None:
print(f" => rpy = {[round(float(x), 3) for x in joint.origin.rpy]}")
else:
print(" => rpy = None")

transform_3d = rr.Transform3D(translation=translation , mat3x3=rotation)
print("entity = rr.Transform3D with:")
print(" translation:", [f"{x:>8.3f}" for x in translation] if translation else None)
print(" mat3x3:")
if rotation:
for row in rotation:
print(" [" + ", ".join(f"{x:>8.3f}" for x in row) + "]")
else:
print(" None")
# --- CHANGED ---
debug_print_log_joint(entity_path_w_prefix, joint, translation, rotation)

transform_3d = rr.Transform3D(translation=translation, mat3x3=rotation)
rr.log(
entity_path=entity_path_w_prefix,
entity=transform_3d,
Expand All @@ -184,6 +240,7 @@ def log_visual(self, entity_path: str, visual: urdf_parser.Visual) -> None:
if visual.origin is not None and visual.origin.rpy is not None:
transform[:3, :3] = rotation_from_euler_xyz(visual.origin.rpy)

# Geometry handling (same as original)
if isinstance(visual.geometry, urdf_parser.Mesh):
resolved_path = self.resolve_ros_path(visual.geometry.filename)
mesh_scale = visual.geometry.scale
Expand All @@ -202,12 +259,13 @@ def log_visual(self, entity_path: str, visual: urdf_parser.Visual) -> None:
radius=visual.geometry.radius,
)
else:
raise ValueError(f"Unsupported geometry type: {type(visual.geometry)}")
log_text = f"Unsupported geometry type: {type(visual.geometry)}"
entity_path_val = self.add_entity_path_prefix("")
print("======================")
print("rerun_log")
print(f"entity_path = self.add_entity_path_prefix(\"\") with value '{entity_path_val}'")
print(f"entity = rr.TextLog(...) with value '{log_text}'")

# --- CHANGED ---
debug_print_unsupported_geometry(entity_path_val, log_text)

rr.log(
entity_path=entity_path_val,
entity=rr.TextLog(log_text),
Expand Down Expand Up @@ -272,22 +330,13 @@ def print_final_link_transforms(self) -> None:
print("\n========== FINAL ACCUMULATED TRANSFORMS PER LINK ==========")
for link in self.urdf.links:
if link.name == root_link:
# The root link is presumably identity in the URDF
print(f"Link '{link.name}': Root link => final transform is identity.\n")
continue

chain = self.urdf.get_chain(root_link, link.name)
# e.g. [root_link, joint0_name, link0_name, joint1_name, link1_name, ...]

# Accumulate transform from 'root_link' to 'link.name'
# We'll build an np.eye(4), multiply each joint's origin along the way
chain = self.urdf.get_chain(root_link, link.name)
final_tf = np.eye(4)
# We skip the 0th element (which is root_link),
# and step in pairs: (joint_i, link_i). If the chain is [a, j0, b, j1, c],
# then chain[1] is j0, chain[2] is b, chain[3] is j1, chain[4] is c, etc.
for i in range(1, len(chain), 2):
joint_name = chain[i]
# find that joint in self.urdf.joints
j = None
for jt in self.urdf.joints:
if jt.name == joint_name:
Expand All @@ -297,23 +346,18 @@ def print_final_link_transforms(self) -> None:
print(f" (!) Could not find joint named '{joint_name}' in URDF?")
continue

# Build local 4x4 from j.origin
xyz = j.origin.xyz if j.origin and j.origin.xyz else [0,0,0]
rpy = j.origin.rpy if j.origin and j.origin.rpy else [0,0,0]
xyz = j.origin.xyz if j.origin and j.origin.xyz else [0, 0, 0]
rpy = j.origin.rpy if j.origin and j.origin.rpy else [0, 0, 0]
local_rot = rotation_from_euler_xyz(rpy)
local_tf = np.eye(4)
local_tf[:3, :3] = local_rot
local_tf[:3, 3] = xyz

final_tf = final_tf @ local_tf

# Now 'final_tf' is the transform from root_link to link.name
# Print it
print(f"Link '{link.name}': BFS chain = {chain}")
print(" => final_tf (4x4) =")
for row in final_tf:
print(" [{: 8.3f} {: 8.3f} {: 8.3f} {: 8.3f}]".format(*row))
print()
# --- CHANGED ---
# We extracted the actual print statement into a helper function:
debug_print_final_link_transform(link.name, chain, final_tf)


def scene_to_trimeshes(scene: trimesh.Scene) -> list[trimesh.Trimesh]:
Expand Down Expand Up @@ -345,13 +389,16 @@ def log_trimesh(entity_path: str, mesh: trimesh.Trimesh) -> None:
vertex_colors = mesh.visual.vertex_colors
elif isinstance(mesh.visual, trimesh.visual.texture.TextureVisuals):
trimesh_material = mesh.visual.material

if mesh.visual.uv is not None:
vertex_texcoords = mesh.visual.uv
# Trimesh uses OpenGL convention for UV, flip the V coordinate for Rerun
# Trimesh uses the OpenGL convention for UV coordinates, so we need to flip the V coordinate
# since Rerun uses the Vulkan/Metal/DX12/WebGPU convention.
vertex_texcoords[:, 1] = 1.0 - vertex_texcoords[:, 1]

# Handle PBR materials or simple texture
if hasattr(trimesh_material, "baseColorTexture") and trimesh_material.baseColorTexture is not None:
# baseColorTexture is a PIL image or array
img = np.asarray(trimesh_material.baseColorTexture)
if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
Expand All @@ -369,12 +416,8 @@ def log_trimesh(entity_path: str, mesh: trimesh.Trimesh) -> None:
except Exception:
pass

# Prepare debug prints
print("======================")
print("rerun_log log_trimesh")
print(f"entity_path = entity_path with value '{entity_path}'")

# Build the rr.Mesh3D
# --- CHANGED ---
# Prepare the rr.Mesh3D, then debug-print all the info in a helper function.
mesh3d_entity = rr.Mesh3D(
vertex_positions=mesh.vertices,
triangle_indices=mesh.faces,
Expand All @@ -384,16 +427,8 @@ def log_trimesh(entity_path: str, mesh: trimesh.Trimesh) -> None:
vertex_texcoords=vertex_texcoords,
)

# Print numeric data for debugging
# Print only the first three vertex positions
first_three_vertices = mesh.vertices[:3].tolist()
print("entity = rr.Mesh3D(...) with these numeric values:")
print(" => vertex_positions (first 3):")
for vertex in first_three_vertices:
print(f" [{', '.join(f'{x:>7.3f}' for x in vertex)}]")

timeless_val = True
print(f"timeless = {timeless_val}")
debug_print_log_trimesh(entity_path, mesh3d_entity, timeless_val, mesh)

rr.log(
entity_path=entity_path,
Expand Down Expand Up @@ -426,6 +461,7 @@ def main() -> None:

filepath = Path(args.filepath).resolve()
is_file = filepath.is_file()
# Changed from the old code to handle uppercase/lowercase URDF:
is_urdf_file = ".urdf" in filepath.name.lower()

if not is_file or not is_urdf_file:
Expand Down

0 comments on commit 95b51b9

Please sign in to comment.