Skip to content

Commit

Permalink
feat(components): get clproto message type from attribute (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
domire8 authored Dec 8, 2024
1 parent 9231a41 commit dec18cd
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Release Versions:
- feat(controllers): add TF listener in BaseControllerInterface (#169)
- feat(controllers): add TF broadcaster in BaseControllerInterface (#170)
- test(controllers): add TF listener and broadcaster tests (#172)
- feat(components): get clproto message type from attribute (#175)

## 5.0.2

Expand Down
4 changes: 2 additions & 2 deletions source/modulo_components/modulo_components/component.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from threading import Thread
from typing import TypeVar
from typing import Optional, TypeVar

import clproto
from modulo_components.component_interface import ComponentInterface
Expand Down Expand Up @@ -74,7 +74,7 @@ def on_execute_callback(self) -> bool:
return True

def add_output(self, signal_name: str, data: str, message_type: MsgT,
clproto_message_type=clproto.MessageType.UNKNOWN_MESSAGE, default_topic="", fixed_topic=False,
clproto_message_type: Optional[clproto.MessageType] = None, default_topic="", fixed_topic=False,
publish_on_step=True):
"""
Add and configure an output signal of the component.
Expand Down
19 changes: 11 additions & 8 deletions source/modulo_components/modulo_components/component_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def remove_output(self, signal_name):
self.get_logger().debug(f"Removing signal '{signal_name}'.")

def __create_output(self, signal_name: str, data: str, message_type: MsgT,
clproto_message_type: clproto.MessageType, default_topic: str, fixed_topic: bool,
clproto_message_type: Union[clproto.MessageType, None], default_topic: str, fixed_topic: bool,
publish_on_step: bool) -> str:
"""
Helper function to parse the signal name and add an output without Publisher to the dict of outputs.
Expand All @@ -438,23 +438,26 @@ def __create_output(self, signal_name: str, data: str, message_type: MsgT,
:return: The parsed signal name
"""
try:
if message_type == EncodedState and clproto_message_type == clproto.MessageType.UNKNOWN_MESSAGE:
raise AddSignalError(f"Provide a valid clproto message type for outputs of type EncodedState.")
self.declare_output(signal_name, default_topic, fixed_topic)
parsed_signal_name = parse_topic_name(signal_name)
if message_type == Bool or message_type == Float64 or \
message_type == Float64MultiArray or message_type == Int32 or message_type == String:
translator = modulo_writers.write_std_message
elif message_type == EncodedState:
translator = partial(modulo_writers.write_clproto_message,
clproto_message_type=clproto_message_type)
cl_msg_type = clproto_message_type if clproto_message_type else modulo_writers.get_clproto_msg_type(
self.__getattribute__(data))
if cl_msg_type == clproto.MessageType.UNKNOWN_MESSAGE:
raise AddSignalError(f"Provide a valid clproto message type for output '{
signal_name}' of type EncodedState.")
translator = partial(modulo_writers.write_clproto_message, clproto_message_type=cl_msg_type)
elif hasattr(message_type, 'get_fields_and_field_types'):
def write_ros_msg(message, data):
for field in message.get_fields_and_field_types().keys():
setattr(message, field, getattr(data, field))
translator = write_ros_msg
else:
raise AddSignalError("The provided message type is not supported to create a component output.")
raise AddSignalError(
f"The provided message type is not supported to create component output '{signal_name}'.")
self.declare_output(signal_name, default_topic, fixed_topic)
parsed_signal_name = parse_topic_name(signal_name)
self.__outputs[parsed_signal_name] = {"attribute": data, "message_type": message_type,
"translator": translator}
self.__periodic_outputs[parsed_signal_name] = publish_on_step
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Optional, TypeVar

import clproto
from lifecycle_msgs.msg import State
Expand Down Expand Up @@ -361,7 +361,7 @@ def __configure_outputs(self) -> bool:
return success

def add_output(self, signal_name: str, data: str, message_type: MsgT,
clproto_message_type=clproto.MessageType.UNKNOWN_MESSAGE, default_topic="", fixed_topic=False,
clproto_message_type: Optional[clproto.MessageType] = None, default_topic="", fixed_topic=False,
publish_on_step=True):
"""
Add an output signal of the component.
Expand Down
6 changes: 2 additions & 4 deletions source/modulo_components/test/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def publish(self):

component = component_type("minimal_cartesian_output")
component._output = random_pose
component.add_output("cartesian_pose", "_output", EncodedState, clproto.MessageType.CARTESIAN_STATE_MESSAGE,
topic, publish_on_step=publish_on_step)
component.add_output("cartesian_pose", "_output", EncodedState, default_topic=topic, publish_on_step=publish_on_step)
component.publish = publish.__get__(component)
return component

Expand All @@ -65,8 +64,7 @@ def publish(self):

component = component_type("minimal_joint_output")
component._output = random_joint
component.add_output("joint_state", "_output", EncodedState, clproto.MessageType.JOINT_STATE_MESSAGE,
topic, publish_on_step=publish_on_step)
component.add_output("joint_state", "_output", EncodedState, default_topic=topic, publish_on_step=publish_on_step)
component.publish = publish.__get__(component)
return component

Expand Down
44 changes: 44 additions & 0 deletions source/modulo_core/modulo_core/translators/message_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,50 @@
StateT = TypeVar('StateT')


def get_clproto_msg_type(state: StateT) -> clproto.MessageType:
if not isinstance(state, sr.State) or not hasattr(state, 'get_type') or not callable(state.get_type):
return clproto.MessageType.UNKNOWN_MESSAGE

state_type = state.get_type()
if state_type == sr.StateType.STATE:
return clproto.MessageType.STATE_MESSAGE
elif state_type == sr.StateType.SPATIAL_STATE:
return clproto.MessageType.SPATIAL_STATE_MESSAGE
elif state_type == sr.StateType.CARTESIAN_STATE:
return clproto.MessageType.CARTESIAN_STATE_MESSAGE
elif state_type == sr.StateType.CARTESIAN_POSE:
return clproto.MessageType.CARTESIAN_POSE_MESSAGE
elif state_type == sr.StateType.CARTESIAN_TWIST:
return clproto.MessageType.CARTESIAN_TWIST_MESSAGE
elif state_type == sr.StateType.CARTESIAN_ACCELERATION:
return clproto.MessageType.CARTESIAN_ACCELERATION_MESSAGE
elif state_type == sr.StateType.CARTESIAN_WRENCH:
return clproto.MessageType.CARTESIAN_WRENCH_MESSAGE
elif state_type == sr.StateType.JACOBIAN:
return clproto.MessageType.JACOBIAN_MESSAGE
elif state_type == sr.StateType.JOINT_STATE:
return clproto.MessageType.JOINT_STATE_MESSAGE
elif state_type == sr.StateType.JOINT_POSITIONS:
return clproto.MessageType.JOINT_POSITIONS_MESSAGE
elif state_type == sr.StateType.JOINT_VELOCITIES:
return clproto.MessageType.JOINT_VELOCITIES_MESSAGE
elif state_type == sr.StateType.JOINT_ACCELERATIONS:
return clproto.MessageType.JOINT_ACCELERATIONS_MESSAGE
elif state_type == sr.StateType.JOINT_TORQUES:
return clproto.MessageType.JOINT_TORQUES_MESSAGE
elif state_type == sr.StateType.GEOMETRY_SHAPE:
return clproto.MessageType.SHAPE_MESSAGE
elif state_type == sr.StateType.GEOMETRY_ELLIPSOID:
return clproto.MessageType.ELLIPSOID_MESSAGE
elif state_type == sr.StateType.PARAMETER:
return clproto.MessageType.PARAMETER_MESSAGE
elif state_type == sr.StateType.DIGITAL_IO_STATE:
return clproto.MessageType.DIGITAL_IO_STATE_MESSAGE
elif state_type == sr.StateType.ANALOG_IO_STATE:
return clproto.MessageType.ANALOG_IO_STATE_MESSAGE
return clproto.MessageType.UNKNOWN_MESSAGE


def write_xyz(message: Union[geometry.Point, geometry.Vector3], vector: np.array):
"""
Helper function to write a vector to a Point or Vector3 message.
Expand Down

0 comments on commit dec18cd

Please sign in to comment.