diff --git a/README.md b/README.md index c243da496..9b51bbf7c 100644 --- a/README.md +++ b/README.md @@ -29,18 +29,19 @@ The plan that both robots execute is a relativly simple pick and place plan: The code for this plan can be seen below. ``` -from pycram.bullet_world import BulletWorld, Object +from pycram.worlds.bullet_world import BulletWorld +from pycram.world_concepts.world_object import Object from pycram.process_module import simulated_robot from pycram.designators.motion_designator import * from pycram.designators.location_designator import * from pycram.designators.action_designator import * from pycram.designators.object_designator import * -from pycram.enums import ObjectType +from pycram.datastructures.enums import ObjectType, Arms, Grasp, WorldMode -world = BulletWorld() +world = BulletWorld(WorldMode.GUI) kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen.urdf") robot = Object("pr2", ObjectType.ROBOT, "pr2.urdf") -cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", position=[1.4, 1, 0.95]) +cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", pose=Pose([1.4, 1, 0.95])) cereal_desig = ObjectDesignatorDescription(names=["cereal"]) kitchen_desig = ObjectDesignatorDescription(names=["kitchen"]) @@ -56,7 +57,7 @@ with simulated_robot: NavigateAction(target_locations=[pickup_pose.pose]).resolve().perform() - PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=["front"]).resolve().perform() + PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=[Grasp.FRONT]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() @@ -69,6 +70,8 @@ with simulated_robot: PlaceAction(cereal_desig, target_locations=[place_island.pose], arms=[pickup_arm]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() + +world.exit() ``` diff --git a/binder/README.md b/binder/README.md index b10281a0b..d9d6ece00 100644 --- a/binder/README.md +++ b/binder/README.md @@ -99,8 +99,7 @@ RUN cd pycram \ && cd src/neem_interface_python \ && git clone https://github.com/benjaminalt/neem-interface.git src/neem-interface -RUN pip install --requirement ${PYCRAM_WS}/src/pycram/requirements.txt --user -RUN pip install --requirement ${PYCRAM_WS}/src/pycram/src/neem_interface_python/requirements.txt --user \ +RUN pip install --requirement ${PYCRAM_WS}/src/pycram/requirements.txt --user \ && pip cache purge ``` @@ -448,15 +447,3 @@ with simulated_robot: arms=["left"], grasps=["left", "right"]).resolve().perform() ``` - - - - - - - - - - - - diff --git a/binder/pycram-http.rosinstall b/binder/pycram-http.rosinstall index 4ad9ec467..15a9e5238 100644 --- a/binder/pycram-http.rosinstall +++ b/binder/pycram-http.rosinstall @@ -7,18 +7,10 @@ repositories: type: git url: http://github.com/code-iai/iai_robots.git version: master - pr2_common: - type: git - url: https://github.com/PR2/pr2_common.git - version: b34703bcca2b07cadbc3777d3c504c232a0c0c28 kdl_ik_services: type: git url: https://github.com/cram2/kdl_ik_service.git verison: master - pr2_kinematics: - type: git - url: https://github.com/PR2/pr2_kinematics.git - version: kinetic-devel orocos_kinematics_dynamics: type: git url: https://github.com/orocos/orocos_kinematics_dynamics.git diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/config/multiverse_conf.py b/config/multiverse_conf.py new file mode 100644 index 000000000..6463cf151 --- /dev/null +++ b/config/multiverse_conf.py @@ -0,0 +1,72 @@ +import datetime + +from typing_extensions import Type + +from .world_conf import WorldConfig +from pycram.description import ObjectDescription +from pycram.helper import find_multiverse_resources_path +from pycram.object_descriptors.mjcf import ObjectDescription as MJCF + + +class MultiverseConfig(WorldConfig): + # Multiverse Configuration + resources_path = find_multiverse_resources_path() + """ + The path to the Multiverse resources directory. + """ + + # Multiverse Socket Configuration + HOST: str = "tcp://127.0.0.1" + SERVER_HOST: str = HOST + SERVER_PORT: str = 7000 + BASE_CLIENT_PORT: int = 9000 + + # Multiverse Client Configuration + READER_MAX_WAIT_TIME_FOR_DATA: datetime.timedelta = datetime.timedelta(milliseconds=1000) + """ + The maximum wait time for the data in seconds. + """ + + # Multiverse Simulation Configuration + simulation_time_step: datetime.timedelta = datetime.timedelta(milliseconds=10) + simulation_frequency: int = int(1 / simulation_time_step.total_seconds()) + """ + The time step of the simulation in seconds and the frequency of the simulation in Hz. + """ + + simulation_wait_time_factor: float = 1.0 + """ + The factor to multiply the simulation wait time with, this is used to adjust the simulation wait time to account for + the time taken by the simulation to process the request, this depends on the computational power of the machine + running the simulation. + """ + + use_static_mode: bool = True + """ + If True, the simulation will always be in paused state unless the simulate() function is called, this behaves + similar to bullet_world which uses the bullet physics engine. + """ + + use_controller: bool = False + use_controller = use_controller and not use_static_mode + """ + Only used when use_static_mode is False. This turns on the controller for the robot joints. + """ + + default_description_type: Type[ObjectDescription] = MJCF + """ + The default description type for the objects. + """ + + use_physics_simulator_state: bool = True + """ + Whether to use the physics simulator state when restoring or saving the world state. + """ + + clear_cache_at_start = False + + let_pycram_move_attached_objects = False + let_pycram_handle_spawning = False + + position_tolerance = 2e-2 + prismatic_joint_position_tolerance = 2e-2 diff --git a/config/world_conf.py b/config/world_conf.py new file mode 100644 index 000000000..76db4b330 --- /dev/null +++ b/config/world_conf.py @@ -0,0 +1,93 @@ +import math +import os + +from typing_extensions import Tuple, Type +from pycram.description import ObjectDescription +from pycram.object_descriptors.urdf import ObjectDescription as URDF + + +class WorldConfig: + + """ + A class to store the configuration of the world, this can be inherited to create a new configuration class for a + specific world (e.g. multiverse has MultiverseConfig which inherits from this class). + """ + + resources_path = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'resources') + resources_path = os.path.abspath(resources_path) + """ + Global reference for the resources path, this is used to search for the description files of the robot and + the objects. + """ + + cache_dir_name: str = 'cached' + """ + The name of the cache directory. + """ + + cache_dir: str = os.path.join(resources_path, cache_dir_name) + """ + Global reference for the cache directory, this is used to cache the description files of the robot and the objects. + """ + + clear_cache_at_start: bool = True + """ + Whether to clear the cache directory at the start. + """ + + prospection_world_prefix: str = "prospection_" + """ + The prefix for the prospection world name. + """ + + simulation_frequency: int = 240 + """ + The simulation frequency (Hz), used for calculating the equivalent real time in the simulation. + """ + + update_poses_from_sim_on_get: bool = True + """ + Whether to update the poses from the simulator when getting the object poses. + """ + + default_description_type: Type[ObjectDescription] = URDF + """ + The default description type for the objects. + """ + + use_physics_simulator_state: bool = False + """ + Whether to use the physics simulator state when restoring or saving the world state. + Currently with PyBullet, this causes a bug where ray_test does not work correctly after restoring the state using the + simulator, so it is recommended to set this to False in PyBullet. + """ + + let_pycram_move_attached_objects: bool = True + let_pycram_handle_spawning: bool = True + let_pycram_handle_world_sync: bool = True + """ + Whether to let PyCRAM handle the movement of attached objects, the spawning of objects, + and the world synchronization. + """ + + position_tolerance: float = 1e-2 + orientation_tolerance: float = 10 * math.pi / 180 + prismatic_joint_position_tolerance: float = 1e-2 + revolute_joint_position_tolerance: float = 5 * math.pi / 180 + """ + The acceptable error for the position and orientation of an object/link, and the joint positions. + """ + + use_percentage_of_goal: bool = False + acceptable_percentage_of_goal: float = 0.5 + """ + Whether to use a percentage of the goal as the acceptable error. + """ + + raise_goal_validator_error: bool = False + """ + Whether to raise an error if the goals are not achieved. + """ + @classmethod + def get_pose_tolerance(cls) -> Tuple[float, float]: + return cls.position_tolerance, cls.orientation_tolerance diff --git a/demos/pycram_bullet_world_demo/demo.py b/demos/pycram_bullet_world_demo/demo.py index a60df7771..f06fc3d38 100644 --- a/demos/pycram_bullet_world_demo/demo.py +++ b/demos/pycram_bullet_world_demo/demo.py @@ -8,10 +8,12 @@ from pycram.object_descriptors.urdf import ObjectDescription from pycram.world_concepts.world_object import Object from pycram.datastructures.dataclasses import Color +from pycram.ros.viz_marker_publisher import VizMarkerPublisher extension = ObjectDescription.get_file_extension() world = BulletWorld(WorldMode.GUI) + robot = Object("pr2", ObjectType.ROBOT, f"pr2{extension}", pose=Pose([1, 2, 0])) apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment{extension}") @@ -94,3 +96,5 @@ def move_and_detect(obj_type): PlaceAction(spoon_desig, [spoon_target_pose], [pickup_arm]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() + +world.exit() diff --git a/demos/pycram_multiverse_demo/demo.py b/demos/pycram_multiverse_demo/demo.py new file mode 100644 index 000000000..6e71ec112 --- /dev/null +++ b/demos/pycram_multiverse_demo/demo.py @@ -0,0 +1,100 @@ +from pycram.datastructures.dataclasses import Color +from pycram.datastructures.enums import ObjectType, Arms, Grasp +from pycram.datastructures.pose import Pose +from pycram.designators.action_designator import ParkArmsAction, MoveTorsoAction, TransportAction, NavigateAction, \ + LookAtAction, DetectAction, OpenAction, PickUpAction, CloseAction, PlaceAction +from pycram.designators.location_designator import CostmapLocation, AccessingLocation +from pycram.designators.object_designator import BelieveObject, ObjectPart +from pycram.object_descriptors.urdf import ObjectDescription +from pycram.process_module import simulated_robot, with_simulated_robot +from pycram.world_concepts.world_object import Object +from pycram.worlds.multiverse import Multiverse + + +@with_simulated_robot +def move_and_detect(obj_type: ObjectType, pick_pose: Pose): + NavigateAction(target_locations=[Pose([1.7, 2, 0])]).resolve().perform() + + LookAtAction(targets=[pick_pose]).resolve().perform() + + object_desig = DetectAction(BelieveObject(types=[obj_type])).resolve().perform() + + return object_desig + +world = Multiverse(simulation_name='pycram_test') +extension = ObjectDescription.get_file_extension() +robot = Object('pr2', ObjectType.ROBOT, f'pr2{extension}', pose=Pose([1.3, 2, 0.01])) +apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment{extension}") + +milk = Object("milk", ObjectType.MILK, f"milk.stl", pose=Pose([2.4, 2, 1.02]), + color=Color(1, 0, 0, 1)) + +spoon = Object("spoon", ObjectType.SPOON, "spoon.stl", pose=Pose([2.5, 2.2, 0.85]), + color=Color(0, 0, 1, 1)) +apartment.attach(spoon, 'cabinet10_drawer1') + +robot_desig = BelieveObject(names=[robot.name]) +apartment_desig = BelieveObject(names=[apartment.name]) + +with simulated_robot: + + # Transport the milk + ParkArmsAction([Arms.BOTH]).resolve().perform() + + MoveTorsoAction([0.25]).resolve().perform() + + NavigateAction(target_locations=[Pose([1.7, 2, 0])]).resolve().perform() + + LookAtAction(targets=[Pose([2.6, 2.15, 1])]).resolve().perform() + + milk_desig = DetectAction(BelieveObject(types=[milk.obj_type])).resolve().perform() + + TransportAction(milk_desig, [Arms.LEFT], [Pose([2.4, 3, 1.02])]).resolve().perform() + + # Find and navigate to the drawer containing the spoon + handle_desig = ObjectPart(names=["cabinet10_drawer1_handle"], part_of=apartment_desig.resolve()) + drawer_open_loc = AccessingLocation(handle_desig=handle_desig.resolve(), + robot_desig=robot_desig.resolve()).resolve() + + NavigateAction([drawer_open_loc.pose]).resolve().perform() + + OpenAction(object_designator_description=handle_desig, arms=[drawer_open_loc.arms[0]]).resolve().perform() + spoon.detach(apartment) + + # Detect and pickup the spoon + LookAtAction([apartment.get_link_pose("cabinet10_drawer1_handle")]).resolve().perform() + + spoon_desig = DetectAction(BelieveObject(types=[ObjectType.SPOON])).resolve().perform() + + pickup_arm = Arms.LEFT if drawer_open_loc.arms[0] == Arms.RIGHT else Arms.RIGHT + PickUpAction(spoon_desig, [pickup_arm], [Grasp.TOP]).resolve().perform() + + ParkArmsAction([Arms.LEFT if pickup_arm == Arms.LEFT else Arms.RIGHT]).resolve().perform() + + CloseAction(object_designator_description=handle_desig, arms=[drawer_open_loc.arms[0]]).resolve().perform() + + ParkArmsAction([Arms.BOTH]).resolve().perform() + + MoveTorsoAction([0.15]).resolve().perform() + + # Find a pose to place the spoon, move and then place it + spoon_target_pose = Pose([2.35, 2.6, 0.95], [0, 0, 0, 1]) + placing_loc = CostmapLocation(target=spoon_target_pose, reachable_for=robot_desig.resolve()).resolve() + + NavigateAction([placing_loc.pose]).resolve().perform() + + PlaceAction(spoon_desig, [spoon_target_pose], [pickup_arm]).resolve().perform() + + ParkArmsAction([Arms.BOTH]).resolve().perform() + +world.exit() + + +def debug_place_spoon(): + robot.set_pose(Pose([1.66, 2.56, 0.01], [0.0, 0.0, -0.04044101807153309, 0.9991819274072855])) + spoon.set_pose(Pose([1.9601757566599975, 2.06718019258626, 1.0427727691273496], + [0.11157527804553227, -0.7076243776942466, 0.12773644958149588, 0.685931554070963])) + robot.attach(spoon, 'r_gripper_tool_frame') + pickup_arm = Arms.RIGHT + spoon_desig = BelieveObject(names=[spoon.name]) + return pickup_arm, spoon_desig diff --git a/demos/pycram_ur5_demo/demo.py b/demos/pycram_ur5_demo/demo.py index 42d6709d3..491ad704f 100644 --- a/demos/pycram_ur5_demo/demo.py +++ b/demos/pycram_ur5_demo/demo.py @@ -7,9 +7,9 @@ from pycram.worlds.bullet_world import BulletWorld from pycram.datastructures.world import Object from pycram.datastructures.pose import Pose -from pycram.ros.force_torque_sensor import ForceTorqueSensor -from pycram.ros.joint_state_publisher import JointStatePublisher -from pycram.ros.tf_broadcaster import TFBroadcaster +from pycram.ros_utils.force_torque_sensor import ForceTorqueSensor +from pycram.ros_utils.joint_state_publisher import JointStatePublisher +from pycram.ros_utils.tf_broadcaster import TFBroadcaster SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) PYCRAM_DIR = os.path.join(SCRIPT_DIR, os.pardir, os.pardir) diff --git a/demos/pycram_virtual_building_demos/setup/launch_robot.py b/demos/pycram_virtual_building_demos/setup/launch_robot.py index eaa9a08a8..2e51c45e0 100644 --- a/demos/pycram_virtual_building_demos/setup/launch_robot.py +++ b/demos/pycram_virtual_building_demos/setup/launch_robot.py @@ -3,6 +3,9 @@ import rospy import rospkg +from pycram.ros.logging import loginfo +from pycram.ros.ros_tools import create_ros_pack + def launch_pr2(): """ @@ -39,14 +42,14 @@ def launch_robot(launch_file, package='pycram', launch_folder='/launch/'): :param launch_folder: Location of the launch file inside the package """ - rospath = rospkg.RosPack() + rospath = create_ros_pack() uuid = roslaunch.rlutil.get_or_generate_uuid(None, False) roslaunch.configure_logging(uuid) launch = roslaunch.parent.ROSLaunchParent(uuid, [rospath.get_path(package) + launch_folder + launch_file]) launch.start() - rospy.loginfo(f'{launch_file} started') + loginfo(f'{launch_file} started') # Wait for ik server to launch time.sleep(2) diff --git a/doc/source/index.rst b/doc/source/index.rst index 17b3a7942..3e956cc1e 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -52,19 +52,19 @@ The code for this plan can be seen below. .. code-block:: python - from pycram.world.bullet_world import BulletWorld - from pycram.world_concepts.world_concepts import Object + from pycram.worlds.bullet_world import BulletWorld + from pycram.world_concepts.world_object import Object from pycram.process_module import simulated_robot from pycram.designators.motion_designator import * from pycram.designators.location_designator import * from pycram.designators.action_designator import * from pycram.designators.object_designator import * - from pycram.datastructures.enums import ObjectType, Arms, Grasps + from pycram.datastructures.enums import ObjectType, Arms, Grasp, WorldMode - world = BulletWorld() + world = BulletWorld(WorldMode.GUI) kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen.urdf") robot = Object("pr2", ObjectType.ROBOT, "pr2.urdf") - cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", position=[1.4, 1, 0.95]) + cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", pose=Pose([1.4, 1, 0.95])) cereal_desig = ObjectDesignatorDescription(names=["cereal"]) kitchen_desig = ObjectDesignatorDescription(names=["kitchen"]) @@ -80,7 +80,7 @@ The code for this plan can be seen below. NavigateAction(target_locations=[pickup_pose.pose]).resolve().perform() - PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=[Grasps.FRONT]).resolve().perform() + PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=[Grasp.FRONT]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 7923702f2..a36b3f00a 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -68,7 +68,7 @@ Now you can install PyCRAM into your ROS workspace. cd .. catkin_make source devel/setup.bash - echo "~/workspace/ros/devel/setup.bash" >> ~/.bashrc + echo "source ~/workspace/ros/devel/setup.bash" >> ~/.bashrc The cloning and setting up can take several minutes. After the command finishes you should see a number of repositories in your ROS workspace. @@ -105,8 +105,19 @@ Then install the Python packages in the requirements.txt file .. code-block:: shell sudo pip3 install -r requirements.txt - sudo pip3 install -r src/neem_interface_python/requirements.txt +This installs the packages into ``/usr/local/lib``. If you prefer to not clutter your system-wide python installation, +you can also install the packages into the catkin workspace as follows: + +.. code-block:: shell + + # install packages into catkin workspace instead of ~/.local + export PYTHONUSERBASE=~/workspace/ros/devel + # don't install packages that are available in system + export PIP_IGNORE_INSTALLED=0 + + pip3 install -r requirements.txt + pip3 install -r src/neem_interface_python/requirements.txt Building your ROS workspace =========================== diff --git a/doc/source/remarks.rst b/doc/source/remarks.rst index 452f54313..0debe0d8a 100644 --- a/doc/source/remarks.rst +++ b/doc/source/remarks.rst @@ -33,14 +33,3 @@ Dirty Terminals If your terminal gets polluted by PyBullet complaining about incomplete URDF descriptions, you need to first fix your URDF files by inserting the missing tags and second delete the `resources/cached` folder. - -Missing pr2_arm_kinematics -========================== - -Aptitudes autoremove likes to also remove the arm kinematics. Reinstall the missing libraries with - -.. code-block:: shell - - sudo apt-get install ros-noetic-moveit - -Then rebuild your workspace. diff --git a/doc/source/troubleshooting.rst b/doc/source/troubleshooting.rst index f14b4abce..768818778 100644 --- a/doc/source/troubleshooting.rst +++ b/doc/source/troubleshooting.rst @@ -108,32 +108,3 @@ real_robot environments. This is also explained in the `Action Designator Exampl with simulated_robot: NavigateAction([Pose()]).resolve().perform() - -Missing pr2_arm_kinematics -========================== - -Aptitudes autoremove likes to also remove the arm kinematics. The error message looks similar to this. Important is the -missing library libmoveit_kinematics_base.so. This can be fixed by reinstalling the missing libraries. - -.. code-block:: shell - - process[pr2_left_arm_kinematics-3]: started with pid [26862] - pr2_arm_kinematics_node: error while loading shared libraries: libmoveit_kinematics_base.so.1.1.12: cannot open shared object file: No such file or directory - process[pr2_right_arm_kinematics-4]: started with pid [26863] - [pr2_left_arm_kinematics-3] process has died [pid 26862, exit code 127, cmd ~/pycram/devel/lib/pr2_arm_kinematics/pr2_arm_kinematics_node __name:=pr2_left_arm_kinematics __log:=~/.ros/log/ba5e95de-384f-11ee-ab53-97c8787037e5/pr2_left_arm_kinematics-3.log]. - log file: ~/.ros/log/ba5e95de-384f-11ee-ab53-97c8787037e5/pr2_left_arm_kinematics-3*.log - pr2_arm_kinematics_node: error while loading shared libraries: libmoveit_kinematics_base.so.1.1.12: cannot open shared object file: No such file or directory - [pr2_right_arm_kinematics-4] process has died [pid 26863, exit code 127, cmd ~/pycram/devel/lib/pr2_arm_kinematics/pr2_arm_kinematics_node __name:=pr2_right_arm_kinematics __log:=~/.ros/log/ba5e95de-384f-11ee-ab53-97c8787037e5/pr2_right_arm_kinematics-4.log]. - log file: ~/.ros/log/ba5e95de-384f-11ee-ab53-97c8787037e5/pr2_right_arm_kinematics-4*.log - IK server ready. - - -Reinstall the missing libraries with - -.. code-block:: shell - - sudo apt-get install ros-noetic-moveit - -Then rebuild your workspace. - - diff --git a/examples/cram_plan_tutorial.md b/examples/cram_plan_tutorial.md index 7606b9d3d..7ca106a34 100644 --- a/examples/cram_plan_tutorial.md +++ b/examples/cram_plan_tutorial.md @@ -28,7 +28,7 @@ from pycram.designators.location_designator import * from pycram.process_module import simulated_robot from pycram.designators.object_designator import * import anytree -import pycram.plan_failures +import pycram.failures ``` Next we will create a bullet world with a PR2 in a kitchen containing milk and cereal. diff --git a/examples/improving_actions.md b/examples/improving_actions.md index 3e796ed74..4cb46a28e 100644 --- a/examples/improving_actions.md +++ b/examples/improving_actions.md @@ -44,7 +44,7 @@ from random_events.product_algebra import Event, SimpleEvent import pycram.orm.base from pycram.designators.action_designator import MoveTorsoActionPerformable -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.designators.object_designator import ObjectDesignatorDescription from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object @@ -54,7 +54,7 @@ from pycram.datastructures.pose import Pose from pycram.ros.viz_marker_publisher import VizMarkerPublisher from pycram.process_module import ProcessModule, simulated_robot from pycram.designators.specialized_designators.probabilistic.probabilistic_action import MoveAndPickUp, Arms, Grasp -from pycram.tasktree import task_tree, reset_tree +from pycram.tasktree import task_tree, TaskTree ProcessModule.execution_delay = False np.random.seed(69) @@ -76,7 +76,7 @@ Now we construct an empty world with just a floating milk, where we can learn ab ```python world = BulletWorld(WorldMode.DIRECT) print(world.prospection_world) -robot = Object(robot_description.name, ObjectType.ROBOT, robot_description.name + ".urdf") +robot = Object("pr2", ObjectType.ROBOT, "pr2.urdf") milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) viz_marker_publisher = VizMarkerPublisher() milk_description = ObjectDesignatorDescription(types=[ObjectType.MILK]).ground() @@ -102,20 +102,20 @@ After finishing the experiments, we insert the results into the database. ```python pycram.orm.base.ProcessMetaData().description = "Experimenting with Pick Up Actions" -fpa.sample_amount = 500 -print(world.current_world) +fpa.sample_amount = 100 with simulated_robot: - print(world.current_world) fpa.batch_rollout() -task_tree.insert(session) -reset_tree() +task_tree.root.insert(session) session.commit() +task_tree.reset_tree() ``` Let's query the data that is needed to learn a pick up action and have a look at it. ```python samples = pd.read_sql(fpa.query_for_database(), engine) +samples["arm"] = samples["arm"].astype(str) +samples["grasp"] = samples["grasp"].astype(str) samples ``` @@ -147,7 +147,7 @@ Let's make a monte carlo estimate on the success probability of the new model. ```python fpa.policy = model -fpa.sample_amount = 100 +fpa.sample_amount = 500 with simulated_robot: fpa.batch_rollout() ``` diff --git a/examples/interface_examples/giskard.md b/examples/interface_examples/giskard.md index 34dc5334f..075230857 100644 --- a/examples/interface_examples/giskard.md +++ b/examples/interface_examples/giskard.md @@ -15,7 +15,7 @@ jupyter: # Giskard interface in PyCRAM This notebook should provide you with an example on how to use the Giskard interface. This includes how to use the -interface, how it actually works and how to extend it. +interface, how it actually works, and how to extend it. We start by installing and launching Giskard. For the installation please follow the instructions [here](https://github.com/SemRoCo/giskardpy). @@ -67,7 +67,7 @@ in the BulletWorld with the pose and joint state of the real robot. You might need to change to topic names to fit the topic names as published by your robot. ```python -from pycram.ros.robot_state_updater import RobotStateUpdater +from pycram.ros_utils.robot_state_updater import RobotStateUpdater r = RobotStateUpdater("/tf", "/joint_states") ``` diff --git a/examples/intro.md b/examples/intro.md index c45f4efac..c9d3cb5cc 100644 --- a/examples/intro.md +++ b/examples/intro.md @@ -30,10 +30,10 @@ A BulletWorld can be created by simply creating an object of the BulletWorld cla ```python from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object -from pycram.datastructures.enums import ObjectType +from pycram.datastructures.enums import ObjectType, WorldMode from pycram.datastructures.pose import Pose -world = BulletWorld() +world = BulletWorld(mode=WorldMode.GUI) ``` The BulletWorld allows to render images from arbitrary positions. In the following example we render images with the @@ -342,6 +342,8 @@ cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", p ``` ```python +from pycram.datastructures.enums import Grasp + cereal_desig = ObjectDesignatorDescription(names=["cereal"]) kitchen_desig = ObjectDesignatorDescription(names=["kitchen"]) robot_desig = ObjectDesignatorDescription(names=["pr2"]).resolve() @@ -355,7 +357,7 @@ with simulated_robot: NavigateAction(target_locations=[pickup_pose.pose]).resolve().perform() - PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=["front"]).resolve().perform() + PickUpAction(object_designator_description=cereal_desig, arms=[pickup_arm], grasps=[Grasp.FRONT]).resolve().perform() ParkArmsAction([Arms.BOTH]).resolve().perform() @@ -380,10 +382,10 @@ Task trees are a hierarchical representation of all Actions involved in a plan. inspect and restructure the execution order of Actions in the plan. ```python -import pycram.task +import pycram.tasktree import anytree -tt = pycram.task.task_tree +tt = pycram.tasktree.task_tree print(anytree.RenderTree(tt)) ``` diff --git a/examples/language.md b/examples/language.md index 0f036e673..197c879bb 100644 --- a/examples/language.md +++ b/examples/language.md @@ -180,7 +180,7 @@ with simulated_robot: ## Combination of Expressions You can also combine different language expressions to further structure your plans. If you combine sequential and -parallel expression please keep in mind that sequential expressions bind stringer than parallel ones. For example: +parallel expression please keep in mind that sequential expressions bind stronger than parallel ones. For example: ``` navigate | park + move_torso @@ -254,7 +254,7 @@ We will see how exceptions are handled at a simple example. from pycram.designators.action_designator import * from pycram.process_module import simulated_robot from pycram.language import Code -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure def code_test(): diff --git a/examples/location_designator.md b/examples/location_designator.md index 6ba478932..1fa54ef03 100644 --- a/examples/location_designator.md +++ b/examples/location_designator.md @@ -68,9 +68,9 @@ print(pose) ## Reachable -Next we want to locations from where the robot can reach a specific point, like an object the robot should pick up. This +Next we want to have locations from where the robot can reach a specific point, like an object the robot should pick up. This can also be done with the {meth}`~pycram.designators.location_designator.CostmapLocation` description, but this time we need to provide an additional argument. -The additional argument is the robo which should be able to reach the pose. +The additional argument is the robot which should be able to reach the pose. Since a robot is needed we will use the PR2 and use a milk as a target point for the robot to reach. The torso of the PR2 will be set to 0.2 since otherwise the arms of the robot will be too low to reach on the countertop. @@ -97,7 +97,7 @@ print(location_description.resolve()) As you can see we get a pose near the countertop where the robot can be placed without colliding with it. Furthermore, we get a list of arms with which the robot can reach the given object. -## Visibile +## Visible The {meth}`~pycram.designators.location_designator.CostmapLocation` can also find position from which the robot can see a given object or location. This is very similar to how reachable locations are described, meaning we provide a object designator or a pose and a robot @@ -215,7 +215,7 @@ print(access_location.pose) ## Giskard Location -Some robots like the HSR or the Stretch2 need a full-body ik solver to utilize the whole body. For this case robots +Some robots like the HSR or the Stretch2 need a full-body ik solver to utilize the whole body. For this case the {meth}`~pycram.designators.specialized_designators.location.giskard_location.GiskardLocation` can be used. This location designator uses giskard as an ik solver to find a pose for the robot to reach a target pose. diff --git a/examples/migrate_neems.md b/examples/migrate_neems.md index d49634354..9fc8e63c5 100644 --- a/examples/migrate_neems.md +++ b/examples/migrate_neems.md @@ -23,7 +23,7 @@ connect your pycram process to it. After you recorded your data locally you can migrate the data using the `migrate_neems` function. -First, lets create an in memory database engine called `source_engine` where we record our current process. +First, lets create an in-memory database engine called `source_engine` where we record our current process. ```python import sqlalchemy.orm diff --git a/examples/minimal_task_tree.md b/examples/minimal_task_tree.md index b057fe463..ff93d7680 100644 --- a/examples/minimal_task_tree.md +++ b/examples/minimal_task_tree.md @@ -31,13 +31,13 @@ from pycram.designators.object_designator import * from pycram.datastructures.pose import Pose from pycram.datastructures.enums import ObjectType, WorldMode import anytree -import pycram.plan_failures +import pycram.failures ``` Next we will create a bullet world with a PR2 in a kitchen containing milk and cereal. ```python -world = BulletWorld(WorldMode.GUI) +world = BulletWorld(WorldMode.DIRECT) pr2 = Object("pr2", ObjectType.ROBOT, "pr2.urdf") kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen.urdf") milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) @@ -82,59 +82,58 @@ plan() Now we get the task tree from its module and render it. Rendering can be done with any render method described in the anytree package. We will use ascii rendering here for ease of displaying. ```python -tt = pycram.task.task_tree -print(anytree.RenderTree(tt)) +tt = pycram.tasktree.task_tree +print(anytree.RenderTree(tt.root)) ``` As we see every task in the plan got recorded correctly. It is noticeable that the tree begins with a NoOperation node. This is done because several, not connected, plans that get executed after each other should still appear in the task tree. Hence, a NoOperation node is the root of any tree. If we re-execute the plan we would see them appear in the same tree even though they are not connected. ```python -world.reset_bullet_world() +world.reset_current_world() plan() -print(anytree.RenderTree(tt)) +print(anytree.RenderTree(tt.root)) ``` Projecting a plan in a new environment with its own task tree that only exists while the projected plan is running can be done with the ``with`` keyword. When this is done, both the bullet world and task tree are saved and new, freshly reset objects are available. At the end of a with block the old state is restored. The root for such things is then called ``simulation()``. ```python with pycram.tasktree.SimulatedTaskTree() as stt: - print(anytree.RenderTree(pycram.task.task_tree)) -print(anytree.RenderTree(pycram.task.task_tree)) + print(anytree.RenderTree(pycram.tasktree.task_tree.root)) +print(anytree.RenderTree(pycram.tasktree.task_tree.root)) ``` Task tree can be manipulated with ordinary anytree manipulation. If we for example want to discard the second plan, we would write ```python tt.root.children = (tt.root.children[0],) -print(anytree.RenderTree(tt, style=anytree.render.AsciiStyle())) +print(anytree.RenderTree(tt.root, style=anytree.render.AsciiStyle())) ``` - We can now re-execute this (modified) plan by executing the leaf in pre-ordering iteration using the anytree functionality. This will not append the re-execution to the task tree. ```python world.reset_world() with simulated_robot: [node.code.execute() for node in tt.root.leaves] -print(anytree.RenderTree(pycram.task.task_tree, style=anytree.render.AsciiStyle())) +print(anytree.RenderTree(tt.root, style=anytree.render.AsciiStyle())) ``` Nodes in the task tree contain additional information about the status and time of a task. ```python -print(pycram.task.task_tree.children[0]) +print(pycram.tasktree.task_tree.root.children[0]) ``` The task tree can also be reset to an empty one by invoking ```python -pycram.tasktree.reset_tree() -print(anytree.RenderTree(pycram.task.task_tree, style=anytree.render.AsciiStyle())) +pycram.tasktree.task_tree.reset_tree() +print(anytree.RenderTree(pycram.tasktree.task_tree.root, style=anytree.render.AsciiStyle())) ``` If a plan fails using the PlanFailure exception, the plan will not stop. Instead, the error will be logged and saved in the task tree as a failed subtask. First let's create a simple failing plan and execute it. ```python -@pycram.task.with_tree +@pycram.tasktree.with_tree def failing_plan(): raise pycram.plan_failures.PlanFailure("Oopsie!") @@ -147,8 +146,8 @@ except pycram.plan_failures.PlanFailure as e: We can now investigate the nodes of the tree, and we will see that the tree indeed contains a failed task. ```python -print(anytree.RenderTree(pycram.task.task_tree, style=anytree.render.AsciiStyle())) -print(pycram.tasktree.task_tree.children[0]) +print(anytree.RenderTree(pycram.tasktree.task_tree.root, style=anytree.render.AsciiStyle())) +print(pycram.tasktree.task_tree.root.children[0]) ``` ```python diff --git a/examples/motion_designator.md b/examples/motion_designator.md index e059a6172..a1eb4d8a1 100644 --- a/examples/motion_designator.md +++ b/examples/motion_designator.md @@ -16,7 +16,7 @@ jupyter: Motion designators are similar to action designators, but unlike action designators, motion designators represent atomic low-level motions. Motion designators only take the parameter that they should execute and not a list of possible -parameters, like the other designators. Like action designators, motion designators can be performed, performing motion +parameters, like the other designators. Like action designators, motion designators can be performed. Performing a motion designator verifies the parameter and passes the designator to the respective process module. Since motion designators perform a motion on the robot, we need a robot which we can use. Therefore, we will create a diff --git a/examples/orm_example.md b/examples/orm_example.md index c183a2be9..b39f2cb0f 100644 --- a/examples/orm_example.md +++ b/examples/orm_example.md @@ -1,4 +1,4 @@ ---- +from pycram.designators.action_designator import ActionAbstract--- jupyter: jupytext: text_representation: @@ -15,10 +15,9 @@ jupyter: # Hands on Object Relational Mapping in PyCram This tutorial will walk you through the serialization of a minimal plan in pycram. -First we will import sqlalchemy, create an in memory database and connect a session to it. +First we will import sqlalchemy, create an in-memory database and connect a session to it. ```python -import sqlalchemy import sqlalchemy.orm engine = sqlalchemy.create_engine("sqlite+pysqlite:///:memory:", echo=False) @@ -29,7 +28,6 @@ session Next we create the database schema using the sqlalchemy functionality. For that we need to import the base class of pycram.orm. ```python -import pycram.orm.base import pycram.orm.action_designator pycram.orm.base.Base.metadata.create_all(engine) session.commit() @@ -48,9 +46,10 @@ from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object from pycram.designators.object_designator import * from pycram.datastructures.pose import Pose +from pycram.orm.base import ProcessMetaData import anytree -world = BulletWorld(WorldMode.GUI) +world = BulletWorld(WorldMode.DIRECT) pr2 = Object("pr2", ObjectType.ROBOT, "pr2.urdf") kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen.urdf") milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) @@ -85,9 +84,9 @@ def plan(): plan() # set description of what we are doing -pycram.orm.base.ProcessMetaData().description = "Tutorial for getting familiar with the ORM." +ProcessMetaData().description = "Tutorial for getting familiar with the ORM." task_tree = pycram.tasktree.task_tree -print(anytree.RenderTree(task_tree)) +print(anytree.RenderTree(task_tree.root)) ``` Next we serialize the task tree by recursively inserting from its root. @@ -163,25 +162,16 @@ class ORMSaying(Action): text: Mapped[str] # define brand new action designator - +# Since this class is derived from ActionAbstract, we do not need to manually define the insert() and to_sql() function, the mapping is done automatically. We just have to tell the class, which ORMclass it is supposed to use. @dataclass -class SayingActionPerformable(ActionDesignatorDescription.Action): +class SayingActionPerformable(ActionAbstract): text: str + orm_class = ORMSaying @with_tree def perform(self) -> None: print(self.text) - - def to_sql(self) -> ORMSaying: - return ORMSaying(self.text) - - def insert(self, session: Session, *args, **kwargs) -> ORMSaying: - action = super().insert(session) - session.add(action) - session.commit() - return action - ``` Now we got our new ActionDesignator called Saying and its ORM version. Since this class got created after all other classes got inserted into the database (in the beginning of the notebook) we have to insert it manually. diff --git a/examples/orm_querying_examples.md b/examples/orm_querying_examples.md index 610c14f99..b24c7f740 100644 --- a/examples/orm_querying_examples.md +++ b/examples/orm_querying_examples.md @@ -20,35 +20,29 @@ In this tutorial, we will get to see more examples of ORM querying. First, we will gather a lot of data. In order to achieve that we will write a randomized experiment for grasping a couple of objects. In the experiment the robot will try to grasp a randomized object using random poses and torso heights. - ```python from tf import transformations import itertools -import time from typing import Optional, List, Tuple - import numpy as np - import sqlalchemy.orm -import tf import tqdm - import pycram.orm.base from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object as BulletWorldObject -from pycram.designators.action_designator import MoveTorsoAction, PickUpAction, NavigateAction, ParkArmsAction, ParkArmsActionPerformable, MoveTorsoActionPerformable +from pycram.designators.action_designator import MoveTorsoAction, PickUpAction, NavigateAction, ParkArmsAction, + ParkArmsActionPerformable, MoveTorsoActionPerformable from pycram.designators.object_designator import ObjectDesignatorDescription -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import ProcessModule -from pycram.datastructures.enums import Arms, ObjectType, Grasp - +from pycram.datastructures.enums import Arms, ObjectType, Grasp, WorldMode from pycram.process_module import simulated_robot -import sqlalchemy.orm -import pycram.orm -from pycram.orm.base import Position, RobotState -from pycram.orm.tasktree import TaskTreeNode from pycram.orm.action_designator import PickUpAction as ORMPickUpAction +from pycram.orm.base import RobotState, Position, ProcessMetaData, Pose as ORMPose +from pycram.orm.tasktree import TaskTreeNode from pycram.orm.object_designator import Object +from pycram.tasktree import task_tree, TaskTree +import pycram.orm import sqlalchemy.sql import pandas as pd @@ -57,7 +51,7 @@ from pycram.datastructures.pose import Pose np.random.seed(420) ProcessModule.execution_delay = False -pycram.orm.base.ProcessMetaData().description = "Tutorial for learning from experience in a Grasping action." +ProcessMetaData().description = "Tutorial for learning from experience in a Grasping action." class GraspingExplorer: @@ -81,16 +75,17 @@ class GraspingExplorer: self.robots: List[Tuple[str, str]] = [("pr2", "pr2.urdf")] if not objects: - self.objects: List[Tuple[str, ObjectType, str]] = [("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl"), - ("bowl", ObjectType.BOWL, "bowl.stl"), - ("milk", ObjectType.MILK, "milk.stl"), - ("spoon", ObjectType.SPOON, "spoon.stl")] + self.objects: List[Tuple[str, ObjectType, str]] = [ + ("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl"), + ("bowl", ObjectType.BOWL, "bowl.stl"), + ("milk", ObjectType.MILK, "milk.stl"), + ("spoon", ObjectType.SPOON, "spoon.stl")] if not arms: - self.arms: List[str] = [Arms.LEFT, Arms.RIGHT] + self.arms: List[Arms] = [Arms.LEFT, Arms.RIGHT] if not grasps: - self.grasps: List[str] = [Grasp.LEFT, Grasp.RIGHT, Grasp.FRONT, Grasp.TOP] + self.grasps: List[Grasp] = [Grasp.LEFT, Grasp.RIGHT, Grasp.FRONT, Grasp.TOP] # store trials per scenario self.samples_per_scenario: int = samples_per_scenario @@ -111,7 +106,7 @@ class GraspingExplorer: progress_bar = tqdm.tqdm( total=np.prod([len(p) for p in self.hyper_parameters]) * self.samples_per_scenario) - self.world = BulletWorld("DIRECT") + self.world = BulletWorld(WorldMode.DIRECT) # for every robot for robot, robot_urdf in self.robots: @@ -171,8 +166,8 @@ class GraspingExplorer: self.total_tries += 1 # insert into database - pycram.tasktree.task_tree.insert(session, use_progress_bar=False) - pycram.tasktree.reset_tree() + task_tree.root.insert(session, use_progress_bar=False) + task_tree.reset_tree() progress_bar.update() progress_bar.set_postfix(success_rate=(self.total_tries - self.total_failures) / @@ -207,17 +202,16 @@ Let's say we want to select positions of robots that were able to grasp a specif from sqlalchemy import select from pycram.datastructures.enums import ObjectType -milk = BulletWorldObject("Milk", ObjectType.MILK, "milk.stl") +milk = BulletWorldObject("milk", ObjectType.MILK, "milk.stl") # query all relative robot positions in regard to an objects position # make sure to order the joins() correctly query = (select(ORMPickUpAction.arm, ORMPickUpAction.grasp, RobotState.torso_height, Position.x, Position.y) - .join(TaskTreeNode.code) - .join(Code.designator.of_type(ORMPickUpAction)) + .join(TaskTreeNode.action.of_type(ORMPickUpAction)) .join(ORMPickUpAction.robot_state) .join(RobotState.pose) - .join(pycram.orm.base.Pose.position) - .join(ORMPickUpAction.object).where(Object.type == milk.type) + .join(ORMPose.position) + .join(ORMPickUpAction.object).where(Object.obj_type == milk.obj_type) .where(TaskTreeNode.status == "SUCCEEDED")) print(query) @@ -235,19 +229,16 @@ The effect of this function can also be seen in the printed query of above's out Another interesting query: Let's say we want to select the torso height and positions of robots relative to the object they were trying to grasp: ```python -from pycram.orm.base import Pose as ORMPose - robot_pose = sqlalchemy.orm.aliased(ORMPose) object_pose = sqlalchemy.orm.aliased(ORMPose) robot_position = sqlalchemy.orm.aliased(Position) object_position = sqlalchemy.orm.aliased(Position) -query = (select(TaskTreeNode.status, Object.type, +query = (select(TaskTreeNode.status, Object.obj_type, sqlalchemy.label("relative torso height", object_position.z - RobotState.torso_height), sqlalchemy.label("x", robot_position.x - object_position.x), sqlalchemy.label("y", robot_position.y - object_position.y)) - .join(TaskTreeNode.code) - .join(Code.designator.of_type(ORMPickUpAction)) + .join(TaskTreeNode.action.of_type(ORMPickUpAction)) .join(ORMPickUpAction.robot_state) .join(robot_pose, RobotState.pose) .join(robot_position, robot_pose.position) diff --git a/launch/ik_and_description.launch b/launch/ik_and_description.launch index 628fd7f82..4484eaf86 100644 --- a/launch/ik_and_description.launch +++ b/launch/ik_and_description.launch @@ -44,6 +44,12 @@ textfile="$(find pycram)/resources/robots/stretch_description.urdf"/> + + + + + diff --git a/package.xml b/package.xml index f3b332d63..b13066606 100644 --- a/package.xml +++ b/package.xml @@ -53,7 +53,10 @@ geometry_msgs geometry_msgs + pr2_arm_kinematics + moveit_kinematics + pr2_common diff --git a/pycram-https.rosinstall b/pycram-https.rosinstall index 64bbf0430..94d873ea8 100644 --- a/pycram-https.rosinstall +++ b/pycram-https.rosinstall @@ -11,15 +11,7 @@ repositories: type: git url: https://github.com/cram2/pycram.git version: dev - pr2_common: - type: git - url: https://@github.com/PR2/pr2_common.git - version: b34703bcca2b07cadbc3777d3c504c232a0c0c28 kdl_ik_services: type: git url: https://github.com/cram2/kdl_ik_service.git version: master - pr2_kinematics: - type: git - url: https://github.com/PR2/pr2_kinematics.git - version: kinetic-devel diff --git a/pycram.rosinstall b/pycram.rosinstall index ccbb0cd22..8628e2c22 100644 --- a/pycram.rosinstall +++ b/pycram.rosinstall @@ -11,15 +11,7 @@ repositories: type: git url: git@github.com:cram2/pycram.git version: dev - pr2_common: - type: git - url: git@github.com:PR2/pr2_common.git - version: b34703bcca2b07cadbc3777d3c504c232a0c0c28 kdl_ik_services: type: git url: git@github.com:cram2/kdl_ik_service.git version: master - pr2_kinematics: - type: git - url: git@github.com:PR2/pr2_kinematics.git - version: kinetic-devel diff --git a/requirements-resolver.txt b/requirements-resolver.txt index 9f31a1cb0..efcd1c22b 100644 --- a/requirements-resolver.txt +++ b/requirements-resolver.txt @@ -1,3 +1,2 @@ -r requirements.txt -probabilistic_model>=5.0.3 -random_events>=3.0.4 + diff --git a/requirements.txt b/requirements.txt index ba876360f..d4a612b1d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,9 @@ -r requirements-setuptools.txt gitpython>=3.1.37 -pybullet~=3.2.5 -rospkg~=1.4.0 -roslibpy~=1.2.1 -# rospy~=1.14.11 +pycram_bullet==3.2.8 pathlib~=1.0.1 numpy==1.24.3 pytransform3d~=1.9.1 -# tf~=1.12.1 -# actionlib~=1.12.1 -urdf-parser-py~=0.0.3 graphviz anytree>=2.8.0 SQLAlchemy>=2.0.9 @@ -18,4 +12,14 @@ psutil==5.9.7 lxml==4.9.1 typing_extensions==4.9.0 owlready2>=0.45 -catkin_pkg \ No newline at end of file +Pillow>=10.3.0 +pynput~=1.7.7 +playsound~=1.3.0 +pydub~=0.25.1 +gTTS~=2.5.3 +dm_control +trimesh +deprecated +probabilistic_model>=5.1.0 +random_events>=3.0.7 +sympy diff --git a/resources/kitchen-small.urdf b/resources/kitchen-small.urdf new file mode 100644 index 000000000..18764ed70 --- /dev/null +++ b/resources/kitchen-small.urdf @@ -0,0 +1,2247 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + false + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1000000000 + 1000000000 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resources/IAI_kitchen.urdf b/resources/objects/IAI_kitchen.urdf similarity index 100% rename from resources/IAI_kitchen.urdf rename to resources/objects/IAI_kitchen.urdf diff --git a/resources/Static_CokeBottle.stl b/resources/objects/Static_CokeBottle.stl similarity index 100% rename from resources/Static_CokeBottle.stl rename to resources/objects/Static_CokeBottle.stl diff --git a/resources/Static_MilkPitcher.stl b/resources/objects/Static_MilkPitcher.stl similarity index 100% rename from resources/Static_MilkPitcher.stl rename to resources/objects/Static_MilkPitcher.stl diff --git a/resources/apartment.urdf b/resources/objects/apartment.urdf similarity index 96% rename from resources/apartment.urdf rename to resources/objects/apartment.urdf index 60b5a4d01..2900accb1 100644 --- a/resources/apartment.urdf +++ b/resources/objects/apartment.urdf @@ -4,6 +4,27 @@ + + + + + + + + + + + + + + + + + + + + + @@ -480,6 +501,7 @@ + @@ -501,6 +523,7 @@ + @@ -533,7 +556,7 @@ - + @@ -562,6 +585,7 @@ + @@ -600,6 +624,7 @@ + @@ -678,6 +703,7 @@ + @@ -706,6 +732,7 @@ + @@ -744,6 +771,7 @@ + @@ -765,6 +793,7 @@ + @@ -784,6 +813,7 @@ + @@ -812,6 +842,7 @@ + @@ -858,6 +889,7 @@ + @@ -894,6 +926,7 @@ + @@ -914,6 +947,7 @@ + @@ -936,6 +970,7 @@ + @@ -972,6 +1007,7 @@ + @@ -992,6 +1028,7 @@ + @@ -1030,6 +1067,7 @@ + @@ -1068,6 +1106,7 @@ + @@ -1104,6 +1143,7 @@ + @@ -1124,6 +1164,7 @@ + @@ -1162,6 +1203,7 @@ + @@ -1200,6 +1242,7 @@ + @@ -1237,6 +1280,7 @@ + @@ -1257,6 +1301,7 @@ + @@ -1335,6 +1380,7 @@ + @@ -1355,6 +1401,7 @@ + @@ -1376,6 +1423,7 @@ + @@ -1412,6 +1460,7 @@ + @@ -1432,6 +1481,7 @@ + @@ -1470,6 +1520,7 @@ + @@ -1508,6 +1559,7 @@ + @@ -1544,6 +1596,7 @@ + @@ -1564,6 +1617,7 @@ + @@ -1602,6 +1656,7 @@ + @@ -1640,6 +1695,7 @@ + @@ -1676,6 +1732,7 @@ + @@ -1696,6 +1753,7 @@ + @@ -1734,6 +1792,7 @@ + @@ -1772,6 +1831,7 @@ + @@ -1809,6 +1869,7 @@ + @@ -1827,6 +1888,7 @@ + @@ -1869,6 +1931,7 @@ + @@ -2010,6 +2073,7 @@ + @@ -2028,6 +2092,7 @@ + @@ -2046,6 +2111,7 @@ + @@ -2064,6 +2130,7 @@ + @@ -2082,6 +2149,7 @@ + @@ -2100,6 +2168,7 @@ + diff --git a/resources/apartment_bowl.stl b/resources/objects/apartment_bowl.stl similarity index 100% rename from resources/apartment_bowl.stl rename to resources/objects/apartment_bowl.stl diff --git a/resources/apartment_without_walls.urdf b/resources/objects/apartment_without_walls.urdf similarity index 96% rename from resources/apartment_without_walls.urdf rename to resources/objects/apartment_without_walls.urdf index a807692d1..26f05aec7 100644 --- a/resources/apartment_without_walls.urdf +++ b/resources/objects/apartment_without_walls.urdf @@ -4,6 +4,27 @@ + + + + + + + + + + + + + + + + + + + + + @@ -395,6 +416,7 @@ + @@ -416,6 +438,7 @@ + @@ -448,7 +471,7 @@ - + @@ -477,6 +500,7 @@ + @@ -515,6 +539,7 @@ + @@ -593,6 +618,7 @@ + @@ -621,6 +647,7 @@ + @@ -659,6 +686,7 @@ + @@ -680,6 +708,7 @@ + @@ -699,6 +728,7 @@ + @@ -727,6 +757,7 @@ + @@ -773,6 +804,7 @@ + @@ -809,6 +841,7 @@ + @@ -829,6 +862,7 @@ + @@ -851,6 +885,7 @@ + @@ -887,6 +922,7 @@ + @@ -907,6 +943,7 @@ + @@ -945,6 +982,7 @@ + @@ -983,6 +1021,7 @@ + @@ -1019,6 +1058,7 @@ + @@ -1039,6 +1079,7 @@ + @@ -1077,6 +1118,7 @@ + @@ -1115,6 +1157,7 @@ + @@ -1152,6 +1195,7 @@ + @@ -1172,6 +1216,7 @@ + @@ -1250,6 +1295,7 @@ + @@ -1270,6 +1316,7 @@ + @@ -1291,6 +1338,7 @@ + @@ -1327,6 +1375,7 @@ + @@ -1347,6 +1396,7 @@ + @@ -1385,6 +1435,7 @@ + @@ -1423,6 +1474,7 @@ + @@ -1459,6 +1511,7 @@ + @@ -1479,6 +1532,7 @@ + @@ -1517,6 +1571,7 @@ + @@ -1555,6 +1610,7 @@ + @@ -1591,6 +1647,7 @@ + @@ -1611,6 +1668,7 @@ + @@ -1649,6 +1707,7 @@ + @@ -1687,6 +1746,7 @@ + @@ -1724,6 +1784,7 @@ + @@ -1742,6 +1803,7 @@ + @@ -1784,6 +1846,7 @@ + @@ -1925,6 +1988,7 @@ + @@ -1943,6 +2007,7 @@ + @@ -1961,6 +2026,7 @@ + @@ -1979,6 +2045,7 @@ + @@ -1997,6 +2064,7 @@ + @@ -2015,6 +2083,7 @@ + diff --git a/resources/bowl.stl b/resources/objects/bowl.stl similarity index 100% rename from resources/bowl.stl rename to resources/objects/bowl.stl diff --git a/resources/box.urdf b/resources/objects/box.urdf similarity index 100% rename from resources/box.urdf rename to resources/objects/box.urdf diff --git a/resources/breakfast_cereal.stl b/resources/objects/breakfast_cereal.stl similarity index 100% rename from resources/breakfast_cereal.stl rename to resources/objects/breakfast_cereal.stl diff --git a/resources/broken_kitchen.urdf b/resources/objects/broken_kitchen.urdf similarity index 100% rename from resources/broken_kitchen.urdf rename to resources/objects/broken_kitchen.urdf diff --git a/resources/jeroen_cup.stl b/resources/objects/jeroen_cup.stl similarity index 100% rename from resources/jeroen_cup.stl rename to resources/objects/jeroen_cup.stl diff --git a/resources/kitchen.urdf b/resources/objects/kitchen.urdf similarity index 100% rename from resources/kitchen.urdf rename to resources/objects/kitchen.urdf diff --git a/resources/milk.stl b/resources/objects/milk.stl similarity index 100% rename from resources/milk.stl rename to resources/objects/milk.stl diff --git a/resources/plane.obj b/resources/objects/plane.obj similarity index 100% rename from resources/plane.obj rename to resources/objects/plane.obj diff --git a/resources/plane.urdf b/resources/objects/plane.urdf similarity index 100% rename from resources/plane.urdf rename to resources/objects/plane.urdf diff --git a/resources/spoon.stl b/resources/objects/spoon.stl similarity index 100% rename from resources/spoon.stl rename to resources/objects/spoon.stl diff --git a/resources/robots/pr2.urdf b/resources/robots/pr2.urdf index 671407ef6..439087f3f 100644 --- a/resources/robots/pr2.urdf +++ b/resources/robots/pr2.urdf @@ -55,6 +55,9 @@ + + + @@ -826,7 +829,7 @@ - + @@ -857,7 +860,7 @@ - + @@ -1372,7 +1375,7 @@ - + @@ -1452,7 +1455,7 @@ - + @@ -1571,7 +1574,7 @@ - + @@ -1651,7 +1654,7 @@ - + @@ -1690,7 +1693,7 @@ - + @@ -1800,7 +1803,7 @@ - + @@ -1998,7 +2001,7 @@ - + @@ -2033,7 +2036,7 @@ - + @@ -2337,7 +2340,7 @@ - + @@ -2456,7 +2459,7 @@ - + @@ -2536,7 +2539,7 @@ - + @@ -2575,7 +2578,7 @@ - + @@ -2685,7 +2688,7 @@ - + @@ -2883,7 +2886,7 @@ - + @@ -2918,7 +2921,7 @@ - + diff --git a/resources/robots/turtlebot.urdf b/resources/robots/turtlebot.urdf new file mode 100644 index 000000000..b3a996afd --- /dev/null +++ b/resources/robots/turtlebot.urdf @@ -0,0 +1,371 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Gazebo/DarkGrey + + + 0.1 + 0.1 + 500000.0 + 10.0 + 0.001 + 0.1 + 1 0 0 + Gazebo/FlatBlack + + + 0.1 + 0.1 + 500000.0 + 10.0 + 0.001 + 0.1 + 1 0 0 + Gazebo/FlatBlack + + + 0.1 + 0.1 + 1000000.0 + 100.0 + 0.001 + 1.0 + Gazebo/FlatBlack + + + 0.1 + 0.1 + 1000000.0 + 100.0 + 0.001 + 1.0 + Gazebo/FlatBlack + + + + true + false + + Gazebo/Grey + + + + cmd_vel + odom + odom + world + true + base_footprint + false + true + true + false + 30 + wheel_left_joint + wheel_right_joint + 0.287 + 0.066 + 1 + 10 + na + + + + + true + imu_link + imu_link + imu + imu_service + 0.0 + 0 + + + gaussian + + 0.0 + 2e-4 + 0.0000075 + 0.0000008 + + + 0.0 + 1.7e-2 + 0.1 + 0.001 + + + + + + + Gazebo/FlatBlack + + 0 0 0 0 0 0 + false + 5 + + + + 360 + 1 + 0.0 + 6.28319 + + + + 0.120 + 3.5 + 0.015 + + + gaussian + 0.0 + 0.01 + + + + scan + base_scan + + + + + + + true + false + + 1.085595 + + 640 + 480 + R8G8B8 + + + 0.03 + 100 + + + + true + 30.0 + camera + camera_rgb_optical_frame + rgb/image_raw + rgb/camera_info + 0.07 + 0.0 + 0.0 + 0.0 + 0.0 + 0.0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/resources/worlds/pycram_test.muv b/resources/worlds/pycram_test.muv new file mode 100644 index 000000000..bfe6acb7d --- /dev/null +++ b/resources/worlds/pycram_test.muv @@ -0,0 +1,52 @@ +resources: + - ../cached + - ../robots + - ../worlds + - ../objects + +worlds: + pycram_test: + rtf_desired: 1 + prospection_pycram_test: + rtf_desired: 1 + +simulations: + pycram_test: + simulator: mujoco + world: + name: world + path: apartment/mjcf/apartment.xml + apply: + body: + gravcomp: 1 + config: + max_time_step: 0.002 + min_time_step: 0.001 + prospection_pycram_test: + simulator: mujoco + world: + name: prospection_world + path: apartment/mjcf/apartment.xml + apply: + body: + gravcomp: 1 + config: + max_time_step: 0.002 + min_time_step: 0.001 + +multiverse_server: + host: "tcp://127.0.0.1" + port: 7000 + +multiverse_clients: + pycram_test: + port: 7500 + send: + body: ["position", "quaternion", "relative_velocity"] + joint: ["joint_rvalue", "joint_tvalue", "joint_linear_velocity", "joint_angular_velocity", "joint_force", "joint_torque"] + + prospection_pycram_test: + port: 7600 + send: + body: [ "position", "quaternion", "relative_velocity" ] + joint: [ "joint_rvalue", "joint_tvalue", "joint_linear_velocity", "joint_angular_velocity", "joint_force", "joint_torque" ] \ No newline at end of file diff --git a/src/pycram/__init__.py b/src/pycram/__init__.py index 477a6d659..9fde2cdac 100644 --- a/src/pycram/__init__.py +++ b/src/pycram/__init__.py @@ -1,4 +1,19 @@ -import pycram.process_modules +from . import process_modules +from . import robot_descriptions # from .specialized_designators import * +from .datastructures.world import World +import signal __version__ = "0.0.2" + + +def signal_handler(sig, frame): + if World.current_world: + World.current_world.exit() + print("Exiting...") + exit(0) + +signal.signal(signal.SIGINT, signal_handler) + + + diff --git a/src/pycram/cache_manager.py b/src/pycram/cache_manager.py index 1c7b86cea..3e6a9889d 100644 --- a/src/pycram/cache_manager.py +++ b/src/pycram/cache_manager.py @@ -1,8 +1,9 @@ import glob import os import pathlib +import shutil -from typing_extensions import List, TYPE_CHECKING +from typing_extensions import List, TYPE_CHECKING, Optional if TYPE_CHECKING: from .description import ObjectDescription @@ -14,31 +15,52 @@ class CacheManager: The CacheManager is responsible for caching object description files and managing the cache directory. """ - mesh_extensions: List[str] = [".obj", ".stl"] + cache_cleared: bool = False """ - The file extensions of mesh files. + Indicate whether the cache directory has been cleared at least once since beginning or not. """ - def __init__(self, cache_dir: str, data_directory: List[str]): + def __init__(self, cache_dir: str, data_directory: List[str], clear_cache: bool = True): """ - Initializes the CacheManager. + Initialize the CacheManager. :param cache_dir: The directory where the cached files are stored. :param data_directory: The directory where all resource files are stored. + :param clear_cache: If True, the cache directory will be cleared. """ self.cache_dir = cache_dir - self.data_directory = data_directory + self.data_directories = data_directory + if clear_cache: + self.clear_cache() + + def clear_cache(self): + """ + Clear the cache directory. + """ + self.delete_cache_dir() + self.create_cache_dir_if_not_exists() + self.cache_cleared = True + + def delete_cache_dir(self): + """ + Delete the cache directory. + """ + if pathlib.Path(self.cache_dir).exists(): + shutil.rmtree(self.cache_dir) def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, - object_description: 'ObjectDescription', object_name: str) -> str: + object_description: 'ObjectDescription', object_name: str, + scale_mesh: Optional[float] = None) -> str: """ - Checks if the file is already in the cache directory, if not it will be preprocessed and saved in the cache. + Check if the file is already in the cache directory, if not preprocess and save in the cache. :param path: The path of the file to preprocess and save in the cache directory. :param ignore_cached_files: If True, the file will be preprocessed and saved in the cache directory even if it is already cached. :param object_description: The object description of the file. :param object_name: The name of the object. + :param scale_mesh: The scale of the mesh. + :return: The path of the cached file. """ path_object = pathlib.Path(path) extension = path_object.suffix @@ -46,49 +68,24 @@ def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, self.create_cache_dir_if_not_exists() # save correct path in case the file is already in the cache directory - cache_path = self.cache_dir + object_description.get_file_name(path_object, extension, object_name) + cache_path = os.path.join(self.cache_dir, object_description.get_file_name(path_object, extension, object_name)) if not self.is_cached(path, object_description) or ignore_cached_files: # if file is not yet cached preprocess the description file and save it in the cache directory. path = self.look_for_file_in_data_dir(path_object) - self.generate_description_and_write_to_cache(path, object_name, extension, cache_path, object_description) + object_description.generate_description_from_file(path, object_name, extension, cache_path, scale_mesh) return cache_path - def generate_description_and_write_to_cache(self, path: str, name: str, extension: str, cache_path: str, - object_description: 'ObjectDescription') -> None: - """ - Generates the description from the file at the given path and writes it to the cache directory. - - :param path: The path of the file to preprocess. - :param name: The name of the object. - :param extension: The file extension of the file to preprocess. - :param cache_path: The path of the file in the cache directory. - :param object_description: The object description of the file. - """ - description_string = object_description.generate_description_from_file(path, name, extension) - self.write_to_cache(description_string, cache_path) - - @staticmethod - def write_to_cache(description_string: str, cache_path: str) -> None: - """ - Writes the description string to the cache directory. - - :param description_string: The description string to write to the cache directory. - :param cache_path: The path of the file in the cache directory. - """ - with open(cache_path, "w") as file: - file.write(description_string) - def look_for_file_in_data_dir(self, path_object: pathlib.Path) -> str: """ - Looks for a file in the data directory of the World. If the file is not found in the data directory, this method - raises a FileNotFoundError. + Look for a file in the data directory of the World. If the file is not found in the data directory, raise a + FileNotFoundError. :param path_object: The pathlib object of the file to look for. """ name = path_object.name - for data_dir in self.data_directory: + for data_dir in self.data_directories: data_path = pathlib.Path(data_dir).joinpath("**") for file in glob.glob(str(data_path), recursive=True): file_path = pathlib.Path(file) @@ -97,18 +94,18 @@ def look_for_file_in_data_dir(self, path_object: pathlib.Path) -> str: return str(file_path) raise FileNotFoundError( - f"File {name} could not be found in the resource directory {self.data_directory}") + f"File {name} could not be found in the resource directory {self.data_directories}") def create_cache_dir_if_not_exists(self): """ - Creates the cache directory if it does not exist. + Create the cache directory if it does not exist. """ if not pathlib.Path(self.cache_dir).exists(): os.mkdir(self.cache_dir) def is_cached(self, path: str, object_description: 'ObjectDescription') -> bool: """ - Checks if the file in the given path is already cached or if + Check if the file in the given path is already cached or if there is already a cached file with the given name, this is the case if a .stl, .obj file or a description from the parameter server is used. @@ -116,26 +113,26 @@ def is_cached(self, path: str, object_description: 'ObjectDescription') -> bool: :param object_description: The object description of the file. :return: True if there already exists a cached file, False in any other case. """ - return True if self.check_with_extension(path) else self.check_without_extension(path, object_description) + return self.check_with_extension(path) or self.check_without_extension(path, object_description) def check_with_extension(self, path: str) -> bool: """ - Checks if the file in the given ath exists in the cache directory including file extension. + Check if the file in the given ath exists in the cache directory including file extension. :param path: The path of the file to check. """ file_name = pathlib.Path(path).name - full_path = pathlib.Path(self.cache_dir + file_name) + full_path = pathlib.Path(os.path.join(self.cache_dir, file_name)) return full_path.exists() def check_without_extension(self, path: str, object_description: 'ObjectDescription') -> bool: """ - Checks if the file in the given path exists in the cache directory without file extension, - the extension is added after the file name manually in this case. + Check if the file in the given path exists in the cache directory the given file extension. + Instead, replace the given extension with the extension of the used ObjectDescription and check for that one. :param path: The path of the file to check. :param object_description: The object description of the file. """ file_stem = pathlib.Path(path).stem - full_path = pathlib.Path(self.cache_dir + file_stem + object_description.get_file_extension()) + full_path = pathlib.Path(os.path.join(self.cache_dir, file_stem + object_description.get_file_extension())) return full_path.exists() diff --git a/src/pycram/config b/src/pycram/config new file mode 120000 index 000000000..899f69898 --- /dev/null +++ b/src/pycram/config @@ -0,0 +1 @@ +../../config \ No newline at end of file diff --git a/src/pycram/costmaps.py b/src/pycram/costmaps.py index 462c60233..9ce0ecfdf 100644 --- a/src/pycram/costmaps.py +++ b/src/pycram/costmaps.py @@ -1,26 +1,35 @@ # used for delayed evaluation of typing until python 3.11 becomes mainstream from __future__ import annotations -from typing_extensions import Tuple, List, Optional - -import matplotlib.pyplot as plt from dataclasses import dataclass +import matplotlib.pyplot as plt import numpy as np import psutil -import rospy +import random_events +import tf from matplotlib import colors from nav_msgs.msg import OccupancyGrid, MapMetaData - +from probabilistic_model.probabilistic_circuit.nx.distributions import UniformDistribution +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, ProductUnit +from random_events.interval import Interval, reals, closed_open, closed +from random_events.product_algebra import Event, SimpleEvent +from random_events.variable import Continuous +from typing_extensions import Tuple, List, Optional, Iterator + +from .ros.ros_tools import wait_for_message +from .datastructures.dataclasses import AxisAlignedBoundingBox +from .datastructures.pose import Pose from .datastructures.world import UseProspectionWorld -from .world_concepts.world_object import Object +from .datastructures.world import World from .description import Link from .local_transformer import LocalTransformer +from .world_concepts.world_object import Object + from .datastructures.pose import Pose, Transform from .datastructures.world import World from .datastructures.dataclasses import AxisAlignedBoundingBox, BoxVisualShape, Color - -import pybullet as p +from tf.transformations import quaternion_from_matrix @dataclass @@ -111,42 +120,21 @@ def visualize(self) -> None: # Creation of the visual shapes, for documentation of the visual shapes # please look here: https://docs.google.com/document/d/10sXEhzFRSnvFcl3XxNGhnD4N2SedqwdAvK3dsihxVUA/edit#heading=h.q1gn7v6o58bf for box in boxes: - visual = p.createVisualShape(p.GEOM_BOX, - halfExtents=[(box[1] * self.resolution) / 2, (box[2] * self.resolution) / 2, - 0.001], - rgbaColor=[1, 0, 0, 0.6], - visualFramePosition=[(box[0][0] + box[1] / 2) * self.resolution, - (box[0][1] + box[2] / 2) * self.resolution, 0.]) + box = BoxVisualShape(Color(1, 0, 0, 0.6), + [(box[0][0] + box[1] / 2) * self.resolution, + (box[0][1] + box[2] / 2) * self.resolution, 0.], + [(box[1] * self.resolution) / 2, (box[2] * self.resolution) / 2, 0.001]) + visual = self.world.create_visual_shape(box) cells.append(visual) # Set to 127 for since this is the maximal amount of links in a multibody for cell_parts in self._chunks(cells, 127): - # Dummy paramater since these are needed to spawn visual shapes as a - # multibody. - link_poses = [[0, 0, 0] for c in cell_parts] - link_orientations = [[0, 0, 0, 1] for c in cell_parts] - link_masses = [1.0 for c in cell_parts] - link_parent = [0 for c in cell_parts] - link_joints = [p.JOINT_FIXED for c in cell_parts] - link_collision = [-1 for c in cell_parts] - link_joint_axis = [[1, 0, 0] for c in cell_parts] - # The position at which the multibody will be spawned. Offset such that - # the origin referes to the centre of the costmap. - # origin_pose = self.origin.position_as_list() - # base_pose = [origin_pose[0] - self.height / 2 * self.resolution, - # origin_pose[1] - self.width / 2 * self.resolution, origin_pose[2]] - offset = [[-self.height / 2 * self.resolution, -self.width / 2 * self.resolution, 0.05], [0, 0, 0, 1]] - new_pose = p.multiplyTransforms(self.origin.position_as_list(), self.origin.orientation_as_list(), - offset[0], offset[1]) - - map_obj = p.createMultiBody(baseVisualShapeIndex=-1, linkVisualShapeIndices=cell_parts, - basePosition=new_pose[0], baseOrientation=new_pose[1], linkPositions=link_poses, - # [0, 0, 1, 0] - linkMasses=link_masses, linkOrientations=link_orientations, - linkInertialFramePositions=link_poses, - linkInertialFrameOrientations=link_orientations, linkParentIndices=link_parent, - linkJointTypes=link_joints, linkJointAxis=link_joint_axis, - linkCollisionShapeIndices=link_collision) + origin_transform = (Transform(self.origin.position_as_list(), self.origin.orientation_as_list()) + .get_homogeneous_matrix()) + offset_transform = (Transform(offset[0], offset[1]).get_homogeneous_matrix()) + new_pose_transform = np.dot(origin_transform, offset_transform) + new_pose = Pose(new_pose_transform[:3, 3].tolist(), quaternion_from_matrix(new_pose_transform)) + map_obj = self.world.create_multi_body_from_visual_shapes(cell_parts, new_pose) self.vis_ids.append(map_obj) def _chunks(self, lst: List, n: int) -> List: @@ -165,7 +153,7 @@ def close_visualization(self) -> None: Removes the visualization from the World. """ for v_id in self.vis_ids: - self.world.remove_object_by_id(v_id) + self.world.remove_visual_object(v_id) self.vis_ids = [] def _find_consectuive_line(self, start: Tuple[int, int], map: np.ndarray) -> int: @@ -366,7 +354,7 @@ def _get_map() -> np.ndarray: :return: The costmap as a numpy array. """ print("Waiting for Map") - map = rospy.wait_for_message("/map", OccupancyGrid) + map = wait_for_message("/map", OccupancyGrid) print("Recived Map") return np.array(map.data) @@ -379,7 +367,7 @@ def _get_map_metadata() -> MapMetaData: :return: The meta-data for the costmap array. """ print("Waiting for Map Meta Data") - meta = rospy.wait_for_message("/map_metadata", MapMetaData) + meta = wait_for_message("/map_metadata", MapMetaData) print("Recived Meta Data") return meta @@ -448,8 +436,8 @@ def _create_from_world(self, size: int, resolution: float) -> np.ndarray: indices = np.concatenate(np.dstack(np.mgrid[int(-size / 2):int(size / 2), int(-size / 2):int(size / 2)]), axis=0) * resolution + np.array(origin_position[:2]) # Add the z-coordinate to the grid, which is either 0 or 10 - indices_0 = np.pad(indices, (0, 1), mode='constant', constant_values=0)[:-1] - indices_10 = np.pad(indices, (0, 1), mode='constant', constant_values=10)[:-1] + indices_0 = np.pad(indices, (0, 1), mode='constant', constant_values=5)[:-1] + indices_10 = np.pad(indices, (0, 1), mode='constant', constant_values=0)[:-1] # Zips both arrays such that there are tuples for every coordinate that # only differ in the z-coordinate rays = np.dstack(np.dstack((indices_0, indices_10))).T @@ -461,10 +449,9 @@ def _create_from_world(self, size: int, resolution: float) -> np.ndarray: i = 0 j = 0 for n in self._chunks(np.array(rays), 16380): - # with UseProspectionWorld(): - r_t = self.world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) + r_t = World.current_world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) while r_t is None: - r_t = self.world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) + r_t = World.current_world.ray_test_batch(n[:, 0], n[:, 1], num_threads=0) j += len(n) if World.robot: attached_objs_id = [o.id for o in self.world.robot.attachments.keys()] @@ -787,11 +774,10 @@ def generate_map(self) -> None: def get_aabb_for_link(self) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box (AABB) of the link provided when creating this costmap. To try and let the - AABB as close to the actual object as possible, the Object will be rotated such that the link will be in the - identity orientation. - :return: Two points in world coordinate space, which span a rectangle + :return: The axis aligned bounding box (AABB) of the link provided when creating this costmap. To try and let + the AABB as close to the actual object as possible, the Object will be rotated such that the link will be in the + identity orientation. """ prospection_object = World.current_world.get_prospection_object_for_object(self.object) with UseProspectionWorld(): @@ -802,6 +788,160 @@ def get_aabb_for_link(self) -> AxisAlignedBoundingBox: return self.link.get_axis_aligned_bounding_box() +class AlgebraicSemanticCostmap(SemanticCostmap): + """ + Class for a semantic costmap that is based on an algebraic set-description of the valid area. + """ + x: Continuous = Continuous("x") + """ + The variable for height. + """ + + y: Continuous = Continuous("y") + """ + The variable for width. + """ + + original_valid_area: Optional[SimpleEvent] + """ + The original rectangle of the valid area. + """ + + valid_area: Optional[Event] + """ + A description of the valid positions as set. + """ + + number_of_samples: int + """ + The number of samples to generate for the iter. + """ + + def __init__(self, object, urdf_link_name, world=None, number_of_samples=1000): + super().__init__(object, urdf_link_name, world=world) + self.number_of_samples = number_of_samples + + def check_valid_area_exists(self): + assert self.valid_area is not None, ("The map has to be created before semantics can be applied. " + "Call 'generate_map first'") + + def left(self, margin = 0.) -> Event: + """ + Create an event left of the origins Y-Coordinate. + :param margin: The margin of the events left bound. + :return: The left event. + """ + self.check_valid_area_exists() + y_origin = self.origin.position.y + left = self.original_valid_area[self.y].simple_sets[0].lower + left += margin + event = SimpleEvent( + {self.x: reals(), self.y: random_events.interval.open(left, y_origin)}).as_composite_set() + return event + + def right(self, margin = 0.) -> Event: + """ + Create an event right of the origins Y-Coordinate. + :param margin: The margin of the events right bound. + :return: The right event. + """ + self.check_valid_area_exists() + y_origin = self.origin.position.y + right = self.original_valid_area[self.y].simple_sets[0].upper + right -= margin + event = SimpleEvent({self.x: reals(), self.y: closed_open(y_origin, right)}).as_composite_set() + return event + + def top(self, margin = 0.) -> Event: + """ + Create an event above the origins X-Coordinate. + :param margin: The margin of the events upper bound. + :return: The top event. + """ + self.check_valid_area_exists() + x_origin = self.origin.position.x + top = self.original_valid_area[self.x].simple_sets[0].upper + top -= margin + event = SimpleEvent( + {self.x: random_events.interval.closed_open(x_origin, top), self.y: reals()}).as_composite_set() + return event + + def bottom(self, margin = 0.) -> Event: + """ + Create an event below the origins X-Coordinate. + :param margin: The margin of the events lower bound. + :return: The bottom event. + """ + self.check_valid_area_exists() + x_origin = self.origin.position.x + lower = self.original_valid_area[self.x].simple_sets[0].lower + lower += margin + event = SimpleEvent( + {self.x: random_events.interval.open(lower, x_origin), self.y: reals()}).as_composite_set() + return event + + def inner(self, margin=0.2): + min_x = self.original_valid_area[self.x].simple_sets[0].lower + max_x = self.original_valid_area[self.x].simple_sets[0].upper + min_y = self.original_valid_area[self.y].simple_sets[0].lower + max_y = self.original_valid_area[self.y].simple_sets[0].upper + + min_x += margin + max_x -= margin + min_y += margin + max_y -= margin + + inner_event = SimpleEvent({self.x: closed(min_x, max_x), + self.y: closed(min_y, max_y)}).as_composite_set() + return inner_event + + def border(self, margin=0.2): + return ~self.inner(margin) + + def generate_map(self) -> None: + super().generate_map() + valid_area = Event() + for rectangle in self.partitioning_rectangles(): + # rectangle.scale(1/self.resolution, 1/self.resolution) + rectangle.translate(self.origin.position.x, self.origin.position.y) + valid_area.simple_sets.add(SimpleEvent({self.x: closed(rectangle.x_lower, rectangle.x_upper), + self.y: closed(rectangle.y_lower, rectangle.y_upper)})) + + assert len(valid_area.simple_sets) == 1, ("The map at the basis of a Semantic costmap must be an axis aligned" + "bounding box") + self.valid_area = valid_area + self.original_valid_area = self.valid_area.simple_sets[0] + + def as_distribution(self) -> ProbabilisticCircuit: + p_xy = ProductUnit() + u_x = UniformDistribution(self.x, self.original_valid_area[self.x].simple_sets[0]) + u_y = UniformDistribution(self.y, self.original_valid_area[self.y].simple_sets[0]) + p_xy.add_subcircuit(u_x) + p_xy.add_subcircuit(u_y) + + conditional, _ = p_xy.conditional(self.valid_area) + return conditional.probabilistic_circuit + + def sample_to_pose(self, sample: np.ndarray) -> Pose: + """ + Convert a sample from the costmap to a pose. + :param sample: The sample to convert + :return: The pose corresponding to the sample + """ + x = sample[0] + y = sample[1] + position = [x, y, self.origin.position.z] + angle = np.arctan2(position[1] - self.origin.position.y, position[0] - self.origin.position.x) + np.pi + orientation = list(tf.transformations.quaternion_from_euler(0, 0, angle, axes="sxyz")) + return Pose(position, orientation, self.origin.frame) + + def __iter__(self) -> Iterator[Pose]: + model = self.as_distribution() + samples = model.sample(self.number_of_samples) + for sample in samples: + yield self.sample_to_pose(sample) + + cmap = colors.ListedColormap(['white', 'black', 'green', 'red', 'blue']) diff --git a/src/pycram/datastructures/dataclasses.py b/src/pycram/datastructures/dataclasses.py index d83bcd00f..ddf5d7b21 100644 --- a/src/pycram/datastructures/dataclasses.py +++ b/src/pycram/datastructures/dataclasses.py @@ -1,10 +1,14 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from copy import deepcopy, copy from dataclasses import dataclass + from typing_extensions import List, Optional, Tuple, Callable, Dict, Any, Union, TYPE_CHECKING -from .enums import JointType, Shape + +from .enums import JointType, Shape, VirtualMobileBaseJointName from .pose import Pose, Point -from abc import ABC, abstractmethod +from ..validation.error_checkers import calculate_joint_position_error, is_error_acceptable if TYPE_CHECKING: from ..description import Link @@ -14,7 +18,7 @@ def get_point_as_list(point: Point) -> List[float]: """ - Returns the point as a list. + Return the point as a list. :param point: The point. :return: The point as a list @@ -37,7 +41,7 @@ class Color: @classmethod def from_list(cls, color: List[float]): """ - Sets the rgba_color from a list of RGBA values. + Set the rgba_color from a list of RGBA values. :param color: The list of RGBA values """ @@ -51,7 +55,7 @@ def from_list(cls, color: List[float]): @classmethod def from_rgb(cls, rgb: List[float]): """ - Sets the rgba_color from a list of RGB values. + Set the rgba_color from a list of RGB values. :param rgb: The list of RGB values """ @@ -60,7 +64,7 @@ def from_rgb(cls, rgb: List[float]): @classmethod def from_rgba(cls, rgba: List[float]): """ - Sets the rgba_color from a list of RGBA values. + Set the rgba_color from a list of RGBA values. :param rgba: The list of RGBA values """ @@ -68,7 +72,7 @@ def from_rgba(cls, rgba: List[float]): def get_rgba(self) -> List[float]: """ - Returns the rgba_color as a list of RGBA values. + Return the rgba_color as a list of RGBA values. :return: The rgba_color as a list of RGBA values """ @@ -76,7 +80,7 @@ def get_rgba(self) -> List[float]: def get_rgb(self) -> List[float]: """ - Returns the rgba_color as a list of RGB values. + Return the rgba_color as a list of RGB values. :return: The rgba_color as a list of RGB values """ @@ -98,7 +102,7 @@ class AxisAlignedBoundingBox: @classmethod def from_min_max(cls, min_point: List[float], max_point: List[float]): """ - Sets the axis-aligned bounding box from a minimum and maximum point. + Set the axis-aligned bounding box from a minimum and maximum point. :param min_point: The minimum point :param max_point: The maximum point @@ -107,48 +111,36 @@ def from_min_max(cls, min_point: List[float], max_point: List[float]): def get_min_max_points(self) -> Tuple[Point, Point]: """ - Returns the axis-aligned bounding box as a tuple of minimum and maximum points. - :return: The axis-aligned bounding box as a tuple of minimum and maximum points """ return self.get_min_point(), self.get_max_point() def get_min_point(self) -> Point: """ - Returns the axis-aligned bounding box as a minimum point. - :return: The axis-aligned bounding box as a minimum point """ return Point(self.min_x, self.min_y, self.min_z) def get_max_point(self) -> Point: """ - Returns the axis-aligned bounding box as a maximum point. - :return: The axis-aligned bounding box as a maximum point """ return Point(self.max_x, self.max_y, self.max_z) def get_min_max(self) -> Tuple[List[float], List[float]]: """ - Returns the axis-aligned bounding box as a tuple of minimum and maximum points. - :return: The axis-aligned bounding box as a tuple of minimum and maximum points """ return self.get_min(), self.get_max() def get_min(self) -> List[float]: """ - Returns the minimum point of the axis-aligned bounding box. - :return: The minimum point of the axis-aligned bounding box """ return [self.min_x, self.min_y, self.min_z] def get_max(self) -> List[float]: """ - Returns the maximum point of the axis-aligned bounding box. - :return: The maximum point of the axis-aligned bounding box """ return [self.max_x, self.max_y, self.max_z] @@ -156,12 +148,19 @@ def get_max(self) -> List[float]: @dataclass class CollisionCallbacks: + """ + Dataclass for storing the collision callbacks which are callables that get called when there is a collision + or when a collision is no longer there. + """ on_collision_cb: Callable no_collision_cb: Optional[Callable] = None @dataclass class MultiBody: + """ + Dataclass for storing the information of a multibody which consists of a base and multiple links with joints. + """ base_visual_shape_index: int base_pose: Pose link_visual_shape_indices: List[int] @@ -176,13 +175,16 @@ class MultiBody: @dataclass class VisualShape(ABC): + """ + Abstract dataclass for storing the information of a visual shape. + """ rgba_color: Color visual_frame_position: List[float] @abstractmethod def shape_data(self) -> Dict[str, Any]: """ - Returns the shape data of the visual shape (e.g. half extents for a box, radius for a sphere). + :return: the shape data of the visual shape (e.g. half extents for a box, radius for a sphere) as a dictionary. """ pass @@ -190,13 +192,16 @@ def shape_data(self) -> Dict[str, Any]: @abstractmethod def visual_geometry_type(self) -> Shape: """ - Returns the visual geometry type of the visual shape (e.g. box, sphere). + :return: The visual geometry type of the visual shape (e.g. box, sphere) as a Shape object. """ pass @dataclass class BoxVisualShape(VisualShape): + """ + Dataclass for storing the information of a box visual shape + """ half_extents: List[float] def shape_data(self) -> Dict[str, List[float]]: @@ -213,6 +218,9 @@ def size(self) -> List[float]: @dataclass class SphereVisualShape(VisualShape): + """ + Dataclass for storing the information of a sphere visual shape + """ radius: float def shape_data(self) -> Dict[str, float]: @@ -225,6 +233,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class CapsuleVisualShape(VisualShape): + """ + Dataclass for storing the information of a capsule visual shape + """ radius: float length: float @@ -238,6 +249,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class CylinderVisualShape(CapsuleVisualShape): + """ + Dataclass for storing the information of a cylinder visual shape + """ @property def visual_geometry_type(self) -> Shape: @@ -246,6 +260,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class MeshVisualShape(VisualShape): + """ + Dataclass for storing the information of a mesh visual shape + """ scale: List[float] file_name: str @@ -259,6 +276,9 @@ def visual_geometry_type(self) -> Shape: @dataclass class PlaneVisualShape(VisualShape): + """ + Dataclass for storing the information of a plane visual shape + """ normal: List[float] def shape_data(self) -> Dict[str, List[float]]: @@ -271,28 +291,409 @@ def visual_geometry_type(self) -> Shape: @dataclass class State(ABC): + """ + Abstract dataclass for storing the state of an entity (e.g. world, object, link, joint). + """ pass @dataclass class LinkState(State): + """ + Dataclass for storing the state of a link. + """ constraint_ids: Dict[Link, int] + def __eq__(self, other: 'LinkState'): + return self.all_constraints_exist(other) and self.all_constraints_are_equal(other) + + def all_constraints_exist(self, other: 'LinkState') -> bool: + """ + Check if all constraints exist in the other link state. + + :param other: The state of the other link. + :return: True if all constraints exist, False otherwise. + """ + return (all([cid_k in other.constraint_ids.keys() for cid_k in self.constraint_ids.keys()]) + and len(self.constraint_ids.keys()) == len(other.constraint_ids.keys())) + + def all_constraints_are_equal(self, other: 'LinkState') -> bool: + """ + Check if all constraints are equal to the ones in the other link state. + + :param other: The state of the other link. + :return: True if all constraints are equal, False otherwise. + """ + return all([cid == other_cid for cid, other_cid in zip(self.constraint_ids.values(), + other.constraint_ids.values())]) + + def __copy__(self): + return LinkState(constraint_ids=copy(self.constraint_ids)) + @dataclass class JointState(State): + """ + Dataclass for storing the state of a joint. + """ position: float + acceptable_error: float + + def __eq__(self, other: 'JointState'): + error = calculate_joint_position_error(self.position, other.position) + return is_error_acceptable(error, other.acceptable_error) + + def __copy__(self): + return JointState(position=self.position, acceptable_error=self.acceptable_error) @dataclass class ObjectState(State): + """ + Dataclass for storing the state of an object. + """ pose: Pose attachments: Dict[Object, Attachment] link_states: Dict[int, LinkState] joint_states: Dict[int, JointState] + acceptable_pose_error: Tuple[float, float] + + def __eq__(self, other: 'ObjectState'): + return (self.pose_is_almost_equal(other) + and self.all_attachments_exist(other) and self.all_attachments_are_equal(other) + and self.link_states == other.link_states + and self.joint_states == other.joint_states) + + def pose_is_almost_equal(self, other: 'ObjectState') -> bool: + """ + Check if the pose of the object is almost equal to the pose of another object. + + :param other: The state of the other object. + :return: True if the poses are almost equal, False otherwise. + """ + return self.pose.almost_equal(other.pose, other.acceptable_pose_error[0], other.acceptable_pose_error[1]) + + def all_attachments_exist(self, other: 'ObjectState') -> bool: + """ + Check if all attachments exist in the other object state. + + :param other: The state of the other object. + :return: True if all attachments exist, False otherwise. + """ + return (all([att_k in other.attachments.keys() for att_k in self.attachments.keys()]) + and len(self.attachments.keys()) == len(other.attachments.keys())) + + def all_attachments_are_equal(self, other: 'ObjectState') -> bool: + """ + Check if all attachments are equal to the ones in the other object state. + + :param other: The state of the other object. + :return: True if all attachments are equal, False otherwise + """ + return all([att == other_att for att, other_att in zip(self.attachments.values(), other.attachments.values())]) + + def __copy__(self): + return ObjectState(pose=self.pose.copy(), attachments=copy(self.attachments), + link_states=copy(self.link_states), + joint_states=copy(self.joint_states), + acceptable_pose_error=deepcopy(self.acceptable_pose_error)) @dataclass class WorldState(State): - simulator_state_id: int + """ + Dataclass for storing the state of the world. + """ + simulator_state_id: Optional[int] object_states: Dict[str, ObjectState] + + def __eq__(self, other: 'WorldState'): + return (self.simulator_state_is_equal(other) and self.all_objects_exist(other) + and self.all_objects_states_are_equal(other)) + + def simulator_state_is_equal(self, other: 'WorldState') -> bool: + """ + Check if the simulator state is equal to the simulator state of another world state. + + :param other: The state of the other world. + :return: True if the simulator states are equal, False otherwise. + """ + return self.simulator_state_id == other.simulator_state_id + + def all_objects_exist(self, other: 'WorldState') -> bool: + """ + Check if all objects exist in the other world state. + + :param other: The state of the other world. + :return: True if all objects exist, False otherwise. + """ + return (all([obj_name in other.object_states.keys() for obj_name in self.object_states.keys()]) + and len(self.object_states.keys()) == len(other.object_states.keys())) + + def all_objects_states_are_equal(self, other: 'WorldState') -> bool: + """ + Check if all object states are equal to the ones in the other world state. + + :param other: The state of the other world. + :return: True if all object states are equal, False otherwise. + """ + return all([obj_state == other_obj_state + for obj_state, other_obj_state in zip(self.object_states.values(), + other.object_states.values())]) + + def __copy__(self): + return WorldState(simulator_state_id=self.simulator_state_id, + object_states=deepcopy(self.object_states)) + + +@dataclass +class LateralFriction: + """ + Dataclass for storing the information of the lateral friction. + """ + lateral_friction: float + lateral_friction_direction: List[float] + + +@dataclass +class ContactPoint: + """ + Dataclass for storing the information of a contact point between two objects. + """ + link_a: Link + link_b: Link + position_on_object_a: Optional[List[float]] = None + position_on_object_b: Optional[List[float]] = None + normal_on_b: Optional[List[float]] = None # normal on object b pointing towards object a + distance: Optional[float] = None + normal_force: Optional[List[float]] = None # normal force applied during last step simulation + lateral_friction_1: Optional[LateralFriction] = None + lateral_friction_2: Optional[LateralFriction] = None + force_x_in_world_frame: Optional[float] = None + force_y_in_world_frame: Optional[float] = None + force_z_in_world_frame: Optional[float] = None + + def __str__(self): + return f"ContactPoint: {self.link_a.object.name} - {self.link_b.object.name}" + + def __repr__(self): + return self.__str__() + + +ClosestPoint = ContactPoint +""" +The closest point between two objects which has the same structure as ContactPoint. +""" + + +class ContactPointsList(list): + """ + A list of contact points. + """ + def get_links_that_got_removed(self, previous_points: 'ContactPointsList') -> List[Link]: + """ + Return the links that are not in the current points list but were in the initial points list. + + :param previous_points: The initial points list. + :return: A list of Link instances that represent the links that got removed. + """ + initial_links_in_contact = previous_points.get_links_in_contact() + current_links_in_contact = self.get_links_in_contact() + return [link for link in initial_links_in_contact if link not in current_links_in_contact] + + def get_links_in_contact(self) -> List[Link]: + """ + Get the links in contact. + + :return: A list of Link instances that represent the links in contact. + """ + return [point.link_b for point in self] + + def check_if_two_objects_are_in_contact(self, obj_a: Object, obj_b: Object) -> bool: + """ + Check if two objects are in contact. + + :param obj_a: An instance of the Object class that represents the first object. + :param obj_b: An instance of the Object class that represents the second object. + :return: True if the objects are in contact, False otherwise. + """ + return (any([point.link_b.object == obj_b and point.link_a.object == obj_a for point in self]) or + any([point.link_a.object == obj_b and point.link_b.object == obj_a for point in self])) + + def get_normals_of_object(self, obj: Object) -> List[List[float]]: + """ + Get the normals of the object. + + :param obj: An instance of the Object class that represents the object. + :return: A list of float vectors that represent the normals of the object. + """ + return self.get_points_of_object(obj).get_normals() + + def get_normals(self) -> List[List[float]]: + """ + Get the normals of the points. + + :return: A list of float vectors that represent the normals of the contact points. + """ + return [point.normal_on_b for point in self] + + def get_links_in_contact_of_object(self, obj: Object) -> List[Link]: + """ + Get the links in contact of the object. + + :param obj: An instance of the Object class that represents the object. + :return: A list of Link instances that represent the links in contact of the object. + """ + return [point.link_b for point in self if point.link_b.object == obj] + + def get_points_of_object(self, obj: Object) -> 'ContactPointsList': + """ + Get the points of the object. + + :param obj: An instance of the Object class that represents the object that the points are related to. + :return: A ContactPointsList instance that represents the contact points of the object. + """ + return ContactPointsList([point for point in self if point.link_b.object == obj]) + + def get_objects_that_got_removed(self, previous_points: 'ContactPointsList') -> List[Object]: + """ + Return the object that is not in the current points list but was in the initial points list. + + :param previous_points: The initial points list. + :return: A list of Object instances that represent the objects that got removed. + """ + initial_objects_in_contact = previous_points.get_objects_that_have_points() + current_objects_in_contact = self.get_objects_that_have_points() + return [obj for obj in initial_objects_in_contact if obj not in current_objects_in_contact] + + def get_new_objects(self, previous_points: 'ContactPointsList') -> List[Object]: + """ + Return the object that is not in the initial points list but is in the current points list. + + :param previous_points: The initial points list. + :return: A list of Object instances that represent the new objects. + """ + initial_objects_in_contact = previous_points.get_objects_that_have_points() + current_objects_in_contact = self.get_objects_that_have_points() + return [obj for obj in current_objects_in_contact if obj not in initial_objects_in_contact] + + def is_object_in_the_list(self, obj: Object) -> bool: + """ + Check if the object is one of the objects that have points in the list. + + :param obj: An instance of the Object class that represents the object. + :return: True if the object is in the list, False otherwise. + """ + return obj in self.get_objects_that_have_points() + + def get_names_of_objects_that_have_points(self) -> List[str]: + """ + Return the names of the objects that have points in the list. + + :return: A list of strings that represent the names of the objects that have points in the list. + """ + return [obj.name for obj in self.get_objects_that_have_points()] + + def get_objects_that_have_points(self) -> List[Object]: + """ + Return the objects that have points in the list. + + :return: A list of Object instances that represent the objects that have points in the list. + """ + return list({point.link_b.object for point in self}) + + def __str__(self): + return f"ContactPointsList: {', '.join([point.__str__() for point in self])}" + + def __repr__(self): + return self.__str__() + + +ClosestPointsList = ContactPointsList +""" +The list of closest points which has same structure as ContactPointsList. +""" + + +@dataclass +class TextAnnotation: + """ + Dataclass for storing text annotations that can be displayed in the simulation. + """ + text: str + position: List[float] + id: int + color: Color = Color(0, 0, 0, 1) + size: float = 0.1 + + +@dataclass +class VirtualMobileBaseJoints: + """ + Dataclass for storing the names, types and axes of the virtual mobile base joints of a mobile robot. + """ + translation_x: Optional[str] = VirtualMobileBaseJointName.LINEAR_X.value + translation_y: Optional[str] = VirtualMobileBaseJointName.LINEAR_Y.value + angular_z: Optional[str] = VirtualMobileBaseJointName.ANGULAR_Z.value + + @property + def names(self) -> List[str]: + """ + Return the names of the virtual mobile base joints. + """ + return [self.translation_x, self.translation_y, self.angular_z] + + def get_types(self) -> Dict[str, JointType]: + """ + Return the joint types of the virtual mobile base joints. + """ + return {self.translation_x: JointType.PRISMATIC, + self.translation_y: JointType.PRISMATIC, + self.angular_z: JointType.REVOLUTE} + + def get_axes(self) -> Dict[str, Point]: + """ + Return the axes (i.e. The axis on which the joint moves) of the virtual mobile base joints. + """ + return {self.translation_x: Point(1, 0, 0), + self.translation_y: Point(0, 1, 0), + self.angular_z: Point(0, 0, 1)} + + +@dataclass +class MultiverseMetaData: + """Meta data for the Multiverse Client, the simulation_name should be non-empty and unique for each simulation""" + world_name: str = "world" + simulation_name: str = "cram" + length_unit: str = "m" + angle_unit: str = "rad" + mass_unit: str = "kg" + time_unit: str = "s" + handedness: str = "rhs" + + +@dataclass +class RayResult: + """ + A dataclass to store the ray result. The ray result contains the body name that the ray intersects with and the + distance from the ray origin to the intersection point. + """ + body_name: str + distance: float + + def intersected(self) -> bool: + """ + Check if the ray intersects with a body. + return: Whether the ray intersects with a body. + """ + return self.distance >= 0 and self.body_name != "" + + +@dataclass +class MultiverseContactPoint: + """ + A dataclass to store the contact point returned from Multiverse. + """ + body_name: str + contact_force: List[float] + contact_torque: List[float] diff --git a/src/pycram/datastructures/enums.py b/src/pycram/datastructures/enums.py index 6efa774e6..530513c52 100644 --- a/src/pycram/datastructures/enums.py +++ b/src/pycram/datastructures/enums.py @@ -2,15 +2,24 @@ from enum import Enum, auto +from ..failures import UnsupportedJointType -class Arms(Enum): + +class ExecutionType(Enum): + """Enum for Execution Process Module types.""" + REAL = auto() + SIMULATED = auto() + SEMI_REAL = auto() + + +class Arms(int, Enum): """Enum for Arms.""" - LEFT = auto() - RIGHT = auto() - BOTH = auto() + LEFT = 0 + RIGHT = 1 + BOTH = 2 -class TaskStatus(Enum): +class TaskStatus(int, Enum): """ Enum for readable descriptions of a tasks' status. """ @@ -34,7 +43,7 @@ class JointType(Enum): FLOATING = 7 -class Grasp(Enum): +class Grasp(int, Enum): """ Enum for Grasp orientations. """ @@ -44,7 +53,7 @@ class Grasp(Enum): TOP = 3 -class ObjectType(Enum): +class ObjectType(int, Enum): """ Enum for Object types to easier identify different objects """ @@ -61,7 +70,7 @@ class ObjectType(Enum): HUMAN = auto() -class State(Enum): +class State(int, Enum): """ Enumeration which describes the result of a language expression. """ @@ -127,11 +136,21 @@ class FilterConfig(Enum): """ butterworth = 1 + +class PerceptionTechniques(Enum): + """ + Enum for techniques for perception tasks. + """ + ALL = auto() + HUMAN = auto() + TYPES = auto() + + class ImageEnum(Enum): """ - enum for picture id to be shown - on robot display + Enum for image switch view on hsrb display. + """ HI = 0 TALK = 1 @@ -152,6 +171,113 @@ class ImageEnum(Enum): JREPEAT = 16 SOFA = 17 INSPECT = 18 + CHAIR = 37 + + +class VirtualMobileBaseJointName(Enum): + """ + Enum for the joint names of the virtual mobile base. + """ + LINEAR_X = "odom_vel_lin_x_joint" + LINEAR_Y = "odom_vel_lin_y_joint" + ANGULAR_Z = "odom_vel_ang_z_joint" + + +class MJCFGeomType(Enum): + """ + Enum for the different geom types in a MuJoCo XML file. + """ + BOX = "box" + CYLINDER = "cylinder" + CAPSULE = "capsule" + SPHERE = "sphere" + PLANE = "plane" + MESH = "mesh" + ELLIPSOID = "ellipsoid" + HFIELD = "hfield" + SDF = "sdf" + + +MJCFBodyType = MJCFGeomType +""" +Alias for MJCFGeomType. As the body type is the same as the geom type. +""" + + +class MJCFJointType(Enum): + """ + Enum for the different joint types in a MuJoCo XML file. + """ + FREE = "free" + BALL = "ball" + SLIDE = "slide" + HINGE = "hinge" + FIXED = "fixed" # Added for compatibility with PyCRAM, but not a real joint type in MuJoCo. + + +class MultiverseAPIName(Enum): + """ + Enum for the different APIs of the Multiverse. + """ + GET_CONTACT_BODIES = "get_contact_bodies" + GET_CONSTRAINT_EFFORT = "get_constraint_effort" + ATTACH = "attach" + DETACH = "detach" + GET_RAYS = "get_rays" + EXIST = "exist" + PAUSE = "pause" + UNPAUSE = "unpause" + SAVE = "save" + LOAD = "load" +class MultiverseProperty(Enum): + def __str__(self): + return self.value + + +class MultiverseBodyProperty(MultiverseProperty): + """ + Enum for the different properties of a body the Multiverse. + """ + POSITION = "position" + ORIENTATION = "quaternion" + RELATIVE_VELOCITY = "relative_velocity" + + +class MultiverseJointProperty(MultiverseProperty): + pass + + +class MultiverseJointPosition(MultiverseJointProperty): + """ + Enum for the Position names of the different joint types in the Multiverse. + """ + REVOLUTE_JOINT_POSITION = "joint_rvalue" + PRISMATIC_JOINT_POSITION = "joint_tvalue" + + @classmethod + def from_pycram_joint_type(cls, joint_type: JointType) -> 'MultiverseJointPosition': + if joint_type in [JointType.REVOLUTE, JointType.CONTINUOUS]: + return MultiverseJointPosition.REVOLUTE_JOINT_POSITION + elif joint_type == JointType.PRISMATIC: + return MultiverseJointPosition.PRISMATIC_JOINT_POSITION + else: + raise UnsupportedJointType(joint_type) + + +class MultiverseJointCMD(MultiverseJointProperty): + """ + Enum for the Command names of the different joint types in the Multiverse. + """ + REVOLUTE_JOINT_CMD = "cmd_joint_rvalue" + PRISMATIC_JOINT_CMD = "cmd_joint_tvalue" + @classmethod + def from_pycram_joint_type(cls, joint_type: JointType) -> 'MultiverseJointCMD': + if joint_type in [JointType.REVOLUTE, JointType.CONTINUOUS]: + return MultiverseJointCMD.REVOLUTE_JOINT_CMD + elif joint_type == JointType.PRISMATIC: + return MultiverseJointCMD.PRISMATIC_JOINT_CMD + else: + raise UnsupportedJointType(joint_type) diff --git a/src/pycram/datastructures/pose.py b/src/pycram/datastructures/pose.py index 4ca28b267..e5e226861 100644 --- a/src/pycram/datastructures/pose.py +++ b/src/pycram/datastructures/pose.py @@ -3,15 +3,19 @@ import math import datetime -from typing_extensions import List, Union, Optional + +from tf.transformations import euler_from_quaternion +from typing_extensions import List, Union, Optional, Sized, Self import numpy as np -import rospy import sqlalchemy.orm from geometry_msgs.msg import PoseStamped, TransformStamped, Vector3, Point from geometry_msgs.msg import (Pose as GeoPose, Quaternion as GeoQuaternion) from tf import transformations from ..orm.base import Pose as ORMPose, Position, Quaternion, ProcessMetaData +from ..ros.data_types import Time +from ..validation.error_checkers import calculate_pose_error +from ..ros.logging import logwarn, logerr def get_normalized_quaternion(quaternion: np.ndarray) -> GeoQuaternion: @@ -47,7 +51,7 @@ class Pose(PoseStamped): """ def __init__(self, position: Optional[List[float]] = None, orientation: Optional[List[float]] = None, - frame: str = "map", time: rospy.Time = None): + frame: str = "map", time: Time = None): """ Poses can be initialized by a position and orientation given as lists, this is optional. By default, Poses are initialized with the position being [0, 0, 0], the orientation being [0, 0, 0, 1] and the frame being 'map'. @@ -68,7 +72,7 @@ def __init__(self, position: Optional[List[float]] = None, orientation: Optional self.header.frame_id = frame - self.header.stamp = time if time else rospy.Time.now() + self.header.stamp = time if time else Time().now() self.frame = frame @@ -85,6 +89,32 @@ def from_pose_stamped(pose_stamped: PoseStamped) -> Pose: p.pose = pose_stamped.pose return p + def get_position_diff(self, target_pose: Self) -> Point: + """ + Get the difference between the target and the current positions. + + :param target_pose: The target pose. + :return: The difference between the two positions. + """ + return Point(target_pose.position.x - self.position.x, target_pose.position.y - self.position.y, + target_pose.position.z - self.position.z) + + def get_z_angle_difference(self, target_pose: Self) -> float: + """ + Get the difference between two z angles. + + :param target_pose: The target pose. + :return: The difference between the two z angles. + """ + return target_pose.z_angle - self.z_angle + + @property + def z_angle(self) -> float: + """ + The z angle of the orientation of this Pose in radians. + """ + return euler_from_quaternion(self.orientation_as_list())[2] + @property def frame(self) -> str: """ @@ -119,7 +149,7 @@ def position(self, value) -> None: """ if (not isinstance(value, list) and not isinstance(value, tuple) and not isinstance(value, GeoPose) and not isinstance(value, Point)): - rospy.logerr("Position can only be a list or geometry_msgs/Pose") + logerr("Position can only be a list or geometry_msgs/Pose") raise TypeError("Position can only be a list/tuple or geometry_msgs/Pose") if isinstance(value, list) or isinstance(value, tuple) and len(value) == 3: self.pose.position.x = value[0] @@ -144,21 +174,22 @@ def orientation(self, value) -> None: :param value: New orientation, either a list or geometry_msgs/Quaternion """ - if not isinstance(value, list) and not isinstance(value, tuple) and not isinstance(value, GeoQuaternion): - rospy.logwarn("Orientation can only be a list or geometry_msgs/Quaternion") + if not isinstance(value, Sized) and not isinstance(value, GeoQuaternion): + logwarn("Orientation can only be an iterable (list, tuple, ...etc.) or a geometry_msgs/Quaternion") return - if isinstance(value, list) or isinstance(value, tuple) and len(value) == 4: + if isinstance(value, Sized) and len(value) == 4: orientation = np.array(value) - else: + elif isinstance(value, GeoQuaternion): orientation = np.array([value.x, value.y, value.z, value.w]) + else: + logerr("Orientation has to be a list or geometry_msgs/Quaternion") + raise TypeError("Orientation has to be a list or geometry_msgs/Quaternion") # This is used instead of np.linalg.norm since numpy is too slow on small arrays self.pose.orientation = get_normalized_quaternion(orientation) def to_list(self) -> List[List[float]]: """ - Returns the position and orientation of this pose as a list containing two list. - :return: The position and orientation as lists """ return [[self.pose.position.x, self.pose.position.y, self.pose.position.z], @@ -186,16 +217,12 @@ def copy(self) -> Pose: def position_as_list(self) -> List[float]: """ - Returns only the position as a list of xyz. - - :return: The position as a list + :return: The position as a list of xyz values. """ return [self.position.x, self.position.y, self.position.z] def orientation_as_list(self) -> List[float]: """ - Returns only the orientation as a list of a quaternion - :return: The orientation as a quaternion with xyzw """ return [self.pose.orientation.x, self.pose.orientation.y, self.pose.orientation.z, self.pose.orientation.w] @@ -230,6 +257,22 @@ def __eq__(self, other: Pose) -> bool: return self_position == other_position and self_orient == other_orient and self.frame == other.frame + def almost_equal(self, other: Pose, position_tolerance_in_meters: float = 1e-3, + orientation_tolerance_in_degrees: float = 1) -> bool: + """ + Checks if the given Pose is almost equal to this Pose. The position and orientation can have a certain + tolerance. The position tolerance is given in meters and the orientation tolerance in degrees. The position + error is calculated as the euclidian distance between the positions and the orientation error as the angle + between the quaternions. + + :param other: The other Pose which should be compared + :param position_tolerance_in_meters: The tolerance for the position in meters + :param orientation_tolerance_in_degrees: The tolerance for the orientation in degrees + :return: True if the Poses are almost equal, False otherwise + """ + error = calculate_pose_error(self, other) + return error[0] <= position_tolerance_in_meters and error[1] <= orientation_tolerance_in_degrees * math.pi / 180 + def set_position(self, new_position: List[float]) -> None: """ Sets the position of this Pose to the given position. Position has to be given as a vector in cartesian space. @@ -285,25 +328,6 @@ def multiply_quaternions(self, quaternion: List) -> None: self.orientation = (x, y, z, w) - def set_orientation_from_euler(self, axis: List, euler_angles: List[float]) -> None: - """ - Convert axis-angle to quaternion. - - :param axis: (x, y, z) tuple representing rotation axis. - :param angle: rotation angle in degree - :return: The quaternion representing the axis angle - """ - angle = math.radians(euler_angles) - axis_length = math.sqrt(sum([i ** 2 for i in axis])) - normalized_axis = tuple(i / axis_length for i in axis) - - x = normalized_axis[0] * math.sin(angle / 2) - y = normalized_axis[1] * math.sin(angle / 2) - z = normalized_axis[2] * math.sin(angle / 2) - w = math.cos(angle / 2) - - return (x, y, z, w) - class Transform(TransformStamped): """ @@ -319,7 +343,7 @@ class Transform(TransformStamped): Rotation: A quaternion representing the conversion of rotation between both frames """ def __init__(self, translation: Optional[List[float]] = None, rotation: Optional[List[float]] = None, - frame: Optional[str] = "map", child_frame: Optional[str] = "", time: rospy.Time = None): + frame: Optional[str] = "map", child_frame: Optional[str] = "", time: Time = None): """ Transforms take a translation, rotation, frame and child_frame as optional arguments. If nothing is given the Transform will be initialized with [0, 0, 0] for translation, [0, 0, 0, 1] for rotation, 'map' for frame and an @@ -342,10 +366,31 @@ def __init__(self, translation: Optional[List[float]] = None, rotation: Optional self.header.frame_id = frame self.child_frame_id = child_frame - self.header.stamp = time if time else rospy.Time.now() + self.header.stamp = time if time else Time().now() self.frame = frame + def apply_transform_to_array_of_points(self, points: np.ndarray) -> np.ndarray: + """ + Applies this Transform to an array of points. The points are given as a Nx3 matrix, where N is the number of + points. The points are transformed from the child_frame_id to the frame_id of this Transform. + + :param points: The points that should be transformed, given as a Nx3 matrix. + """ + homogeneous_transform = self.get_homogeneous_matrix() + # add the homogeneous coordinate, by adding a column of ones to the position vectors, becoming 4xN matrix + homogenous_points = np.concatenate((points, np.ones((points.shape[0], 1))), axis=1).T + rays_end_positions = homogeneous_transform @ homogenous_points + return rays_end_positions[:3, :].T + + def get_homogeneous_matrix(self) -> np.ndarray: + """ + :return: The homogeneous matrix of this Transform + """ + translation = transformations.translation_matrix(self.translation_as_list()) + rotation = transformations.quaternion_matrix(self.rotation_as_list()) + return np.dot(translation, rotation) + @classmethod def from_pose_and_child_frame(cls, pose: Pose, child_frame_name: str) -> Transform: return cls(pose.position_as_list(), pose.orientation_as_list(), pose.frame, child_frame_name, @@ -386,7 +431,7 @@ def frame(self, value: str) -> None: self.header.frame_id = value @property - def translation(self) -> None: + def translation(self) -> Vector3: """ Property that points to the translation of this Transform """ @@ -401,7 +446,7 @@ def translation(self, value) -> None: :param value: The new value for the translation, either a list or geometry_msgs/Vector3 """ if not isinstance(value, list) and not isinstance(value, Vector3): - rospy.logwarn("Value of a translation can only be a list of a geometry_msgs/Vector3") + logwarn("Value of a translation can only be a list of a geometry_msgs/Vector3") return if isinstance(value, list) and len(value) == 3: self.transform.translation.x = value[0] @@ -411,7 +456,7 @@ def translation(self, value) -> None: self.transform.translation = value @property - def rotation(self) -> None: + def rotation(self) -> Quaternion: """ Property that points to the rotation of this Transform """ @@ -426,7 +471,7 @@ def rotation(self, value): :param value: The new value for the rotation, either a list or geometry_msgs/Quaternion """ if not isinstance(value, list) and not isinstance(value, GeoQuaternion): - rospy.logwarn("Value of the rotation can only be a list or a geometry.msgs/Quaternion") + logwarn("Value of the rotation can only be a list or a geometry.msgs/Quaternion") return if isinstance(value, list) and len(value) == 4: rotation = np.array(value) @@ -449,16 +494,12 @@ def copy(self) -> Transform: def translation_as_list(self) -> List[float]: """ - Returns the translation of this Transform as a list. - :return: The translation as a list of xyz """ return [self.transform.translation.x, self.transform.translation.y, self.transform.translation.z] def rotation_as_list(self) -> List[float]: """ - Returns the rotation of this Transform as a list representing a quaternion. - :return: The rotation of this Transform as a list with xyzw """ return [self.transform.rotation.x, self.transform.rotation.y, self.transform.rotation.z, @@ -494,7 +535,7 @@ def __mul__(self, other: Transform) -> Union[Transform, None]: :return: The resulting Transform from the multiplication """ if not isinstance(other, Transform): - rospy.logerr(f"Can only multiply two Transforms") + logerr(f"Can only multiply two Transforms") return self_trans = transformations.translation_matrix(self.translation_as_list()) self_rot = transformations.quaternion_matrix(self.rotation_as_list()) @@ -553,5 +594,3 @@ def set_rotation(self, new_rotation: List[float]) -> None: :param new_rotation: The new rotation as a quaternion with xyzw """ self.rotation = new_rotation - - diff --git a/src/pycram/datastructures/world.py b/src/pycram/datastructures/world.py index 70714a6e7..9097b7e85 100644 --- a/src/pycram/datastructures/world.py +++ b/src/pycram/datastructures/world.py @@ -6,100 +6,39 @@ import time from abc import ABC, abstractmethod from copy import copy -from queue import Queue - import numpy as np -import rospy from geometry_msgs.msg import Point -from typing_extensions import List, Optional, Dict, Tuple, Callable, TYPE_CHECKING -from typing_extensions import Union +from typing_extensions import List, Optional, Dict, Tuple, Callable, TYPE_CHECKING, Union, Type from ..cache_manager import CacheManager -from .enums import JointType, ObjectType, WorldMode -from ..world_concepts.event import Event +from ..config.world_conf import WorldConfig +from ..datastructures.dataclasses import (Color, AxisAlignedBoundingBox, CollisionCallbacks, + MultiBody, VisualShape, BoxVisualShape, CylinderVisualShape, + SphereVisualShape, + CapsuleVisualShape, PlaneVisualShape, MeshVisualShape, + ObjectState, WorldState, ClosestPointsList, + ContactPointsList, VirtualMobileBaseJoints) +from ..datastructures.enums import JointType, ObjectType, WorldMode, Arms +from ..datastructures.pose import Pose, Transform +from ..datastructures.world_entity import StateEntity +from ..failures import ProspectionObjectNotFound, WorldObjectNotFound from ..local_transformer import LocalTransformer -from .pose import Pose, Transform +from ..robot_description import RobotDescription +from ..ros.data_types import Time +from ..ros.logging import logwarn +from ..validation.goal_validator import (MultiPoseGoalValidator, + PoseGoalValidator, JointPositionGoalValidator, + MultiJointPositionGoalValidator, GoalValidator, + validate_joint_position, validate_multiple_joint_positions, + validate_object_pose, validate_multiple_object_poses) from ..world_concepts.constraints import Constraint -from .dataclasses import (Color, AxisAlignedBoundingBox, CollisionCallbacks, - MultiBody, VisualShape, BoxVisualShape, CylinderVisualShape, SphereVisualShape, - CapsuleVisualShape, PlaneVisualShape, MeshVisualShape, - ObjectState, State, WorldState) +from ..world_concepts.event import Event if TYPE_CHECKING: from ..world_concepts.world_object import Object - from ..description import Link, Joint - - -class StateEntity: - """ - The StateEntity class is used to store the state of an object or the physics simulator. This is used to save and - restore the state of the World. - """ - - def __init__(self): - self._saved_states: Dict[int, State] = {} - - @property - def saved_states(self) -> Dict[int, State]: - """ - Returns the saved states of this entity. - """ - return self._saved_states - - def save_state(self, state_id: int) -> int: - """ - Saves the state of this entity with the given state id. - - :param state_id: The unique id of the state. - """ - self._saved_states[state_id] = self.current_state - return state_id - - @property - @abstractmethod - def current_state(self) -> State: - """ - Returns the current state of this entity. - - :return: The current state of this entity. - """ - pass - - @current_state.setter - @abstractmethod - def current_state(self, state: State) -> None: - """ - Sets the current state of this entity. - - :param state: The new state of this entity. - """ - pass - - def restore_state(self, state_id: int) -> None: - """ - Restores the state of this entity from a saved state using the given state id. - - :param state_id: The unique id of the state. - """ - self.current_state = self.saved_states[state_id] - - def remove_saved_states(self) -> None: - """ - Removes all saved states of this entity. - """ - self._saved_states = {} - - -class WorldEntity(StateEntity, ABC): - """ - A data class that represents an entity of the world, such as an object or a link. - """ - - def __init__(self, _id: int, world: Optional[World] = None): - StateEntity.__init__(self) - self.id = _id - self.world: World = world if world is not None else World.current_world + from ..description import Link, Joint, ObjectDescription + from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription class World(StateEntity, ABC): @@ -109,17 +48,17 @@ class World(StateEntity, ABC): current_world which is managed by the World class itself. """ - simulation_frequency: float + conf: Type[WorldConfig] = WorldConfig """ - Global reference for the simulation frequency (Hz), used in calculating the equivalent real time in the simulation. + The configurations of the world, the default configurations are defined in world_conf.py in the config folder. """ current_world: Optional[World] = None """ - Global reference to the currently used World, usually this is the - graphical one. However, if you are inside a UseProspectionWorld() environment the current_world points to the - prospection world. In this way you can comfortably use the current_world, which should point towards the World - used at the moment. + Global reference to the currently used World, usually this is the + graphical one. However, if you are inside a UseProspectionWorld() environment the current_world points to the + prospection world. In this way you can comfortably use the current_world, which should point towards the World + used at the moment. """ robot: Optional[Object] = None @@ -128,51 +67,49 @@ class World(StateEntity, ABC): the URDF with the name of the URDF on the parameter server. """ - data_directory: List[str] = [os.path.join(os.path.dirname(__file__), '..', '..', '..', 'resources')] - """ - Global reference for the data directories, this is used to search for the description files of the robot - and the objects. - """ - - cache_dir = data_directory[0] + '/cached/' + cache_manager: CacheManager = CacheManager(conf.cache_dir, [conf.resources_path], False) """ - Global reference for the cache directory, this is used to cache the description files of the robot and the objects. + Global reference for the cache manager, this is used to cache the description files of the robot and the objects. """ - def __init__(self, mode: WorldMode, is_prospection_world: bool, simulation_frequency: float): + def __init__(self, mode: WorldMode, is_prospection_world: bool = False, clear_cache: bool = False): """ - Creates a new simulation, the mode decides if the simulation should be a rendered window or just run in the - background. There can only be one rendered simulation. - The World object also initializes the Events for attachment, detachment and for manipulating the world. + Create a new simulation, the mode decides if the simulation should be a rendered window or just run in the + background. There can only be one rendered simulation. + The World object also initializes the Events for attachment, detachment and for manipulating the world. - :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default parameter is "GUI" - :param is_prospection_world: For internal usage, decides if this World should be used as a prospection world. + :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default parameter is + "GUI" + :param is_prospection_world: For internal usage, decides if this World should be used as a prospection world. + :param clear_cache: Whether to clear the cache directory. """ StateEntity.__init__(self) + if clear_cache or (self.conf.clear_cache_at_start and not self.cache_manager.cache_cleared): + self.cache_manager.clear_cache() + + GoalValidator.raise_error = self.conf.raise_goal_validator_error + if World.current_world is None: World.current_world = self - World.simulation_frequency = simulation_frequency - self.cache_manager = CacheManager(self.cache_dir, self.data_directory) + self.object_lock: threading.Lock = threading.Lock() self.id: Optional[int] = -1 # This is used to connect to the physics server (allows multiple clients) self._init_world(mode) + self.objects: List[Object] = [] + # List of all Objects in the World + self.is_prospection_world: bool = is_prospection_world self._init_and_sync_prospection_world() self.local_transformer = LocalTransformer() self._update_local_transformer_worlds() - self.objects: List[Object] = [] - # List of all Objects in the World - - - self.mode: WorldMode = mode # The mode of the simulation, can be "GUI" or "DIRECT" @@ -182,16 +119,104 @@ def __init__(self, mode: WorldMode, is_prospection_world: bool, simulation_frequ self._current_state: Optional[WorldState] = None + self._init_goal_validators() + + self.original_state_id = self.save_state() + + @classmethod + def get_cache_dir(cls) -> str: + """ + Return the cache directory. + """ + return cls.cache_manager.cache_dir + + def add_object(self, obj: Object) -> None: + """ + Add an object to the world. + + :param obj: The object to be added. + """ + self.object_lock.acquire() + self.objects.append(obj) + self.add_object_to_original_state(obj) + self.object_lock.release() + + @property + def robot_description(self) -> RobotDescription: + """ + Return the current robot description. + """ + return RobotDescription.current_robot_description + + @property + def robot_has_actuators(self) -> bool: + """ + Return whether the robot has actuators. + """ + return self.robot_description.has_actuators + + def get_actuator_for_joint(self, joint: Joint) -> str: + """ + Get the actuator name for a given joint. + """ + return self.robot_joint_actuators[joint.name] + + def joint_has_actuator(self, joint: Joint) -> bool: + """ + Return whether the joint has an actuator. + """ + return joint.name in self.robot_joint_actuators + + @property + def robot_joint_actuators(self) -> Dict[str, str]: + """ + Return the joint actuators of the robot. + """ + return self.robot_description.joint_actuators + + def _init_goal_validators(self): + """ + Initialize the goal validators for the World objects' poses, positions, and orientations. + """ + + # Objects Pose goal validators + self.pose_goal_validator = PoseGoalValidator(self.get_object_pose, self.conf.get_pose_tolerance(), + self.conf.acceptable_percentage_of_goal) + self.multi_pose_goal_validator = MultiPoseGoalValidator( + lambda x: list(self.get_multiple_object_poses(x).values()), + self.conf.get_pose_tolerance(), self.conf.acceptable_percentage_of_goal) + + # Joint Goal validators + self.joint_position_goal_validator = JointPositionGoalValidator( + self.get_joint_position, + acceptable_revolute_joint_position_error=self.conf.revolute_joint_position_tolerance, + acceptable_prismatic_joint_position_error=self.conf.prismatic_joint_position_tolerance, + acceptable_percentage_of_goal_achieved=self.conf.acceptable_percentage_of_goal) + self.multi_joint_position_goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.get_multiple_joint_positions(x).values()), + acceptable_revolute_joint_position_error=self.conf.revolute_joint_position_tolerance, + acceptable_prismatic_joint_position_error=self.conf.prismatic_joint_position_tolerance, + acceptable_percentage_of_goal_achieved=self.conf.acceptable_percentage_of_goal) + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the simulator. + + :param obj: The object to check. + :return: True if the object is in the world, False otherwise. + """ + raise NotImplementedError + @abstractmethod def _init_world(self, mode: WorldMode): """ - Initializes the physics simulation. + Initialize the physics simulation. """ raise NotImplementedError def _init_events(self): """ - Initializes dynamic events that can be used to react to changes in the World. + Initialize dynamic events that can be used to react to changes in the World. """ self.detachment_event: Event = Event() self.attachment_event: Event = Event() @@ -199,86 +224,108 @@ def _init_events(self): def _init_and_sync_prospection_world(self): """ - Initializes the prospection world and the synchronization between the main and the prospection world. + Initialize the prospection world and the synchronization between the main and the prospection world. """ self._init_prospection_world() self._sync_prospection_world() def _update_local_transformer_worlds(self): """ - Updates the local transformer worlds with the current world and prospection world. + Update the local transformer worlds with the current world and prospection world. """ self.local_transformer.world = self self.local_transformer.prospection_world = self.prospection_world def _init_prospection_world(self): """ - Initializes the prospection world, if this is a prospection world itself it will not create another prospection, + Initialize the prospection world, if this is a prospection world itself it will not create another prospection, world, but instead set the prospection world to None, else it will create a prospection world. """ if self.is_prospection_world: # then no need to add another prospection world self.prospection_world = None else: self.prospection_world: World = self.__class__(WorldMode.DIRECT, - True, - World.simulation_frequency) + True) def _sync_prospection_world(self): """ - Synchronizes the prospection world with the main world, this means that every object in the main world will be + Synchronize the prospection world with the main world, this means that every object in the main world will be added to the prospection world and vice versa. """ if self.is_prospection_world: # then no need to add another prospection world self.world_sync = None else: self.world_sync: WorldSync = WorldSync(self, self.prospection_world) + self.pause_world_sync() self.world_sync.start() - def update_cache_dir_with_object(self, path: str, ignore_cached_files: bool, - obj: Object) -> str: + def preprocess_object_file_and_get_its_cache_path(self, path: str, ignore_cached_files: bool, + description: ObjectDescription, name: str, + scale_mesh: Optional[float] = None) -> str: """ - Updates the cache directory with the given object. + Update the cache directory with the given object. :param path: The path to the object. :param ignore_cached_files: If the cached files should be ignored. - :param obj: The object to be added to the cache directory. + :param description: The object description. + :param name: The name of the object. + :param scale_mesh: The scale of the mesh. + :return: The path of the cached object. """ - return self.cache_manager.update_cache_dir_with_object(path, ignore_cached_files, obj.description, obj.name) + return self.cache_manager.update_cache_dir_with_object(path, ignore_cached_files, description, name, scale_mesh) @property def simulation_time_step(self): """ The time step of the simulation in seconds. """ - return 1 / World.simulation_frequency + return 1 / self.__class__.conf.simulation_frequency @abstractmethod - def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None) -> int: + def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: """ - Loads a description file (e.g. URDF) at the given pose and returns the id of the loaded object. + Load a description file (e.g. URDF) at the given pose and returns the id of the loaded object. :param path: The path to the description file, if None the description file is assumed to be already loaded. :param pose: The pose at which the object should be loaded. + :param obj_type: The type of the object. :return: The id of the loaded object. """ pass + def load_generic_object_and_get_id(self, description: GenericObjectDescription, + pose: Optional[Pose] = None) -> int: + """ + Create a visual and collision box in the simulation and returns the id of the loaded object. + + :param description: The object description. + :param pose: The pose at which the object should be loaded. + """ + raise NotImplementedError + + def get_object_names(self) -> List[str]: + """ + Return the names of all objects in the World. + + :return: A list of object names. + """ + return [obj.name for obj in self.objects] + def get_object_by_name(self, name: str) -> Optional[Object]: """ - Returns the object with the given name. If there is no object with the given name, None is returned. + Return the object with the given name. If there is no object with the given name, None is returned. :param name: The name of the returned Objects. :return: The object with the given name, if there is one. """ - object = list(filter(lambda obj: obj.name == name, self.objects)) - if len(object) > 0: - return object[0] - return None + matching_objects = list(filter(lambda obj: obj.name == name, self.objects)) + return matching_objects[0] if len(matching_objects) > 0 else None def get_object_by_type(self, obj_type: ObjectType) -> List[Object]: """ - Returns a list of all Objects which have the type 'obj_type'. + Return a list of all Objects which have the type 'obj_type'. :param obj_type: The type of the returned Objects. :return: A list of all Objects that have the type 'obj_type'. @@ -287,59 +334,92 @@ def get_object_by_type(self, obj_type: ObjectType) -> List[Object]: def get_object_by_id(self, obj_id: int) -> Object: """ - Returns the single Object that has the unique id. + Return the single Object that has the unique id. :param obj_id: The unique id for which the Object should be returned. :return: The Object with the id 'id'. """ return list(filter(lambda obj: obj.id == obj_id, self.objects))[0] - @abstractmethod - def remove_object_by_id(self, obj_id: int) -> None: + def remove_visual_object(self, obj_id: int) -> bool: """ - Removes the object with the given id from the world. + Remove the object with the given id from the world, and saves a new original state for the world. :param obj_id: The unique id of the object to be removed. + :return: Whether the object was removed successfully. + """ + + removed = self._remove_visual_object(obj_id) + if removed: + self.update_simulator_state_id_in_original_state() + else: + logwarn(f"Object with id {obj_id} could not be removed.") + return removed + + @abstractmethod + def _remove_visual_object(self, obj_id: int) -> bool: + """ + Remove the visual object with the given id from the world, and update the simulator state in the original state. + + :param obj_id: The unique id of the visual object to be removed. + :return: Whether the object was removed successfully. """ pass @abstractmethod - def remove_object_from_simulator(self, obj: Object) -> None: + def remove_object_from_simulator(self, obj: Object) -> bool: """ - Removes an object from the physics simulator. + Remove an object from the physics simulator. :param obj: The object to be removed. + :return: Whether the object was removed successfully. """ pass def remove_object(self, obj: Object) -> None: """ - Removes this object from the current world. + Remove this object from the current world. For the object to be removed it has to be detached from all objects it is currently attached to. After this is done a call to world remove object is done to remove this Object from the simulation/world. :param obj: The object to be removed. """ - obj.detach_all() + self.object_lock.acquire() - self.objects.remove(obj) - - # This means the current world of the object is not the prospection world, since it - # has a reference to the prospection world - if self.prospection_world is not None: - self.world_sync.remove_obj_queue.put(obj) - self.world_sync.remove_obj_queue.join() + obj.detach_all() - self.remove_object_from_simulator(obj) + if self.remove_object_from_simulator(obj): + self.objects.remove(obj) + self.remove_object_from_original_state(obj) if World.robot == obj: World.robot = None + self.object_lock.release() + + def remove_object_from_original_state(self, obj: Object) -> None: + """ + Remove an object from the original state of the world. + + :param obj: The object to be removed. + """ + self.original_state.object_states.pop(obj.name) + self.original_state.simulator_state_id = self.save_physics_simulator_state(use_same_id=True) + + def add_object_to_original_state(self, obj: Object) -> None: + """ + Add an object to the original state of the world. + + :param obj: The object to be added. + """ + self.original_state.object_states[obj.name] = obj.current_state + self.update_simulator_state_id_in_original_state() + def add_fixed_constraint(self, parent_link: Link, child_link: Link, child_to_parent_transform: Transform) -> int: """ - Creates a fixed joint constraint between the given parent and child links, + Create a fixed joint constraint between the given parent and child links, the joint frame will be at the origin of the child link frame, and would have the same orientation as the child link frame. @@ -390,7 +470,7 @@ def get_joint_position(self, joint: Joint) -> float: @abstractmethod def get_object_joint_names(self, obj: Object) -> List[str]: """ - Returns the names of all joints of this object. + Return the names of all joints of this object. :param obj: The object. :return: A list of joint names. @@ -407,10 +487,60 @@ def get_link_pose(self, link: Link) -> Pose: """ pass + @abstractmethod + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + """ + Get the poses of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and Pose objects as values. + """ + pass + + @abstractmethod + def get_link_position(self, link: Link) -> List[float]: + """ + Get the position of a link of an articulated object with respect to the world frame. + + :param link: The link as a AbstractLink object. + :return: The position of the link as a list of floats. + """ + pass + + @abstractmethod + def get_link_orientation(self, link: Link) -> List[float]: + """ + Get the orientation of a link of an articulated object with respect to the world frame. + + :param link: The link as a AbstractLink object. + :return: The orientation of the link as a list of floats. + """ + pass + + @abstractmethod + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the positions of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and lists of floats as values. + """ + pass + + @abstractmethod + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple links of an articulated object with respect to the world frame. + + :param links: The links as a list of AbstractLink objects. + :return: A dictionary with link names as keys and lists of floats as values. + """ + pass + @abstractmethod def get_object_link_names(self, obj: Object) -> List[str]: """ - Returns the names of all links of this object. + Return the names of all links of this object. :param obj: The object. :return: A list of link names. @@ -419,7 +549,7 @@ def get_object_link_names(self, obj: Object) -> List[str]: def simulate(self, seconds: float, real_time: Optional[bool] = False) -> None: """ - Simulates Physics in the World for a given amount of seconds. Usually this simulation is faster than real + Simulate Physics in the World for a given amount of seconds. Usually this simulation is faster than real time. By setting the 'real_time' parameter this simulation is slowed down such that the simulated time is equal to real time. @@ -427,24 +557,24 @@ def simulate(self, seconds: float, real_time: Optional[bool] = False) -> None: :param real_time: If the simulation should happen in real time or faster. """ self.set_realtime(real_time) - for i in range(0, int(seconds * self.simulation_frequency)): - curr_time = rospy.Time.now() + for i in range(0, int(seconds * self.conf.simulation_frequency)): + curr_time = Time().now() self.step() for objects, callbacks in self.coll_callbacks.items(): contact_points = self.get_contact_points_between_two_objects(objects[0], objects[1]) - if contact_points != (): + if len(contact_points) > 0: callbacks.on_collision_cb() elif callbacks.no_collision_cb is not None: callbacks.no_collision_cb() if real_time: - loop_time = rospy.Time.now() - curr_time + loop_time = Time().now() - curr_time time_diff = self.simulation_time_step - loop_time.to_sec() time.sleep(max(0, time_diff)) self.update_all_objects_poses() def update_all_objects_poses(self) -> None: """ - Updates the positions of all objects in the world. + Update the positions of all objects in the world. """ for obj in self.objects: obj.update_pose() @@ -453,20 +583,89 @@ def update_all_objects_poses(self) -> None: def get_object_pose(self, obj: Object) -> Pose: """ Get the pose of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @abstractmethod + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + """ + Get the poses of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. + """ + pass + + @abstractmethod + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + """ + Get the positions of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. """ pass + @abstractmethod + def get_object_position(self, obj: Object) -> List[float]: + """ + Get the position of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @abstractmethod + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple objects in the world frame from the current object poses in the simulator. + + :param objects: The objects. + """ + pass + + @abstractmethod + def get_object_orientation(self, obj: Object) -> List[float]: + """ + Get the orientation of an object in the world frame from the current object pose in the simulator. + + :param obj: The object. + """ + pass + + @property + def robot_virtual_joints(self) -> List[Joint]: + """ + The virtual joints of the robot. + """ + return [self.robot.joints[name] for name in self.robot_virtual_joints_names] + + @property + def robot_virtual_joints_names(self) -> List[str]: + """ + The names of the virtual joints of the robot. + """ + return self.robot_description.virtual_mobile_base_joints.names + + def get_robot_mobile_base_joints(self) -> VirtualMobileBaseJoints: + """ + Get the mobile base joints of the robot. + + :return: The mobile base joints. + """ + return self.robot_description.virtual_mobile_base_joints + @abstractmethod def perform_collision_detection(self) -> None: """ - Checks for collisions between all objects in the World and updates the contact points. + Check for collisions between all objects in the World and updates the contact points. """ pass @abstractmethod - def get_object_contact_points(self, obj: Object) -> List: + def get_object_contact_points(self, obj: Object) -> ContactPointsList: """ - Returns a list of contact points of this Object with all other Objects. + Return a list of contact points of this Object with all other Objects. :param obj: The object. :return: A list of all contact points with other objects @@ -474,9 +673,9 @@ def get_object_contact_points(self, obj: Object) -> List: pass @abstractmethod - def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> List: + def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> ContactPointsList: """ - Returns a list of contact points between obj1 and obj2. + Return a list of contact points between obj_a and obj_b. :param obj1: The first object. :param obj2: The second object. @@ -484,24 +683,97 @@ def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> """ pass + def get_object_closest_points(self, obj: Object, max_distance: float) -> ClosestPointsList: + """ + Return the closest points of this object with all other objects in the world. + + :param obj: The object. + :param max_distance: The maximum distance between the points. + :return: A list of the closest points. + """ + all_obj_closest_points = [self.get_closest_points_between_objects(obj, other_obj, max_distance) for other_obj in + self.objects + if other_obj != obj] + return ClosestPointsList([point for closest_points in all_obj_closest_points for point in closest_points]) + + def get_closest_points_between_objects(self, object_a: Object, object_b: Object, max_distance: float) \ + -> ClosestPointsList: + """ + Return the closest points between two objects. + + :param object_a: The first object. + :param object_b: The second object. + :param max_distance: The maximum distance between the points. + :return: A list of the closest points. + """ + raise NotImplementedError + + @validate_joint_position @abstractmethod - def reset_joint_position(self, joint: Joint, joint_position: float) -> None: + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: """ Reset the joint position instantly without physics simulation + .. note:: + It is recommended to use the validate_joint_position decorator to validate the joint position for + the implementation of this method. + :param joint: The joint to reset the position for. :param joint_position: The new joint pose. + :return: True if the reset was successful, False otherwise """ pass + @validate_multiple_joint_positions @abstractmethod - def reset_object_base_pose(self, obj: Object, pose: Pose): + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints of an articulated object. + + .. note:: + It is recommended to use the validate_multiple_joint_positions decorator to validate the + joint positions for the implementation of this method. + + :param joint_positions: A dictionary with joint objects as keys and joint positions as values. + :return: True if the set was successful, False otherwise. + """ + pass + + @abstractmethod + def get_multiple_joint_positions(self, joints: List[Joint]) -> Dict[str, float]: + """ + Get the positions of multiple joints of an articulated object. + + :param joints: The joints as a list of Joint objects. + """ + pass + + @validate_object_pose + @abstractmethod + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: """ Reset the world position and orientation of the base of the object instantaneously, not through physics simulation. (x,y,z) position vector and (x,y,z,w) quaternion orientation. + .. note:: + It is recommended to use the validate_object_pose decorator to validate the object pose for the + implementation of this method. + :param obj: The object. :param pose: The new pose as a Pose object. + :return: True if the reset was successful, False otherwise. + """ + pass + + @validate_multiple_object_poses + @abstractmethod + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> bool: + """ + Reset the world position and orientation of the base of multiple objects instantaneously, + not through physics simulation. (x,y,z) position vector and (x,y,z,w) quaternion orientation. + + :param objects: A dictionary with objects as keys and poses as values. + :return: True if the reset was successful, False otherwise. """ pass @@ -512,10 +784,20 @@ def step(self): """ pass + def get_arm_tool_frame_link(self, arm: Arms) -> Link: + """ + Get the tool frame link of the arm of the robot. + + :param arm: The arm for which the tool frame link should be returned. + :return: The tool frame link of the arm. + """ + ee_link_name = self.robot_description.get_arm_tool_frame(arm) + return self.robot.get_link(ee_link_name) + @abstractmethod def set_link_color(self, link: Link, rgba_color: Color): """ - Changes the rgba_color of a link of this object, the rgba_color has to be given as Color object. + Change the rgba_color of a link of this object, the rgba_color has to be given as Color object. :param link: The link which should be colored. :param rgba_color: The rgba_color as Color object with RGBA values between 0 and 1. @@ -545,7 +827,7 @@ def get_colors_of_object_links(self, obj: Object) -> Dict[str, Color]: @abstractmethod def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of this object. The return of this method are two points in + Return the axis aligned bounding box of this object. The return of this method are two points in world coordinate frame which define a bounding box. :param obj: The object for which the bounding box should be returned. @@ -556,7 +838,7 @@ def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundi @abstractmethod def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of the link. The return of this method are two points in + Return the axis aligned bounding box of the link. The return of this method are two points in world coordinate frame which define a bounding box. """ pass @@ -564,7 +846,7 @@ def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingB @abstractmethod def set_realtime(self, real_time: bool) -> None: """ - Enables the real time simulation of Physics in the World. By default, this is disabled and Physics is only + Enable the real time simulation of Physics in the World. By default, this is disabled and Physics is only simulated to reason about it. :param real_time: Whether the World should simulate Physics in real time. @@ -574,7 +856,7 @@ def set_realtime(self, real_time: bool) -> None: @abstractmethod def set_gravity(self, gravity_vector: List[float]) -> None: """ - Sets the gravity that is used in the World. By default, it is set to the gravity on earth ([0, 0, -9.8]). + Set the gravity that is used in the World. By default, it is set to the gravity on earth ([0, 0, -9.8]). Gravity is given as a vector in x,y,z. Gravity is only applied while simulating Physic. :param gravity_vector: The gravity vector that should be used in the World. @@ -583,7 +865,7 @@ def set_gravity(self, gravity_vector: List[float]) -> None: def set_robot_if_not_set(self, robot: Object) -> None: """ - Sets the robot if it is not set yet. + Set the robot if it is not set yet. :param robot: The Object reference to the Object representing the robot. """ @@ -593,7 +875,7 @@ def set_robot_if_not_set(self, robot: Object) -> None: @staticmethod def set_robot(robot: Union[Object, None]) -> None: """ - Sets the global variable for the robot Object This should be set on spawning the robot. + Set the global variable for the robot Object This should be set on spawning the robot. :param robot: The Object reference to the Object representing the robot. """ @@ -602,17 +884,21 @@ def set_robot(robot: Union[Object, None]) -> None: @staticmethod def robot_is_set() -> bool: """ - Returns whether the robot has been set or not. + Return whether the robot has been set or not. :return: True if the robot has been set, False otherwise. """ return World.robot is not None - def exit(self) -> None: + def exit(self, remove_saved_states: bool = True) -> None: """ - Closes the World as well as the prospection world, also collects any other thread that is running. + Close the World as well as the prospection world, also collects any other thread that is running. + + :param remove_saved_states: Whether to remove the saved states. """ self.exit_prospection_world_if_exists() + self.reset_world(remove_saved_states) + self.remove_all_objects() self.disconnect_from_physics_server() self.reset_robot() self.join_threads() @@ -621,7 +907,7 @@ def exit(self) -> None: def exit_prospection_world_if_exists(self) -> None: """ - Exits the prospection world if it exists. + Exit the prospection world if it exists. """ if self.prospection_world: self.terminate_world_sync() @@ -630,21 +916,21 @@ def exit_prospection_world_if_exists(self) -> None: @abstractmethod def disconnect_from_physics_server(self) -> None: """ - Disconnects the world from the physics server. + Disconnect the world from the physics server. """ pass def reset_current_world(self) -> None: """ - Resets the pose of every object in the World to the pose it was spawned in and sets every joint to 0. + Reset the pose of every object in the World to the pose it was spawned in and sets every joint to 0. """ for obj in self.objects: obj.set_pose(obj.original_pose) - obj.set_joint_positions(dict(zip(list(obj.joint_names), [0] * len(obj.joint_names)))) + obj.set_multiple_joint_positions(dict(zip(list(obj.joint_names), [0] * len(obj.joint_names)))) def reset_robot(self) -> None: """ - Sets the robot class variable to None. + Set the robot class variable to None. """ self.set_robot(None) @@ -657,19 +943,22 @@ def join_threads(self) -> None: def terminate_world_sync(self) -> None: """ - Terminates the world sync thread. + Terminate the world sync thread. """ self.world_sync.terminate = True + self.resume_world_sync() self.world_sync.join() - def save_state(self, state_id: Optional[int] = None) -> int: + def save_state(self, state_id: Optional[int] = None, use_same_id: bool = False) -> int: """ - Returns the id of the saved state of the World. The saved state contains the states of all the objects and + Return the id of the saved state of the World. The saved state contains the states of all the objects and the state of the physics simulator. + :param state_id: The id of the saved state. + :param use_same_id: Whether to use the same current state id for the new saved state. :return: A unique id of the state """ - state_id = self.save_physics_simulator_state() + state_id = self.save_physics_simulator_state(state_id=state_id, use_same_id=use_same_id) self.save_objects_state(state_id) self._current_state = WorldState(state_id, self.object_states) return super().save_state(state_id) @@ -677,18 +966,24 @@ def save_state(self, state_id: Optional[int] = None) -> int: @property def current_state(self) -> WorldState: if self._current_state is None: - self._current_state = WorldState(self.save_physics_simulator_state(), self.object_states) - return self._current_state + simulator_state = None if self.conf.use_physics_simulator_state else ( + self.save_physics_simulator_state(use_same_id=True)) + self._current_state = WorldState(simulator_state, self.object_states) + return WorldState(self._current_state.simulator_state_id, self.object_states) @current_state.setter def current_state(self, state: WorldState) -> None: - self.restore_physics_simulator_state(state.simulator_state_id) - self.object_states = state.object_states + if self.current_state != state: + if self.conf.use_physics_simulator_state: + self.restore_physics_simulator_state(state.simulator_state_id) + else: + for obj in self.objects: + self.get_object_by_name(obj.name).current_state = state.object_states[obj.name] @property def object_states(self) -> Dict[str, ObjectState]: """ - Returns the states of all objects in the World. + Return the states of all objects in the World. :return: A dictionary with the object id as key and the object state as value. """ @@ -697,14 +992,14 @@ def object_states(self) -> Dict[str, ObjectState]: @object_states.setter def object_states(self, states: Dict[str, ObjectState]) -> None: """ - Sets the states of all objects in the World. + Set the states of all objects in the World. """ for obj_name, obj_state in states.items(): self.get_object_by_name(obj_name).current_state = obj_state def save_objects_state(self, state_id: int) -> None: """ - Saves the state of all objects in the World according to the given state using the unique state id. + Save the state of all objects in the World according to the given state using the unique state id. :param state_id: The unique id representing the state. """ @@ -712,10 +1007,12 @@ def save_objects_state(self, state_id: int) -> None: obj.save_state(state_id) @abstractmethod - def save_physics_simulator_state(self) -> int: + def save_physics_simulator_state(self, state_id: Optional[int] = None, use_same_id: bool = False) -> int: """ - Saves the state of the physics simulator and returns the unique id of the state. + Save the state of the physics simulator and returns the unique id of the state. + :param state_id: The used specified unique id representing the state. + :param use_same_id: If the same id should be used for the state. :return: The unique id representing the state. """ pass @@ -723,7 +1020,7 @@ def save_physics_simulator_state(self) -> int: @abstractmethod def remove_physics_simulator_state(self, state_id: int) -> None: """ - Removes the state of the physics simulator with the given id. + Remove the state of the physics simulator with the given id. :param state_id: The unique id representing the state. """ @@ -732,7 +1029,7 @@ def remove_physics_simulator_state(self, state_id: int) -> None: @abstractmethod def restore_physics_simulator_state(self, state_id: int) -> None: """ - Restores the objects and environment state in the physics simulator according to + Restore the objects and environment state in the physics simulator according to the given state using the unique state id. :param state_id: The unique id representing the state. @@ -744,7 +1041,7 @@ def get_images_for_target(self, cam_pose: Pose, size: Optional[int] = 256) -> List[np.ndarray]: """ - Calculates the view and projection Matrix and returns 3 images: + Calculate the view and projection Matrix and returns 3 images: 1. An RGB image 2. A depth image @@ -763,7 +1060,7 @@ def register_two_objects_collision_callbacks(self, on_collision_callback: Callable, on_collision_removal_callback: Optional[Callable] = None) -> None: """ - Registers callback methods for contact between two Objects. There can be a callback for when the two Objects + Register callback methods for contact between two Objects. There can be a callback for when the two Objects get in contact and, optionally, for when they are not in contact anymore. :param object_a: An object in the World @@ -775,80 +1072,115 @@ def register_two_objects_collision_callbacks(self, on_collision_removal_callback) @classmethod - def add_resource_path(cls, path: str) -> None: + def get_data_directories(cls) -> List[str]: + """ + The resources directories where the objects, robots, and environments are stored. + """ + return cls.cache_manager.data_directories + + @classmethod + def add_resource_path(cls, path: str, prepend: bool = False) -> None: """ - Adds a resource path in which the World will search for files. This resource directory is searched if an + Add a resource path in which the World will search for files. This resource directory is searched if an Object is spawned only with a filename. :param path: A path in the filesystem in which to search for files. + :param prepend: Put the new path at the beginning of the list such that it is searched first. + """ + if prepend: + cls.cache_manager.data_directories = [path] + cls.cache_manager.data_directories + else: + cls.cache_manager.data_directories.append(path) + + @classmethod + def remove_resource_path(cls, path: str) -> None: + """ + Remove the given path from the data_directories list. + + :param path: The path to remove. """ - cls.data_directory.append(path) + cls.cache_manager.data_directories.remove(path) + + @classmethod + def change_cache_dir_path(cls, path: str) -> None: + """ + Change the cache directory to the given path + + :param path: The new path for the cache directory. + """ + cls.cache_manager.cache_dir = os.path.join(path, cls.conf.cache_dir_name) def get_prospection_object_for_object(self, obj: Object) -> Object: """ - Returns the corresponding object from the prospection world for a given object in the main world. + Return the corresponding object from the prospection world for a given object in the main world. If the given Object is already in the prospection world, it is returned. :param obj: The object for which the corresponding object in the prospection World should be found. :return: The corresponding object in the prospection world. """ - self.world_sync.add_obj_queue.join() - try: - return self.world_sync.object_mapping[obj] - except KeyError: - prospection_world = self if self.is_prospection_world else self.prospection_world - if obj in prospection_world.objects: - return obj - else: - raise ValueError( - f"There is no prospection object for the given object: {obj}, this could be the case if" - f" the object isn't anymore in the main (graphical) World" - f" or if the given object is already a prospection object. ") + with UseProspectionWorld(): + return self.world_sync.get_prospection_object(obj) def get_object_for_prospection_object(self, prospection_object: Object) -> Object: """ - Returns the corresponding object from the main World for a given + Return the corresponding object from the main World for a given object in the prospection world. If the given object is not in the prospection world an error will be raised. :param prospection_object: The object for which the corresponding object in the main World should be found. :return: The object in the main World. """ - object_map = self.world_sync.object_mapping - try: - return list(object_map.keys())[list(object_map.values()).index(prospection_object)] - except ValueError: - raise ValueError("The given object is not in the prospection world.") + with UseProspectionWorld(): + return self.world_sync.get_world_object(prospection_object) - def reset_world(self, remove_saved_states=True) -> None: + def remove_all_objects(self, exclude_objects: Optional[List[Object]] = None) -> None: """ - Resets the World to the state it was first spawned in. + Remove all objects from the World. + + :param exclude_objects: A list of objects that should not be removed. + """ + objs_copy = [obj for obj in self.objects] + exclude_objects = [] if exclude_objects is None else exclude_objects + [self.remove_object(obj) for obj in objs_copy if obj not in exclude_objects] + + def reset_world(self, remove_saved_states=False) -> None: + """ + Reset the World to the state it was first spawned in. All attached objects will be detached, all joints will be set to the default position of 0 and all objects will be set to the position and orientation in which they were spawned. :param remove_saved_states: If the saved states should be removed. """ - + self.restore_state(self.original_state_id) if remove_saved_states: self.remove_saved_states() - - for obj in self.objects: - obj.reset(remove_saved_states) + self.original_state_id = self.save_state() def remove_saved_states(self) -> None: """ - Removes all saved states of the World. + Remove all saved states of the World. """ - for state_id in self.saved_states: - self.remove_physics_simulator_state(state_id) + if self.conf.use_physics_simulator_state: + for state_id in self.saved_states: + self.remove_physics_simulator_state(state_id) + else: + self.remove_objects_saved_states() super().remove_saved_states() + self.original_state_id = None + + def remove_objects_saved_states(self) -> None: + """ + Remove all saved states of the objects in the World. + """ + for obj in self.objects: + obj.remove_saved_states() def update_transforms_for_objects_in_current_world(self) -> None: """ Updates transformations for all objects that are currently in :py:attr:`~pycram.world.World.current_world`. """ - curr_time = rospy.Time.now() + curr_time = Time().now() for obj in list(self.current_world.objects): obj.update_link_transforms(curr_time) @@ -883,6 +1215,12 @@ def create_visual_shape(self, visual_shape: VisualShape) -> int: :param visual_shape: The visual shape to be created, uses the VisualShape dataclass defined in world_dataclasses :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_visual_shape, visual_shape) + + def _create_visual_shape(self, visual_shape: VisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_visual_shape` + """ raise NotImplementedError def create_multi_body_from_visual_shapes(self, visual_shape_ids: List[int], pose: Pose) -> int: @@ -920,51 +1258,92 @@ def create_multi_body(self, multi_body: MultiBody) -> int: :param multi_body: The multi body to be created, uses the MultiBody dataclass defined in world_dataclasses. :return: The unique id of the created multi body. """ + return self._simulator_object_creator(self._create_multi_body, multi_body) + + def _create_multi_body(self, multi_body: MultiBody) -> int: + """ + See :py:meth:`~pycram.world.World.create_multi_body` + """ raise NotImplementedError def create_box_visual_shape(self, shape_data: BoxVisualShape) -> int: """ Creates a box visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the box visual shape to be created, uses the BoxVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the box visual shape to be created, uses the BoxVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_box_visual_shape, shape_data) + + def _create_box_visual_shape(self, shape_data: BoxVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_box_visual_shape` + """ raise NotImplementedError def create_cylinder_visual_shape(self, shape_data: CylinderVisualShape) -> int: """ Creates a cylinder visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the cylinder visual shape to be created, uses the CylinderVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the cylinder visual shape to be created, uses the + CylinderVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_cylinder_visual_shape, shape_data) + + def _create_cylinder_visual_shape(self, shape_data: CylinderVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_cylinder_visual_shape` + """ raise NotImplementedError def create_sphere_visual_shape(self, shape_data: SphereVisualShape) -> int: """ Creates a sphere visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the sphere visual shape to be created, uses the SphereVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the sphere visual shape to be created, uses the SphereVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_sphere_visual_shape, shape_data) + + def _create_sphere_visual_shape(self, shape_data: SphereVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_sphere_visual_shape` + """ raise NotImplementedError def create_capsule_visual_shape(self, shape_data: CapsuleVisualShape) -> int: """ Creates a capsule visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the capsule visual shape to be created, uses the CapsuleVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the capsule visual shape to be created, uses the + CapsuleVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_capsule_visual_shape, shape_data) + + def _create_capsule_visual_shape(self, shape_data: CapsuleVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_capsule_visual_shape` + """ raise NotImplementedError def create_plane_visual_shape(self, shape_data: PlaneVisualShape) -> int: """ Creates a plane visual shape in the physics simulator and returns the unique id of the created shape. - :param shape_data: The parameters that define the plane visual shape to be created, uses the PlaneVisualShape dataclass defined in world_dataclasses. + :param shape_data: The parameters that define the plane visual shape to be created, uses the PlaneVisualShape + dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_plane_visual_shape, shape_data) + + def _create_plane_visual_shape(self, shape_data: PlaneVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_plane_visual_shape` + """ raise NotImplementedError def create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: @@ -975,6 +1354,12 @@ def create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: uses the MeshVisualShape dataclass defined in world_dataclasses. :return: The unique id of the created shape. """ + return self._simulator_object_creator(self._create_mesh_visual_shape, shape_data) + + def _create_mesh_visual_shape(self, shape_data: MeshVisualShape) -> int: + """ + See :py:meth:`~pycram.world.World.create_mesh_visual_shape` + """ raise NotImplementedError def add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, size: float = 0.1, @@ -985,14 +1370,26 @@ def add_text(self, text: str, position: List[float], orientation: Optional[List[ :param text: The text to be added. :param position: The position of the text in the world. - :param orientation: By default, debug text will always face the camera, automatically rotation. By specifying a text orientation (quaternion), the orientation will be fixed in world space or local space (when parent is specified). + :param orientation: By default, debug text will always face the camera, automatically rotation. By specifying a + text orientation (quaternion), the orientation will be fixed in world space or local space + (when parent is specified). :param size: The size of the text. :param color: The color of the text. - :param life_time: The lifetime in seconds of the text to remain in the world, if 0 the text will remain in the world until it is removed manually. + :param life_time: The lifetime in seconds of the text to remain in the world, if 0 the text will remain in the + world until it is removed manually. :param parent_object_id: The id of the object to which the text should be attached. :param parent_link_id: The id of the link to which the text should be attached. :return: The id of the added text. """ + return self._simulator_object_creator(self._add_text, text, position, orientation, size, color, life_time, + parent_object_id, parent_link_id) + + def _add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, size: float = 0.1, + color: Optional[Color] = Color(), life_time: Optional[float] = 0, + parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: + """ + See :py:meth:`~pycram.world.World.add_text` + """ raise NotImplementedError def remove_text(self, text_id: Optional[int] = None) -> None: @@ -1001,6 +1398,12 @@ def remove_text(self, text_id: Optional[int] = None) -> None: :param text_id: The id of the text to be removed. """ + self._simulator_object_remover(self._remove_text, text_id) + + def _remove_text(self, text_id: Optional[int] = None) -> None: + """ + See :py:meth:`~pycram.world.World.remove_text` + """ raise NotImplementedError def enable_joint_force_torque_sensor(self, obj: Object, fts_joint_idx: int) -> None: @@ -1027,7 +1430,7 @@ def disable_joint_force_torque_sensor(self, obj: Object, joint_id: int) -> None: def get_joint_reaction_force_torque(self, obj: Object, joint_id: int) -> List[float]: """ - Returns the joint reaction forces and torques of the specified joint. + Get the joint reaction forces and torques of the specified joint. :param obj: The object in which the joint is located. :param joint_id: The id of the joint for which the force torque should be returned. @@ -1037,7 +1440,7 @@ def get_joint_reaction_force_torque(self, obj: Object, joint_id: int) -> List[fl def get_applied_joint_motor_torque(self, obj: Object, joint_id: int) -> float: """ - Returns the applied torque by a joint motor. + Get the applied torque by a joint motor. :param obj: The object in which the joint is located. :param joint_id: The id of the joint for which the applied motor torque should be returned. @@ -1045,6 +1448,85 @@ def get_applied_joint_motor_torque(self, obj: Object, joint_id: int) -> float: """ raise NotImplementedError + def pause_world_sync(self) -> None: + """ + Pause the world synchronization. + """ + self.world_sync.sync_lock.acquire() + + def resume_world_sync(self) -> None: + """ + Resume the world synchronization. + """ + self.world_sync.sync_lock.release() + + def add_vis_axis(self, pose: Pose) -> int: + """ + Add a visual axis to the world. + + :param pose: The pose of the visual axis. + :return: The id of the added visual axis. + """ + return self._simulator_object_creator(self._add_vis_axis, pose) + + def _add_vis_axis(self, pose: Pose) -> None: + """ + See :py:meth:`~pycram.world.World.add_vis_axis` + """ + logwarn(f"Visual axis is not supported in {self.__class__.__name__}") + + def remove_vis_axis(self) -> None: + """ + Remove the visual axis from the world. + """ + self._simulator_object_remover(self._remove_vis_axis) + + def _remove_vis_axis(self) -> None: + """ + See :py:meth:`~pycram.world.World.remove_vis_axis` + """ + logwarn(f"Visual axis is not supported in {self.__class__.__name__}") + + def _simulator_object_creator(self, creator_func: Callable, *args, **kwargs) -> int: + """ + Create an object in the physics simulator and returns the created object id. + + :param creator_func: The function that creates the object in the physics simulator. + :param args: The arguments for the creator function. + :param kwargs: The keyword arguments for the creator function. + :return: The created object id. + """ + obj_id = creator_func(*args, **kwargs) + self.update_simulator_state_id_in_original_state() + return obj_id + + def _simulator_object_remover(self, remover_func: Callable, *args, **kwargs) -> None: + """ + Remove an object from the physics simulator. + + :param remover_func: The function that removes the object from the physics simulator. + :param args: The arguments for the remover function. + :param kwargs: The keyword arguments for the remover function. + """ + remover_func(*args, **kwargs) + self.update_simulator_state_id_in_original_state() + + def update_simulator_state_id_in_original_state(self, use_same_id: bool = False) -> None: + """ + Update the simulator state id in the original state if use_physics_simulator_state is True in the configuration. + + :param use_same_id: If the same id should be used for the state. + """ + if self.conf.use_physics_simulator_state: + self.original_state.simulator_state_id = self.save_physics_simulator_state(use_same_id=use_same_id) + + @property + def original_state(self) -> WorldState: + """ + The saved original state of the world. + """ + return self.saved_states[self.original_state_id] + def __del__(self): self.exit() @@ -1059,35 +1541,21 @@ class UseProspectionWorld: NavigateAction.Action([[1, 0, 0], [0, 0, 0, 1]]).perform() """ - WAIT_TIME_FOR_ADDING_QUEUE = 20 - """ - The time in seconds to wait for the adding queue to be ready. - """ def __init__(self): self.prev_world: Optional[World] = None # The previous world is saved to restore it after the with block is exited. - def sync_worlds(self): - """ - Synchronizes the state of the prospection world with the main world. - """ - for world_obj, prospection_obj in World.current_world.world_sync.object_mapping.items(): - prospection_obj.current_state = world_obj.current_state - def __enter__(self): """ This method is called when entering the with block, it will set the current world to the prospection world """ + # Please do not edit this function, it works as it is now! if not World.current_world.is_prospection_world: - time.sleep(self.WAIT_TIME_FOR_ADDING_QUEUE * World.current_world.simulation_time_step) - # blocks until the adding queue is ready - World.current_world.world_sync.add_obj_queue.join() - self.sync_worlds() - self.prev_world = World.current_world - World.current_world.world_sync.pause_sync = True World.current_world = World.current_world.prospection_world + # This is also a join statement since it is called from the main thread. + World.current_world.world_sync.sync_worlds() def __exit__(self, *args): """ @@ -1095,7 +1563,6 @@ def __exit__(self, *args): """ if self.prev_world is not None: World.current_world = self.prev_world - World.current_world.world_sync.pause_sync = False class WorldSync(threading.Thread): @@ -1103,12 +1570,15 @@ class WorldSync(threading.Thread): Synchronizes the state between the World and its prospection world. Meaning the cartesian and joint position of everything in the prospection world will be synchronized with the main World. - Adding and removing objects is done via queues, such that loading times of objects - in the prospection world does not affect the World. The class provides the possibility to pause the synchronization, this can be used if reasoning should be done in the prospection world. """ + WAIT_TIME_AS_N_SIMULATION_STEPS = 20 + """ + The time in simulation steps to wait between each iteration of the syncing loop. + """ + def __init__(self, world: World, prospection_world: World): threading.Thread.__init__(self) self.world: World = world @@ -1116,48 +1586,110 @@ def __init__(self, world: World, prospection_world: World): self.prospection_world.world_sync = self self.terminate: bool = False - self.add_obj_queue: Queue = Queue() - self.remove_obj_queue: Queue = Queue() self.pause_sync: bool = False # Maps world to prospection world objects - self.object_mapping: Dict[Object, Object] = {} + self.object_to_prospection_object_map: Dict[Object, Object] = {} + self.prospection_object_to_object_map: Dict[Object, Object] = {} self.equal_states = False + self.sync_lock: threading.Lock = threading.Lock() - def run(self, wait_time_as_n_simulation_steps: Optional[int] = 1): + def run(self): """ Main method of the synchronization, this thread runs in a loop until the terminate flag is set. While this loop runs it continuously checks the cartesian and joint position of every object in the World and updates the corresponding object in the - prospection world. When there are entries in the adding or removing queue the corresponding objects will - be added or removed in the same iteration. - - :param wait_time_as_n_simulation_steps: The time in simulation steps to wait between each iteration of - the syncing loop. + prospection world. """ while not self.terminate: - self.check_for_pause() - while not self.add_obj_queue.empty(): - obj = self.add_obj_queue.get() - # Maps the World object to the prospection world object - self.object_mapping[obj] = copy(obj) - self.add_obj_queue.task_done() - while not self.remove_obj_queue.empty(): - obj = self.remove_obj_queue.get() - # Get prospection world object reference from object mapping - prospection_obj = self.object_mapping[obj] - prospection_obj.remove() - del self.object_mapping[obj] - self.remove_obj_queue.task_done() - self.check_for_pause() - time.sleep(wait_time_as_n_simulation_steps * self.world.simulation_time_step) - - def check_for_pause(self) -> None: - """ - Checks if :py:attr:`~self.pause_sync` is true and sleeps this thread until it isn't anymore. - """ - while self.pause_sync: - time.sleep(0.1) + self.sync_lock.acquire() + if not self.terminate: + self.sync_worlds() + self.sync_lock.release() + time.sleep(WorldSync.WAIT_TIME_AS_N_SIMULATION_STEPS * self.world.simulation_time_step) + + def get_world_object(self, prospection_object: Object) -> Object: + """ + Get the corresponding object from the main World for a given object in the prospection world. + + :param prospection_object: The object for which the corresponding object in the main World should be found. + :return: The object in the main World. + """ + try: + return self.prospection_object_to_object_map[prospection_object] + except KeyError: + if prospection_object in self.world.objects: + return prospection_object + raise WorldObjectNotFound(prospection_object) + + def get_prospection_object(self, obj: Object) -> Object: + """ + Get the corresponding object from the prospection world for a given object in the main world. + + :param obj: The object for which the corresponding object in the prospection World should be found. + :return: The corresponding object in the prospection world. + """ + try: + return self.object_to_prospection_object_map[obj] + except KeyError: + if obj in self.prospection_world.objects: + return obj + raise ProspectionObjectNotFound(obj) + + def sync_worlds(self): + """ + Syncs the prospection world with the main world by adding and removing objects and synchronizing their states. + """ + self.remove_objects_not_in_world() + self.add_objects_not_in_prospection_world() + self.prospection_object_to_object_map = {prospection_obj: obj for obj, prospection_obj in + self.object_to_prospection_object_map.items()} + self.sync_objects_states() + + def remove_objects_not_in_world(self): + """ + Removes all objects that are not in the main world from the prospection world. + """ + obj_map_copy = copy(self.object_to_prospection_object_map) + [self.remove_object(obj) for obj in obj_map_copy.keys() if obj not in self.world.objects] + + def add_objects_not_in_prospection_world(self): + """ + Adds all objects that are in the main world but not in the prospection world to the prospection world. + """ + obj_map_copy = copy(self.object_to_prospection_object_map) + [self.add_object(obj) for obj in self.world.objects if obj not in obj_map_copy.keys()] + + def add_object(self, obj: Object) -> None: + """ + Adds an object to the prospection world. + + :param obj: The object to be added. + """ + self.object_to_prospection_object_map[obj] = obj.copy_to_prospection() + + def remove_object(self, obj: Object) -> None: + """ + Removes an object from the prospection world. + + :param obj: The object to be removed. + """ + prospection_obj = self.object_to_prospection_object_map[obj] + prospection_obj.remove() + del self.object_to_prospection_object_map[obj] + + def sync_objects_states(self) -> None: + """ + Synchronizes the state of all objects in the World with the prospection world. + """ + # Set the pose of the prospection objects to the pose of the world objects + obj_pose_dict = {prospection_obj: obj.pose + for obj, prospection_obj in self.object_to_prospection_object_map.items()} + self.world.prospection_world.reset_multiple_objects_base_poses(obj_pose_dict) + for obj, prospection_obj in self.object_to_prospection_object_map.items(): + prospection_obj.set_attachments(obj.attachments) + prospection_obj.link_states = obj.link_states + prospection_obj.joint_states = obj.joint_states def check_for_equal(self) -> bool: """ @@ -1167,7 +1699,12 @@ def check_for_equal(self) -> bool: :return: True if both Worlds have the same state, False otherwise. """ eql = True - for obj, prospection_obj in self.object_mapping.items(): + prospection_names = self.prospection_world.get_object_names() + eql = eql and [name in prospection_names for name in self.world.get_object_names()] + eql = eql and len(prospection_names) == len(self.world.get_object_names()) + if not eql: + return False + for obj, prospection_obj in self.object_to_prospection_object_map.items(): eql = eql and obj.get_pose().dist(prospection_obj.get_pose()) < 0.001 self.equal_states = eql return eql diff --git a/src/pycram/datastructures/world_entity.py b/src/pycram/datastructures/world_entity.py new file mode 100644 index 000000000..1e7c61e06 --- /dev/null +++ b/src/pycram/datastructures/world_entity.py @@ -0,0 +1,77 @@ +from abc import ABC, abstractmethod + +from typing_extensions import TYPE_CHECKING, Dict + +from .dataclasses import State + +if TYPE_CHECKING: + from ..datastructures.world import World + + +class StateEntity: + """ + The StateEntity class is used to store the state of an object or the physics simulator. This is used to save and + restore the state of the World. + """ + + def __init__(self): + self._saved_states: Dict[int, State] = {} + + @property + def saved_states(self) -> Dict[int, State]: + """ + :return: the saved states of this entity. + """ + return self._saved_states + + def save_state(self, state_id: int) -> int: + """ + Saves the state of this entity with the given state id. + + :param state_id: The unique id of the state. + """ + self._saved_states[state_id] = self.current_state + return state_id + + @property + @abstractmethod + def current_state(self) -> State: + """ + :return: The current state of this entity. + """ + pass + + @current_state.setter + @abstractmethod + def current_state(self, state: State) -> None: + """ + Sets the current state of this entity. + + :param state: The new state of this entity. + """ + pass + + def restore_state(self, state_id: int) -> None: + """ + Restores the state of this entity from a saved state using the given state id. + + :param state_id: The unique id of the state. + """ + self.current_state = self.saved_states[state_id] + + def remove_saved_states(self) -> None: + """ + Removes all saved states of this entity. + """ + self._saved_states = {} + + +class WorldEntity(StateEntity, ABC): + """ + A data class that represents an entity of the world, such as an object or a link. + """ + + def __init__(self, _id: int, world: 'World'): + StateEntity.__init__(self) + self.id = _id + self.world: 'World' = world diff --git a/src/pycram/description.py b/src/pycram/description.py index 0ad05d7f1..689bec2f5 100644 --- a/src/pycram/description.py +++ b/src/pycram/description.py @@ -1,34 +1,36 @@ from __future__ import annotations import logging +import os import pathlib from abc import ABC, abstractmethod -import rospy +from .ros.data_types import Time +import trimesh from geometry_msgs.msg import Point, Quaternion -from typing_extensions import Tuple, Union, Any, List, Optional, Dict, TYPE_CHECKING +from typing_extensions import Tuple, Union, Any, List, Optional, Dict, TYPE_CHECKING, Self, deprecated +from .datastructures.dataclasses import JointState, AxisAlignedBoundingBox, Color, LinkState, VisualShape from .datastructures.enums import JointType -from .local_transformer import LocalTransformer from .datastructures.pose import Pose, Transform -from .datastructures.world import WorldEntity -from .datastructures.dataclasses import JointState, AxisAlignedBoundingBox, Color, LinkState, VisualShape +from .datastructures.world_entity import WorldEntity +from .failures import ObjectDescriptionNotFound +from .local_transformer import LocalTransformer if TYPE_CHECKING: from .world_concepts.world_object import Object class EntityDescription(ABC): - """ - A class that represents a description of an entity. This can be a link, joint or object description. + A description of an entity. This can be a link, joint or object description. """ @property @abstractmethod def origin(self) -> Pose: """ - Returns the origin of this entity. + :return: the origin of this entity. """ pass @@ -36,14 +38,14 @@ def origin(self) -> Pose: @abstractmethod def name(self) -> str: """ - Returns the name of this entity. + :return: the name of this entity. """ pass class LinkDescription(EntityDescription): """ - A class that represents a link description of an object. + A link description of an object. """ def __init__(self, parsed_link_description: Any): @@ -53,7 +55,7 @@ def __init__(self, parsed_link_description: Any): @abstractmethod def geometry(self) -> Union[VisualShape, None]: """ - Returns the geometry type of the collision element of this link. + The geometry type of the collision element of this link. """ pass @@ -63,8 +65,13 @@ class JointDescription(EntityDescription): A class that represents the description of a joint. """ - def __init__(self, parsed_joint_description: Any): + def __init__(self, parsed_joint_description: Optional[Any] = None, is_virtual: bool = False): + """ + :param parsed_joint_description: The parsed description of the joint (e.g. from urdf or mjcf file). + :param is_virtual: True if the joint is virtual (i.e. not a physically existing joint), False otherwise. + """ self.parsed_description = parsed_joint_description + self.is_virtual: Optional[bool] = is_virtual @property @abstractmethod @@ -86,8 +93,6 @@ def axis(self) -> Point: @abstractmethod def has_limits(self) -> bool: """ - Checks if this joint has limits. - :return: True if the joint has limits, False otherwise. """ pass @@ -120,7 +125,7 @@ def upper_limit(self) -> Union[float, None]: @property @abstractmethod - def parent_link_name(self) -> str: + def parent(self) -> str: """ :return: The name of the parent link of this joint. """ @@ -128,7 +133,7 @@ def parent_link_name(self) -> str: @property @abstractmethod - def child_link_name(self) -> str: + def child(self) -> str: """ :return: The name of the child link of this joint. """ @@ -159,6 +164,13 @@ def __init__(self, _id: int, obj: Object): WorldEntity.__init__(self, _id, obj.world) self.object: Object = obj + @property + def object_name(self) -> str: + """ + The name of the object to which this joint belongs. + """ + return self.object.name + @property @abstractmethod def pose(self) -> Pose: @@ -170,7 +182,7 @@ def pose(self) -> Pose: @property def transform(self) -> Transform: """ - Returns the transform of this entity. + The transform of this entity. :return: The transform of this entity. """ @@ -180,7 +192,7 @@ def transform(self) -> Transform: @abstractmethod def tf_frame(self) -> str: """ - Returns the tf frame of this entity. + The tf frame of this entity. :return: The tf frame of this entity. """ @@ -196,7 +208,7 @@ def object_id(self) -> int: class Link(ObjectEntity, LinkDescription, ABC): """ - Represents a link of an Object in the World. + A link of an Object in the World. """ def __init__(self, _id: int, link_description: LinkDescription, obj: Object): @@ -204,7 +216,48 @@ def __init__(self, _id: int, link_description: LinkDescription, obj: Object): LinkDescription.__init__(self, link_description.parsed_description) self.local_transformer: LocalTransformer = LocalTransformer() self.constraint_ids: Dict[Link, int] = {} - self._update_pose() + self._current_pose: Optional[Pose] = None + self.update_pose() + + def set_pose(self, pose: Pose) -> None: + """ + Set the pose of this link to the given pose. + NOTE: This will move the entire object such that the link is at the given pose, it will not consider any joints + that can allow the link to be at the given pose. + + :param pose: The target pose for this link. + """ + self.object.set_pose(self.get_object_pose_given_link_pose(pose)) + + def get_object_pose_given_link_pose(self, pose): + """ + Get the object pose given the link pose, which could be a hypothetical link pose to see what would be the object + pose in that case (assuming that the object itself moved not the joints). + + :param pose: The link pose. + """ + return (pose.to_transform(self.tf_frame) * self.get_transform_to_root_link()).to_pose() + + def get_pose_given_object_pose(self, pose): + """ + Get the link pose given the object pose, which could be a hypothetical object pose to see what would be the link + pose in that case (assuming that the object itself moved not the joints). + + :param pose: The object pose. + """ + return (pose.to_transform(self.object.tf_frame) * self.get_transform_from_root_link()).to_pose() + + def get_transform_from_root_link(self) -> Transform: + """ + Return the transformation from the root link of the object to this link. + """ + return self.get_transform_from_link(self.object.root_link) + + def get_transform_to_root_link(self) -> Transform: + """ + Return the transformation from this link to the root link of the object. + """ + return self.get_transform_to_link(self.object.root_link) @property def current_state(self) -> LinkState: @@ -212,25 +265,28 @@ def current_state(self) -> LinkState: @current_state.setter def current_state(self, link_state: LinkState) -> None: - self.constraint_ids = link_state.constraint_ids + if self.current_state != link_state: + self.constraint_ids = link_state.constraint_ids - def add_fixed_constraint_with_link(self, child_link: 'Link') -> int: + def add_fixed_constraint_with_link(self, child_link: Self, + child_to_parent_transform: Optional[Transform] = None) -> int: """ - Adds a fixed constraint between this link and the given link, used to create attachments for example. + Add a fixed constraint between this link and the given link, to create attachments for example. :param child_link: The child link to which a fixed constraint should be added. + :param child_to_parent_transform: The transformation between the two links. :return: The unique id of the constraint. """ - constraint_id = self.world.add_fixed_constraint(self, - child_link, - child_link.get_transform_from_link(self)) + if child_to_parent_transform is None: + child_to_parent_transform = child_link.get_transform_to_link(self) + constraint_id = self.world.add_fixed_constraint(self, child_link, child_to_parent_transform) self.constraint_ids[child_link] = constraint_id child_link.constraint_ids[self] = constraint_id return constraint_id def remove_constraint_with_link(self, child_link: 'Link') -> None: """ - Removes the constraint between this link and the given link. + Remove the constraint between this link and the given link. :param child_link: The child link of the constraint that should be removed. """ @@ -240,17 +296,22 @@ def remove_constraint_with_link(self, child_link: 'Link') -> None: del child_link.constraint_ids[self] @property - def is_root(self) -> bool: + def is_only_link(self) -> bool: + """ + :return: True if this link is the only link, False otherwise. """ - Returns whether this link is the root link of the object. + return self.object.has_one_link + @property + def is_root(self) -> bool: + """ :return: True if this link is the root link, False otherwise. """ return self.object.get_root_link_id() == self.id - def update_transform(self, transform_time: Optional[rospy.Time] = None) -> None: + def update_transform(self, transform_time: Optional[Time] = None) -> None: """ - Updates the transformation of this link at the given time. + Update the transformation of this link at the given time. :param transform_time: The time at which the transformation should be updated. """ @@ -258,8 +319,6 @@ def update_transform(self, transform_time: Optional[rospy.Time] = None) -> None: def get_transform_to_link(self, link: 'Link') -> Transform: """ - Returns the transformation from this link to the given link. - :param link: The link to which the transformation should be returned. :return: A Transform object with the transformation from this link to the given link. """ @@ -267,8 +326,6 @@ def get_transform_to_link(self, link: 'Link') -> Transform: def get_transform_from_link(self, link: 'Link') -> Transform: """ - Returns the transformation from the given link to this link. - :param link: The link from which the transformation should be returned. :return: A Transform object with the transformation from the given link to this link. """ @@ -276,8 +333,6 @@ def get_transform_from_link(self, link: 'Link') -> Transform: def get_pose_wrt_link(self, link: 'Link') -> Pose: """ - Returns the pose of this link with respect to the given link. - :param link: The link with respect to which the pose should be returned. :return: A Pose object with the pose of this link with respect to the given link. """ @@ -285,8 +340,6 @@ def get_pose_wrt_link(self, link: 'Link') -> Pose: def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of this link. - :return: An AxisAlignedBoundingBox object with the axis aligned bounding box of this link. """ return self.world.get_link_axis_aligned_bounding_box(self) @@ -294,8 +347,6 @@ def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: @property def position(self) -> Point: """ - The getter for the position of the link relative to the world frame. - :return: A Point object containing the position of the link relative to the world frame. """ return self.pose.position @@ -303,8 +354,6 @@ def position(self) -> Point: @property def position_as_list(self) -> List[float]: """ - The getter for the position of the link relative to the world frame as a list. - :return: A list containing the position of the link relative to the world frame. """ return self.pose.position_as_list() @@ -312,8 +361,6 @@ def position_as_list(self) -> List[float]: @property def orientation(self) -> Quaternion: """ - The getter for the orientation of the link relative to the world frame. - :return: A Quaternion object containing the orientation of the link relative to the world frame. """ return self.pose.orientation @@ -321,55 +368,58 @@ def orientation(self) -> Quaternion: @property def orientation_as_list(self) -> List[float]: """ - The getter for the orientation of the link relative to the world frame as a list. - :return: A list containing the orientation of the link relative to the world frame. """ return self.pose.orientation_as_list() - def _update_pose(self) -> None: + def update_pose(self) -> None: """ - Updates the current pose of this link from the world. + Update the current pose of this link from the world. """ self._current_pose = self.world.get_link_pose(self) @property def pose(self) -> Pose: """ - The pose of the link relative to the world frame. - :return: A Pose object containing the pose of the link relative to the world frame. """ + if self.world.conf.update_poses_from_sim_on_get: + self.update_pose() return self._current_pose @property def pose_as_list(self) -> List[List[float]]: """ - The pose of the link relative to the world frame as a list. - :return: A list containing the position and orientation of the link relative to the world frame. """ return self.pose.to_list() def get_origin_transform(self) -> Transform: """ - Returns the transformation between the link frame and the origin frame of this link. + :return: the transformation between the link frame and the origin frame of this link. """ return self.origin.to_transform(self.tf_frame) @property def color(self) -> Color: """ - The getter for the rgba_color of this link. - :return: A Color object containing the rgba_color of this link. """ return self.world.get_link_color(self) + @deprecated("Use color property setter instead") + def set_color(self, color: Color) -> None: + """ + Set the color of this link, could be rgb or rgba. + + :param color: The color as a list of floats, either rgb or rgba. + """ + self.color = color + @color.setter def color(self, color: Color) -> None: """ - The setter for the color of this link, could be rgb or rgba. + Set the color of this link, could be rgb or rgba. :param color: The color as a list of floats, either rgb or rgba. """ @@ -401,8 +451,8 @@ def __hash__(self): class RootLink(Link, ABC): """ - Represents the root link of an Object in the World. - It differs from the normal AbstractLink class in that the pose ande the tf_frame is the same as that of the object. + The root link of an Object in the World. + This differs from the normal AbstractLink class in that the pose and the tf_frame is the same as that of the object. """ def __init__(self, obj: Object): @@ -411,12 +461,12 @@ def __init__(self, obj: Object): @property def tf_frame(self) -> str: """ - Returns the tf frame of the root link, which is the same as the tf frame of the object. + :return: the tf frame of the root link, which is the same as the tf frame of the object. """ return self.object.tf_frame - def _update_pose(self) -> None: - self._current_pose = self.object.get_pose() + def update_pose(self) -> None: + self._current_pose = self.world.get_object_pose(self.object) def __copy__(self): return RootLink(self.object) @@ -424,14 +474,16 @@ def __copy__(self): class Joint(ObjectEntity, JointDescription, ABC): """ - Represents a joint of an Object in the World. + Represent a joint of an Object in the World. """ def __init__(self, _id: int, joint_description: JointDescription, - obj: Object): + obj: Object, is_virtual: Optional[bool] = False): ObjectEntity.__init__(self, _id, obj) - JointDescription.__init__(self, joint_description.parsed_description) + JointDescription.__init__(self, joint_description.parsed_description, is_virtual) + self.acceptable_error = (self.world.conf.revolute_joint_position_tolerance if self.type == JointType.REVOLUTE + else self.world.conf.prismatic_joint_position_tolerance) self._update_position() @property @@ -444,38 +496,34 @@ def tf_frame(self) -> str: @property def pose(self) -> Pose: """ - Returns the pose of this joint. The pose is the pose of the child link of this joint. - - :return: The pose of this joint. + :return: The pose of this joint. The pose is the pose of the child link of this joint. """ return self.child_link.pose def _update_position(self) -> None: """ - Updates the current position of the joint from the physics simulator. + Update the current position of the joint from the physics simulator. """ self._current_position = self.world.get_joint_position(self) @property def parent_link(self) -> Link: """ - Returns the parent link of this joint. - :return: The parent link as a AbstractLink object. """ - return self.object.get_link(self.parent_link_name) + return self.object.get_link(self.parent) @property def child_link(self) -> Link: """ - Returns the child link of this joint. - :return: The child link as a AbstractLink object. """ - return self.object.get_link(self.child_link_name) + return self.object.get_link(self.child) @property def position(self) -> float: + if self.world.conf.update_poses_from_sim_on_get: + self._update_position() return self._current_position def reset_position(self, position: float) -> None: @@ -484,8 +532,6 @@ def reset_position(self, position: float) -> None: def get_object_id(self) -> int: """ - Returns the id of the object to which this joint belongs. - :return: The integer id of the object to which this joint belongs. """ return self.object.id @@ -493,8 +539,8 @@ def get_object_id(self) -> int: @position.setter def position(self, joint_position: float) -> None: """ - Sets the position of the given joint to the given joint pose. If the pose is outside the joint limits, - an error will be printed. However, the joint will be set either way. + Set the position of the given joint to the given joint pose. If the pose is outside the joint limits, + issue a warning. However, set the joint either way. :param joint_position: The target pose for this joint """ @@ -524,16 +570,16 @@ def get_applied_motor_torque(self) -> float: @property def current_state(self) -> JointState: - return JointState(self.position) + return JointState(self.position, self.acceptable_error) @current_state.setter def current_state(self, joint_state: JointState) -> None: """ - Updates the current state of this joint from the given joint state if the position is different. + Update the current state of this joint from the given joint state if the position is different. :param joint_state: The joint state to update from. """ - if self._current_position != joint_state.position: + if self.current_state != joint_state: self.position = joint_state.position def __copy__(self): @@ -547,12 +593,11 @@ def __hash__(self): class ObjectDescription(EntityDescription): - """ A class that represents the description of an object. """ - mesh_extensions: Tuple[str] = (".obj", ".stl", ".dae") + mesh_extensions: Tuple[str] = (".obj", ".stl", ".dae", ".ply") """ The file extensions of the mesh files that can be used to generate a description file. """ @@ -570,23 +615,107 @@ def __init__(self, path: Optional[str] = None): """ :param path: The path of the file to update the description data from. """ + + self._links: Optional[List[LinkDescription]] = None + self._joints: Optional[List[JointDescription]] = None + self._link_map: Optional[Dict[str, Any]] = None + self._joint_map: Optional[Dict[str, Any]] = None + if path: self.update_description_from_file(path) else: self._parsed_description = None + self.virtual_joint_names: List[str] = [] + + @property + @abstractmethod + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + pass + + @property + @abstractmethod + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + pass + + @property + @abstractmethod + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + pass + + @property + @abstractmethod + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + pass + + def is_joint_virtual(self, name: str) -> bool: + """ + :param name: The name of the joint. + :return: True if the joint is virtual, False otherwise. + """ + return name in self.virtual_joint_names + + @abstractmethod + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Add a joint to this object. + + :param name: The name of the joint. + :param child: The name of the child link. + :param joint_type: The type of the joint. + :param axis: The axis of the joint. + :param parent: The name of the parent link. + :param origin: The origin of the joint. + :param lower_limit: The lower limit of the joint. + :param upper_limit: The upper limit of the joint. + :param is_virtual: True if the joint is virtual, False otherwise. + """ + pass + def update_description_from_file(self, path: str) -> None: """ - Updates the description of this object from the file at the given path. + Update the description of this object from the file at the given path. :param path: The path of the file to update from. """ self._parsed_description = self.load_description(path) + def update_description_from_string(self, description_string: str) -> None: + """ + Update the description of this object from the given description string. + + :param description_string: The description string to update from. + """ + self._parsed_description = self.load_description_from_string(description_string) + + def load_description_from_string(self, description_string: str) -> Any: + """ + Load the description from the given string. + + :param description_string: The description string to load from. + """ + raise NotImplementedError + @property def parsed_description(self) -> Any: """ - Return the object parsed from the description file. + :return: The object parsed from the description file. """ return self._parsed_description @@ -600,46 +729,74 @@ def parsed_description(self, parsed_description: Any): @abstractmethod def load_description(self, path: str) -> Any: """ - Loads the description from the file at the given path. + Load the description from the file at the given path. :param path: The path to the source file, if only a filename is provided then the resources directories will be searched. """ pass - def generate_description_from_file(self, path: str, name: str, extension: str) -> str: + def generate_description_from_file(self, path: str, name: str, extension: str, save_path: str, + scale_mesh: Optional[float] = None) -> None: """ - Generates and preprocesses the description from the file at the given path and returns the preprocessed - description as a string. + Generate and preprocess the description from the file at the given path and save the preprocessed + description. The generated description will be saved at the given save path. :param path: The path of the file to preprocess. :param name: The name of the object. :param extension: The file extension of the file to preprocess. - :return: The processed description string. + :param save_path: The path to save the generated description file. + :param scale_mesh: The scale of the mesh. + :raises ObjectDescriptionNotFound: If the description file could not be found/read. """ - description_string = None if extension in self.mesh_extensions: - description_string = self.generate_from_mesh_file(path, name) + if extension == ".ply": + mesh = trimesh.load(path) + path = path.replace(extension, ".obj") + if scale_mesh is not None: + mesh.apply_scale(scale_mesh) + mesh.export(path) + self.generate_from_mesh_file(path, name, save_path=save_path) elif extension == self.get_file_extension(): - description_string = self.generate_from_description_file(path) + self.generate_from_description_file(path, save_path=save_path) else: try: # Using the description from the parameter server - description_string = self.generate_from_parameter_server(path) + self.generate_from_parameter_server(path, save_path=save_path) except KeyError: - logging.warning(f"Couldn't find dile data in the ROS parameter server") - if description_string is None: - logging.error(f"Could not find file with path {path} in the resources directory nor" - f" in the ros parameter server.") - raise FileNotFoundError + logging.warning(f"Couldn't find file data in the ROS parameter server") - return description_string + if not self.check_description_file_exists_and_can_be_read(save_path): + raise ObjectDescriptionNotFound(name, path, extension) - def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: str) -> str: + @staticmethod + def check_description_file_exists_and_can_be_read(path: str) -> bool: """ - Returns the file name of the description file. + Check if the description file exists at the given path. + :param path: The path to the description file. + :return: True if the file exists, False otherwise. + """ + exists = os.path.exists(path) + if exists: + with open(path, "r") as file: + exists = bool(file.read()) + return exists + + @staticmethod + def write_description_to_file(description_string: str, save_path: str) -> None: + """ + Write the description string to the file at the given path. + + :param description_string: The description string to write. + :param save_path: The path of the file to write to. + """ + with open(save_path, "w") as file: + file.write(description_string) + + def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: str) -> str: + """ :param path_object: The path object of the description file or the mesh file. :param extension: The file extension of the description file or the mesh file. :param object_name: The name of the object. @@ -656,36 +813,39 @@ def get_file_name(self, path_object: pathlib.Path, extension: str, object_name: @classmethod @abstractmethod - def generate_from_mesh_file(cls, path: str, name: str) -> str: + def generate_from_mesh_file(cls, path: str, name: str, save_path: str) -> None: """ - Generates a description file from one of the mesh types defined in the mesh_extensions and - returns the path of the generated file. + Generate a description file from one of the mesh types defined in the mesh_extensions and + return the path of the generated file. The generated file will be saved at the given save_path. :param path: The path to the .obj file. :param name: The name of the object. - :return: The path of the generated description file. + :param save_path: The path to save the generated description file. """ pass @classmethod @abstractmethod - def generate_from_description_file(cls, path: str) -> str: + def generate_from_description_file(cls, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: """ - Preprocesses the given file and returns the preprocessed description string. + Preprocess the given file and return the preprocessed description string. The preprocessed description will be + saved at the given save_path. :param path: The path of the file to preprocess. - :return: The preprocessed description string. + :param save_path: The path to save the preprocessed description file. + :param make_mesh_paths_absolute: Whether to make the mesh paths absolute. """ pass @classmethod @abstractmethod - def generate_from_parameter_server(cls, name: str) -> str: + def generate_from_parameter_server(cls, name: str, save_path: str) -> None: """ - Preprocesses the description from the ROS parameter server and returns the preprocessed description string. + Preprocess the description from the ROS parameter server and return the preprocessed description string. + The preprocessed description will be saved at the given save_path. :param name: The name of the description on the parameter server. - :return: The preprocessed description string. + :param save_path: The path to save the preprocessed description file. """ pass @@ -697,12 +857,11 @@ def links(self) -> List[LinkDescription]: """ pass - @abstractmethod def get_link_by_name(self, link_name: str) -> LinkDescription: """ :return: The link description with the given name. """ - pass + return self.link_map[link_name] @property @abstractmethod @@ -712,12 +871,11 @@ def joints(self) -> List[JointDescription]: """ pass - @abstractmethod def get_joint_by_name(self, joint_name: str) -> JointDescription: """ :return: The joint description with the given name. """ - pass + return self.joint_map[joint_name] @abstractmethod def get_root(self) -> str: @@ -726,8 +884,15 @@ def get_root(self) -> str: """ pass + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + """ + raise NotImplementedError + @abstractmethod - def get_chain(self, start_link_name: str, end_link_name: str) -> List[str]: + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: """ :return: the chain of links from 'start_link_name' to 'end_link_name'. """ diff --git a/src/pycram/designator.py b/src/pycram/designator.py index d5a8878b2..4790bac19 100644 --- a/src/pycram/designator.py +++ b/src/pycram/designator.py @@ -5,12 +5,13 @@ from abc import ABC, abstractmethod from inspect import isgenerator, isgeneratorfunction -import rospy +from .ros.logging import logwarn, loginfo + try: import owlready2 except ImportError: owlready2 = None - rospy.logwarn("owlready2 is not installed!") + logwarn("owlready2 is not installed!") from sqlalchemy.orm.session import Session @@ -365,9 +366,8 @@ def ground(self) -> Any: def get_slots(self) -> List[str]: """ - Returns a list of all slots of this description. Can be used for inspecting different descriptions and debugging. - - :return: A list of all slots. + :return: a list of all slots of this description. Can be used for inspecting different descriptions and + debugging. """ return list(self.__dict__.keys()) @@ -376,7 +376,7 @@ def copy(self) -> DesignatorDescription: def get_default_ontology_concept(self) -> owlready2.Thing | None: """ - Returns the first element of ontology_concept_holders if there is, else None + :return: The first element of ontology_concept_holders if there is, else None """ return self.ontology_concept_holders[0].ontology_concept if self.ontology_concept_holders else None @@ -575,7 +575,7 @@ def to_sql(self) -> ORMObjectDesignator: :return: The created ORM object. """ - return ORMObjectDesignator(self.obj_type, self.name) + return ORMObjectDesignator(name=self.name, obj_type=self.obj_type) def insert(self, session: Session) -> ORMObjectDesignator: """ @@ -597,8 +597,6 @@ def insert(self, session: Session) -> ORMObjectDesignator: def frozen_copy(self) -> 'ObjectDesignatorDescription.Object': """ - Returns a copy of this designator containing only the fields. - :return: A copy containing only the fields of this class. The WorldObject attached to this pycram object is not copied. The _pose gets set to a method that statically returns the pose of the object when this method was called. """ result = ObjectDesignatorDescription.Object(self.name, self.obj_type, None) @@ -633,7 +631,7 @@ def __repr__(self): def special_knowledge_adjustment_pose(self, grasp: str, pose: Pose) -> Pose: """ - Returns the adjusted target pose based on special knowledge for "grasp front". + Get the adjusted target pose based on special knowledge for "grasp front". :param grasp: From which side the object should be grasped :param pose: Pose at which the object should be grasped, before adjustment @@ -652,7 +650,7 @@ def special_knowledge_adjustment_pose(self, grasp: str, pose: Pose) -> Pose: pose_in_object.pose.position.x += value[0] pose_in_object.pose.position.y += value[1] pose_in_object.pose.position.z += value[2] - rospy.loginfo("Adjusted target pose based on special knowledge for grasp: %s", grasp) + loginfo("Adjusted target pose based on special knowledge for grasp: %s", grasp) return pose_in_object return pose diff --git a/src/pycram/designators/action_designator.py b/src/pycram/designators/action_designator.py index d14ad54f2..2c77c29fe 100644 --- a/src/pycram/designators/action_designator.py +++ b/src/pycram/designators/action_designator.py @@ -8,17 +8,14 @@ import numpy as np from sqlalchemy.orm import Session from tf import transformations -from typing_extensions import Any, List, Union, Callable, Optional, Type - -import rospy +from typing_extensions import List, Union, Callable, Optional, Type from .location_designator import CostmapLocation from .motion_designator import MoveJointsMotion, MoveGripperMotion, MoveArmJointsMotion, MoveTCPMotion, MoveMotion, \ LookingMotion, DetectingMotion, OpeningMotion, ClosingMotion from .object_designator import ObjectDesignatorDescription, BelieveObject, ObjectPart from ..local_transformer import LocalTransformer -from ..plan_failures import ObjectUnfetchable, ReachabilityFailure -# from ..robot_descriptions import robot_description +from ..failures import ObjectUnfetchable, ReachabilityFailure from ..robot_description import RobotDescription from ..tasktree import with_tree @@ -537,11 +534,11 @@ def to_sql(self) -> Action: :return: An instance of the ORM equivalent of the action with the parameters set """ - # get all class parameters (ignore inherited ones) + # get all class parameters class_variables = {key: value for key, value in vars(self).items() if key in inspect.getfullargspec(self.__init__).args} - # get all orm class parameters (ignore inherited ones) + # get all orm class parameters orm_class_variables = inspect.getfullargspec(self.orm_class.__init__).args # list of parameters that will be passed to the ORM class. If the name does not match the orm_class equivalent @@ -565,11 +562,11 @@ def insert(self, session: Session, **kwargs) -> Action: action = super().insert(session) - # get all class parameters (ignore inherited ones) + # get all class parameters class_variables = {key: value for key, value in vars(self).items() if key in inspect.getfullargspec(self.__init__).args} - # get all orm class parameters (ignore inherited ones) + # get all orm class parameters orm_class_variables = inspect.getfullargspec(self.orm_class.__init__).args # loop through all class parameters and insert them into the session unless they are already added by the ORM @@ -716,10 +713,13 @@ class PickUpActionPerformable(ActionAbstract): """ orm_class: Type[ActionAbstract] = field(init=False, default=ORMPickUpAction) - @with_tree - def perform(self) -> None: + def __post_init__(self): + super(ActionAbstract, self).__post_init__() # Store the object's data copy at execution self.object_at_execution = self.object_designator.frozen_copy() + + @with_tree + def perform(self) -> None: robot = World.robot # Retrieve object and robot from designators object = self.object_designator.world_object @@ -775,6 +775,17 @@ def perform(self) -> None: # Remove the vis axis from the world World.current_world.remove_vis_axis() + #TODO find a way to use object_at_execution instead of object_designator in the automatic orm mapping in ActionAbstract + def to_sql(self) -> Action: + return ORMPickUpAction(arm=self.arm, grasp=self.grasp) + + def insert(self, session: Session, **kwargs) -> Action: + action = super(ActionAbstract, self).insert(session) + action.object = self.object_at_execution.insert(session) + + session.add(action) + return action + @dataclass class PlaceActionPerformable(ActionAbstract): @@ -860,7 +871,7 @@ def perform(self) -> None: ParkArmsActionPerformable(Arms.BOTH).perform() pickup_loc = CostmapLocation(target=self.object_designator, reachable_for=robot_desig.resolve(), reachable_arm=self.arm) - # Tries to find a pick-up posotion for the robot that uses the given arm + # Tries to find a pick-up position for the robot that uses the given arm pickup_pose = None for pose in pickup_loc: if self.arm in pose.reachable_arms: diff --git a/src/pycram/designators/location_designator.py b/src/pycram/designators/location_designator.py index d59f60a30..5bd50bcfb 100644 --- a/src/pycram/designators/location_designator.py +++ b/src/pycram/designators/location_designator.py @@ -178,6 +178,7 @@ def __iter__(self): if self.visible_for or self.reachable_for: robot_object = self.visible_for.world_object if self.visible_for else self.reachable_for.world_object test_robot = World.current_world.get_prospection_object_for_object(robot_object) + with UseProspectionWorld(): for maybe_pose in PoseGenerator(final_map, number_of_samples=600): res = True @@ -249,7 +250,6 @@ def __iter__(self) -> Location: final_map = occupancy + gaussian - test_robot = World.current_world.get_prospection_object_for_object(self.robot) # Find a Joint of type prismatic which is above the handle in the URDF tree @@ -283,8 +283,10 @@ def __iter__(self) -> Location: valid_goal, arms_goal = reachability_validator(maybe_pose, test_robot, goal_pose, allowed_collision={test_robot: hand_links}) - if valid_init and valid_goal: - yield self.Location(maybe_pose, list(set(arms_init).intersection(set(arms_goal)))) + arms_list = list(set(arms_init).intersection(set(arms_goal))) + + if valid_init and valid_goal and len(arms_list) > 0: + yield self.Location(maybe_pose, arms_list) class SemanticCostmapLocation(LocationDesignatorDescription): diff --git a/src/pycram/designators/motion_designator.py b/src/pycram/designators/motion_designator.py index b008c8561..589d5f9ad 100644 --- a/src/pycram/designators/motion_designator.py +++ b/src/pycram/designators/motion_designator.py @@ -5,20 +5,21 @@ from .object_designator import ObjectDesignatorDescription, ObjectPart, RealObject from ..designator import ResolutionError from ..orm.base import ProcessMetaData -from ..plan_failures import PerceptionObjectNotFound +from ..failures import PerceptionObjectNotFound from ..process_module import ProcessModuleManager from ..orm.motion_designator import (MoveMotion as ORMMoveMotion, MoveTCPMotion as ORMMoveTCPMotion, LookingMotion as ORMLookingMotion, MoveGripperMotion as ORMMoveGripperMotion, DetectingMotion as ORMDetectingMotion, OpeningMotion as ORMOpeningMotion, ClosingMotion as ORMClosingMotion, Motion as ORMMotionDesignator) -from ..datastructures.enums import ObjectType, Arms, GripperState +from ..datastructures.enums import ObjectType, Arms, GripperState, ExecutionType from typing_extensions import Dict, Optional, get_type_hints from ..datastructures.pose import Pose from ..tasktree import with_tree from ..designator import BaseMotion + @dataclass class MoveMotion(BaseMotion): """ @@ -160,9 +161,9 @@ def perform(self): if not world_object: raise PerceptionObjectNotFound( f"Could not find an object with the type {self.object_type} in the FOV of the robot") - if ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return RealObject.Object(world_object.name, world_object.obj_type, - world_object, world_object.get_pose()) + world_object, world_object.get_pose()) return ObjectDesignatorDescription.Object(world_object.name, world_object.obj_type, world_object) @@ -313,3 +314,26 @@ def insert(self, session: Session, *args, **kwargs) -> ORMClosingMotion: session.add(motion) return motion + + +@dataclass +class TalkingMotion(BaseMotion): + """ + Talking Motion, lets the robot say a sentence. + """ + + cmd: str + """ + Talking Motion, let the robot say a sentence. + """ + + @with_tree + def perform(self): + pm_manager = ProcessModuleManager.get_manager() + return pm_manager.talk().execute(self) + + def to_sql(self) -> ORMMotionDesignator: + pass + + def insert(self, session: Session, *args, **kwargs) -> ORMMotionDesignator: + pass diff --git a/src/pycram/designators/object_designator.py b/src/pycram/designators/object_designator.py index 85d090499..303dc7939 100644 --- a/src/pycram/designators/object_designator.py +++ b/src/pycram/designators/object_designator.py @@ -3,17 +3,19 @@ import dataclasses from typing_extensions import List, Optional, Callable, TYPE_CHECKING import sqlalchemy.orm +from ..datastructures.enums import ObjectType from ..datastructures.world import World from ..world_concepts.world_object import Object as WorldObject from ..designator import ObjectDesignatorDescription from ..orm.base import ProcessMetaData from ..orm.object_designator import (BelieveObject as ORMBelieveObject, ObjectPart as ORMObjectPart) from ..datastructures.pose import Pose -from ..external_interfaces.robokudo import query +from ..external_interfaces.robokudo import * if TYPE_CHECKING: import owlready2 + class BelieveObject(ObjectDesignatorDescription): """ Description for Objects that are only believed in. @@ -26,7 +28,7 @@ class Object(ObjectDesignatorDescription.Object): """ def to_sql(self) -> ORMBelieveObject: - return ORMBelieveObject(self.obj_type, self.name) + return ORMBelieveObject(name=self.name, obj_type=self.obj_type) def insert(self, session: sqlalchemy.orm.session.Session) -> ORMBelieveObject: metadata = ProcessMetaData().insert(session) @@ -49,7 +51,7 @@ class Object(ObjectDesignatorDescription.Object): part_pose: Pose def to_sql(self) -> ORMObjectPart: - return ORMObjectPart(self.obj_type, self.name) + return ORMObjectPart(obj_type=self.obj_type, name=self.name) def insert(self, session: sqlalchemy.orm.session.Session) -> ORMObjectPart: metadata = ProcessMetaData().insert(session) @@ -63,7 +65,7 @@ def insert(self, session: sqlalchemy.orm.session.Session) -> ORMObjectPart: def __init__(self, names: List[str], part_of: ObjectDesignatorDescription.Object, - type: Optional[str] = None, + type: Optional[ObjectType] = None, resolver: Optional[Callable] = None): """ Describing the relationship between an object and a specific part of it. @@ -79,7 +81,7 @@ def __init__(self, names: List[str], if not part_of: raise AttributeError("part_of cannot be None.") - self.type: Optional[str] = type + self.type: Optional[ObjectType] = type self.names: Optional[List[str]] = names self.part_of = part_of @@ -138,6 +140,8 @@ def __init__(self, names: List[str], types: List[str], self.timestamps: List[float] = timestamps +@DeprecationWarning +# Depricated class this will be done differently class RealObject(ObjectDesignatorDescription): """ Object designator representing an object in the real world, when resolving this object designator description ] @@ -155,11 +159,11 @@ class Object(ObjectDesignatorDescription.Object): def __init__(self, names: Optional[List[str]] = None, types: Optional[List[str]] = None, world_object: WorldObject = None, resolver: Optional[Callable] = None): """ - - :param names: - :param types: + + :param names: + :param types: :param world_object: - :param resolver: + :param resolver: """ super().__init__(resolver) self.types: Optional[List[str]] = types diff --git a/src/pycram/designators/specialized_designators/action/__init__.py b/src/pycram/designators/specialized_designators/action/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/designators/specialized_designators/action/dual_arm_pickup_action.py b/src/pycram/designators/specialized_designators/action/dual_arm_pickup_action.py new file mode 100644 index 000000000..6347b0a72 --- /dev/null +++ b/src/pycram/designators/specialized_designators/action/dual_arm_pickup_action.py @@ -0,0 +1,78 @@ +from typing_extensions import List, Union, Optional +from numpy.linalg import norm +from numpy import array +from geometry_msgs.msg import Vector3 + +from owlready2 import Thing + +from ...action_designator import PickUpAction, PickUpActionPerformable +from ....local_transformer import LocalTransformer +from ....datastructures.world import World +from ....datastructures.pose import Pose, Transform +from ....datastructures.enums import Arms, Grasp +from ....robot_description import RobotDescription, KinematicChainDescription +from ....designator import ObjectDesignatorDescription +from ....ros.logging import loginfo + + +class DualArmPickupAction(PickUpAction): + """ + Specialization version of the PickUpAction designator which uses heuristics to solve for a dual pickup solution. + """ + + def __init__(self, + object_designator_description: Union[ObjectDesignatorDescription, ObjectDesignatorDescription.Object], + grasps: List[Grasp], resolver=None, + ontology_concept_holders: Optional[List[Thing]] = None): + """ + Specialized version of the PickUpAction designator which uses heuristics to solve for a dual pickup problem. The + designator will choose the arm which is closest to the object that is to be picked up. + + :param object_designator_description: List of object designator which should be picked up + :param grasps: List of possible grasps which should be used for the pickup + :param resolver: Optional specialized_designators that returns a performable designator with elements from the + lists of possible parameter + :param ontology_concept_holders: List of ontology concepts that the action is categorized as or associated with + """ + super().__init__(object_designator_description, + arms=[Arms.LEFT, Arms.RIGHT], + grasps=grasps, + resolver=resolver, + ontology_concept_holders=ontology_concept_holders) + + self.object_designator_description: Union[ + ObjectDesignatorDescription, ObjectDesignatorDescription.Object] = object_designator_description + + left_gripper = RobotDescription.current_robot_description.get_arm_chain(Arms.LEFT) + right_gripper = RobotDescription.current_robot_description.get_arm_chain(Arms.RIGHT) + self.gripper_list: List[KinematicChainDescription] = [left_gripper, right_gripper] + + + def ground(self) -> PickUpActionPerformable: + if isinstance(self.object_designator_description, ObjectDesignatorDescription.Object): + obj_desig = self.object_designator_description + else: + obj_desig = self.object_designator_description.resolve() + + loginfo("Calculating closest gripper to object {}".format(obj_desig.name)) + + local_transformer = LocalTransformer() + + object_pose: Pose = obj_desig.world_object.pose + distances = [] + # Iterate over possible grippers + for gripper in self.gripper_list: + # Object pose in gripper frame + gripper_frame = World.robot.get_link_tf_frame(gripper.get_tool_frame()) + + object_T_gripper: Pose = local_transformer.transform_pose(object_pose, gripper_frame) + object_V_gripper: Vector3 = object_T_gripper.pose.position # translation vector + distance = norm(array([object_V_gripper.x, object_V_gripper.y, object_V_gripper.z])) + loginfo(f"Distance between {gripper} and {obj_desig.name}: {distance}") + distances.append(distance) + + min_index = distances.index(min(distances)) + winner = self.gripper_list[min_index] + loginfo(f"Winner is {winner.arm_type.name} with distance {min(distances):.2f}") + + return PickUpActionPerformable(object_designator=obj_desig, arm=winner.arm_type, grasp=self.grasps[0]) diff --git a/src/pycram/designators/specialized_designators/location/giskard_location.py b/src/pycram/designators/specialized_designators/location/giskard_location.py index 1400a8e63..de0d0a6e8 100644 --- a/src/pycram/designators/specialized_designators/location/giskard_location.py +++ b/src/pycram/designators/specialized_designators/location/giskard_location.py @@ -53,7 +53,7 @@ def __iter__(self) -> CostmapLocation.Location: prospection_robot = World.current_world.get_prospection_object_for_object(World.robot) with UseProspectionWorld(): - prospection_robot.set_joint_positions(robot_joint_states) + prospection_robot.set_multiple_joint_positions(robot_joint_states) prospection_robot.set_pose(pose) gripper_pose = prospection_robot.get_link_pose(chain.get_tool_frame()) diff --git a/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py b/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py index fe6668bee..606969672 100644 --- a/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py +++ b/src/pycram/designators/specialized_designators/probabilistic/probabilistic_action.py @@ -1,8 +1,8 @@ import numpy as np import tqdm -from probabilistic_model.probabilistic_circuit.distributions import GaussianDistribution, SymbolicDistribution -from probabilistic_model.probabilistic_circuit.probabilistic_circuit import ProbabilisticCircuit, \ - DecomposableProductUnit +from probabilistic_model.probabilistic_circuit.nx.distributions import GaussianDistribution, SymbolicDistribution +from probabilistic_model.probabilistic_circuit.nx.probabilistic_circuit import ProbabilisticCircuit, \ + ProductUnit from probabilistic_model.utils import MissingDict from random_events.interval import * from random_events.product_algebra import Event, SimpleEvent @@ -19,7 +19,7 @@ from ....designator import ActionDesignatorDescription, ObjectDesignatorDescription from ....local_transformer import LocalTransformer from ....orm.views import PickUpWithContextView -from ....plan_failures import ObjectUnreachable, PlanFailure +from ....failures import ObjectUnreachable, PlanFailure class Grasp(SetElement): @@ -124,7 +124,7 @@ def create_model_with_center(self) -> ProbabilisticCircuit: """ Create a fully factorized gaussian at the center of the map. """ - centered_model = DecomposableProductUnit() + centered_model = ProductUnit() centered_model.add_subcircuit(GaussianDistribution(self.relative_x, 0., np.sqrt(self.variance))) centered_model.add_subcircuit(GaussianDistribution(self.relative_y, 0., np.sqrt(self.variance))) @@ -206,7 +206,7 @@ def sample_to_action(self, sample: List) -> MoveAndPickUpPerformable: pose = Pose(position, frame=self.object_designator.world_object.tf_frame) standing_position = LocalTransformer().transform_pose(pose, "map") standing_position.position.z = 0 - action = MoveAndPickUpPerformable(standing_position, self.object_designator, EArms(int(arm)), EGrasp(int(grasp))) + action = MoveAndPickUpPerformable(standing_position, self.object_designator, EArms[Arms(int(arm)).name], EGrasp(int(grasp))) return action def events_from_occupancy_and_visibility_costmap(self) -> Event: @@ -300,8 +300,6 @@ def query_for_database(): def batch_rollout(self): """ Try the policy without conditioning on visibility and occupancy and count the successful tries. - - :amount: The amount of tries """ # initialize statistics diff --git a/src/pycram/external_interfaces/giskard.py b/src/pycram/external_interfaces/giskard.py index 925584fa5..9fd922866 100644 --- a/src/pycram/external_interfaces/giskard.py +++ b/src/pycram/external_interfaces/giskard.py @@ -2,17 +2,18 @@ import threading import time -import rospy import sys -import rosnode + +from ..ros.data_types import Time +from ..ros.logging import logwarn, loginfo_once +from ..ros.ros_tools import get_node_names from ..datastructures.enums import JointType, ObjectType from ..datastructures.pose import Pose -# from ..robot_descriptions import robot_description from ..datastructures.world import World from ..datastructures.dataclasses import MeshVisualShape +from ..ros.service import get_service_proxy from ..world_concepts.world_object import Object -# from ..robot_description import ManipulatorDescription from ..robot_description import RobotDescription from typing_extensions import List, Dict, Callable, Optional @@ -22,9 +23,8 @@ try: from giskardpy.python_interface.old_python_interface import OldGiskardWrapper as GiskardWrapper from giskard_msgs.msg import WorldBody, MoveResult, CollisionEntry - from giskard_msgs.srv import UpdateWorldRequest, UpdateWorld, UpdateWorldResponse, RegisterGroupResponse except ModuleNotFoundError as e: - rospy.logwarn("Failed to import Giskard messages, the real robot will not be available") + logwarn("Failed to import Giskard messages, the real robot will not be available") giskard_wrapper = None giskard_update_service = None @@ -66,28 +66,27 @@ def wrapper(*args, **kwargs): global giskard_wrapper global giskard_update_service global is_init - if is_init and "/giskard" in rosnode.get_node_names(): + if is_init and "/giskard" in get_node_names(): return func(*args, **kwargs) - elif is_init and "/giskard" not in rosnode.get_node_names(): - rospy.logwarn("Giskard node is not available anymore, could not initialize giskard interface") + elif is_init and "/giskard" not in get_node_names(): + logwarn("Giskard node is not available anymore, could not initialize giskard interface") is_init = False giskard_wrapper = None return if "giskard_msgs" not in sys.modules: - rospy.logwarn("Could not initialize the Giskard interface since the giskard_msgs are not imported") + logwarn("Could not initialize the Giskard interface since the giskard_msgs are not imported") return - if "/giskard" in rosnode.get_node_names(): + if "/giskard" in get_node_names(): giskard_wrapper = GiskardWrapper() - giskard_update_service = rospy.ServiceProxy("/giskard/update_world", UpdateWorld) - rospy.loginfo_once("Successfully initialized Giskard interface") + giskard_update_service = get_service_proxy("/giskard/update_world", UpdateWorld) + loginfo_once("Successfully initialized Giskard interface") is_init = True else: - rospy.logwarn("Giskard is not running, could not initialize Giskard interface") + logwarn("Giskard is not running, could not initialize Giskard interface") return return func(*args, **kwargs) - return wrapper @@ -169,7 +168,7 @@ def spawn_object(object: Object) -> None: :param object: World object that should be spawned """ if len(object.link_name_to_id) == 1: - geometry = object.get_link_geometry(object.root_link_name) + geometry = object.get_link_geometry(object.root_link.name) if isinstance(geometry, MeshVisualShape): filename = geometry.file_name spawn_mesh(object.name, filename, object.get_pose()) @@ -318,7 +317,7 @@ def achieve_joint_goal(goal_poses: Dict[str, float]) -> 'MoveResult': @init_giskard_interface @thread_safe -def achieve_cartesian_goal(goal_pose: Pose, tip_link: str, root_link: str) -> 'MoveResult': +def achieve_cartesian_goal(goal_pose: Pose, tip_link: str, root_link: str, position_threshold: float = 0.02, orientation_threshold: float = 0.02) -> 'MoveResult': """ Takes a cartesian position and tries to move the tip_link to this position using the chain defined by tip_link and root_link. @@ -326,6 +325,8 @@ def achieve_cartesian_goal(goal_pose: Pose, tip_link: str, root_link: str) -> 'M :param goal_pose: The position which should be achieved with tip_link :param tip_link: The end link of the chain as well as the link which should achieve the goal_pose :param root_link: The starting link of the chain which should be used to achieve this goal + :param position_threshold: Position distance at which the goal is successfully reached + :param orientation_threshold: Orientation distance at which the goal is successfully reached :return: MoveResult message for this goal """ sync_worlds() @@ -334,8 +335,19 @@ def achieve_cartesian_goal(goal_pose: Pose, tip_link: str, root_link: str) -> 'M if par_return: return par_return - giskard_wrapper.set_cart_goal(_pose_to_pose_stamped(goal_pose), tip_link, root_link) - # giskard_wrapper.add_default_end_motion_conditions() + cart_monitor1 = giskard_wrapper.monitors.add_cartesian_pose(root_link=root_link, tip_link=tip_link, + goal_pose=_pose_to_pose_stamped(goal_pose), + position_threshold=position_threshold, orientation_threshold=orientation_threshold, + name='cart goal 1') + end_monitor = giskard_wrapper.monitors.add_local_minimum_reached(start_condition=cart_monitor1) + + giskard_wrapper.motion_goals.add_cartesian_pose(name='g1', root_link=root_link, tip_link=tip_link, + goal_pose=_pose_to_pose_stamped(goal_pose), + end_condition=cart_monitor1) + + giskard_wrapper.monitors.add_end_motion(start_condition=end_monitor) + giskard_wrapper.motion_goals.avoid_all_collisions() + giskard_wrapper.motion_goals.allow_collision(group1='gripper', group2=CollisionEntry.ALL) return giskard_wrapper.execute() @@ -578,9 +590,7 @@ def allow_gripper_collision(gripper: str) -> None: @init_giskard_interface def get_gripper_group_names() -> List[str]: """ - Returns a list of groups that are registered in giskard which have 'gripper' in their name. - - :return: The list of gripper groups + :return: The list of groups that are registered in giskard which have 'gripper' in their name. """ groups = giskard_wrapper.get_group_names() return list(filter(lambda elem: "gripper" in elem, groups)) @@ -589,7 +599,7 @@ def get_gripper_group_names() -> List[str]: @init_giskard_interface def add_gripper_groups() -> None: """ - Adds the gripper links as a group for collision avoidance. + Add the gripper links as a group for collision avoidance. :return: Response of the RegisterGroup Service """ @@ -633,7 +643,7 @@ def avoid_collisions(object1: Object, object2: Object) -> None: @init_giskard_interface def make_world_body(object: Object) -> 'WorldBody': """ - Creates a WorldBody message for a World Object. The WorldBody will contain the URDF of the World Object + Create a WorldBody message for a World Object. The WorldBody will contain the URDF of the World Object :param object: The World Object :return: A WorldBody message for the World Object @@ -656,7 +666,7 @@ def make_point_stamped(point: List[float]) -> PointStamped: :return: A PointStamped message """ msg = PointStamped() - msg.header.stamp = rospy.Time.now() + msg.header.stamp = Time().now() msg.header.frame_id = "map" msg.point.x = point[0] @@ -674,7 +684,7 @@ def make_quaternion_stamped(quaternion: List[float]) -> QuaternionStamped: :return: A QuaternionStamped message """ msg = QuaternionStamped() - msg.header.stamp = rospy.Time.now() + msg.header.stamp = Time().now() msg.header.frame_id = "map" msg.quaternion.x = quaternion[0] @@ -693,7 +703,7 @@ def make_vector_stamped(vector: List[float]) -> Vector3Stamped: :return: A Vector3Stamped message """ msg = Vector3Stamped() - msg.header.stamp = rospy.Time.now() + msg.header.stamp = Time().now() msg.header.frame_id = "map" msg.vector.x = vector[0] diff --git a/src/pycram/external_interfaces/ik.py b/src/pycram/external_interfaces/ik.py index 17ceca769..5a89a679f 100644 --- a/src/pycram/external_interfaces/ik.py +++ b/src/pycram/external_interfaces/ik.py @@ -2,7 +2,9 @@ import tf from typing_extensions import List, Union, Tuple, Dict -import rospy +from ..ros.data_types import Duration, ServiceException +from ..ros.logging import loginfo_once, logerr +from ..ros.service import get_service_proxy, wait_for_service from moveit_msgs.msg import PositionIKRequest from moveit_msgs.msg import RobotState from moveit_msgs.srv import GetPositionIK @@ -14,7 +16,7 @@ from ..local_transformer import LocalTransformer from ..datastructures.pose import Pose from ..robot_description import RobotDescription -from ..plan_failures import IKError +from ..failures import IKError from ..external_interfaces.giskard import projection_cartesian_goal, allow_gripper_collision @@ -49,7 +51,7 @@ def _make_request_msg(root_link: str, tip_link: str, target_pose: Pose, robot_ob msg_request.pose_stamped = target_pose msg_request.avoid_collisions = False msg_request.robot_state = robot_state - msg_request.timeout = rospy.Duration(secs=1000) + msg_request.timeout = Duration(1000) # msg_request.attempts = 1000 return msg_request @@ -74,15 +76,15 @@ def call_ik(root_link: str, tip_link: str, target_pose: Pose, robot_object: Obje else: ik_service = "/kdl_ik_service/get_ik" - rospy.loginfo_once(f"Waiting for IK service: {ik_service}") - rospy.wait_for_service(ik_service) + loginfo_once(f"Waiting for IK service: {ik_service}") + wait_for_service(ik_service) req = _make_request_msg(root_link, tip_link, target_pose, robot_object, joints) req.pose_stamped.header.frame_id = root_link - ik = rospy.ServiceProxy(ik_service, GetPositionIK) + ik = get_service_proxy(ik_service, GetPositionIK) try: resp = ik(req) - except rospy.ServiceException as e: + except ServiceException as e: if RobotDescription.current_robot_description.name == "pr2": raise IKError(target_pose, root_link, tip_link) else: @@ -151,7 +153,7 @@ def try_to_reach(pose_or_object: Union[Pose, Object], prospection_robot: Object, try: inv = request_ik(input_pose, prospection_robot, joints, gripper_name) except IKError as e: - rospy.logerr(f"Pose is not reachable: {e}") + logerr(f"Pose is not reachable: {e}") return None _apply_ik(prospection_robot, inv) @@ -213,7 +215,7 @@ def request_giskard_ik(target_pose: Pose, robot: Object, gripper: str) -> Tuple[ :param gripper: Name of the tool frame which should grasp, this should be at the end of the given joint chain. :return: A list of joint values. """ - rospy.loginfo_once(f"Using Giskard for full body IK") + loginfo_once(f"Using Giskard for full body IK") local_transformer = LocalTransformer() target_map = local_transformer.transform_pose(target_pose, "map") @@ -234,7 +236,7 @@ def request_giskard_ik(target_pose: Pose, robot: Object, gripper: str) -> Tuple[ robot_joint_states[joint_name] = state with UseProspectionWorld(): - prospection_robot.set_joint_positions(robot_joint_states) + prospection_robot.set_multiple_joint_positions(robot_joint_states) prospection_robot.set_pose(pose) tip_pose = prospection_robot.get_link_pose(gripper) diff --git a/src/pycram/external_interfaces/move_base.py b/src/pycram/external_interfaces/move_base.py index d443ef739..a63ed0bec 100644 --- a/src/pycram/external_interfaces/move_base.py +++ b/src/pycram/external_interfaces/move_base.py @@ -1,15 +1,16 @@ import sys -import rospy -import actionlib -import rosnode +from ..ros.action_lib import create_action_client, SimpleActionClient +from ..ros.logging import logwarn, loginfo +from ..ros.ros_tools import get_node_names + from geometry_msgs.msg import PoseStamped from typing import Callable try: from move_base_msgs.msg import MoveBaseAction, MoveBaseGoal except ModuleNotFoundError as e: - rospy.logwarn(f"Could not import MoveBase messages, Navigation interface could not be initialized") + logwarn(f"Could not import MoveBase messages, Navigation interface could not be initialized") # Global variables for shared resources @@ -17,10 +18,10 @@ is_init = False -def create_nav_action_client() -> actionlib.SimpleActionClient: +def create_nav_action_client() -> SimpleActionClient: """Creates a new action client for the move_base interface.""" - client = actionlib.SimpleActionClient("move_base", MoveBaseAction) - rospy.loginfo("Waiting for move_base action server") + client = create_action_client("move_base", MoveBaseAction) + loginfo("Waiting for move_base action server") client.wait_for_server() return client @@ -36,15 +37,15 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) if "move_base_msgs" not in sys.modules: - rospy.logwarn("Could not initialize the navigation interface: move_base_msgs not imported") + logwarn("Could not initialize the navigation interface: move_base_msgs not imported") return - if "/move_base" in rosnode.get_node_names(): + if "/move_base" in get_node_names(): nav_action_client = create_nav_action_client() - rospy.loginfo("Successfully initialized navigation interface") + loginfo("Successfully initialized navigation interface") is_init = True else: - rospy.logwarn("Move_base is not running, could not initialize navigation interface") + logwarn("Move_base is not running, could not initialize navigation interface") return return func(*args, **kwargs) @@ -59,10 +60,10 @@ def query_pose_nav(navpose: PoseStamped): global query_result def active_callback(): - rospy.loginfo("Sent query to move_base") + loginfo("Sent query to move_base") def done_callback(state, result): - rospy.loginfo("Finished moving") + loginfo("Finished moving") global query_result query_result = result diff --git a/src/pycram/external_interfaces/robokudo.py b/src/pycram/external_interfaces/robokudo.py index a1a7d618a..9a71c05bb 100644 --- a/src/pycram/external_interfaces/robokudo.py +++ b/src/pycram/external_interfaces/robokudo.py @@ -1,135 +1,168 @@ import sys -from typing_extensions import Callable +from threading import Lock, RLock +from typing import Any -import rospy -import actionlib -import rosnode +from ..ros.action_lib import create_action_client +from ..ros.logging import logwarn, loginfo, loginfo_once +from ..ros.ros_tools import get_node_names + +from geometry_msgs.msg import PointStamped +from typing_extensions import List, Callable, Optional -from ..designator import ObjectDesignatorDescription from ..datastructures.pose import Pose -from ..local_transformer import LocalTransformer -from ..datastructures.world import World -from ..datastructures.enums import ObjectType +from ..designator import ObjectDesignatorDescription try: - from robokudo_msgs.msg import ObjectDesignator as robokudo_ObjetDesignator + from robokudo_msgs.msg import ObjectDesignator as robokudo_ObjectDesignator from robokudo_msgs.msg import QueryAction, QueryGoal, QueryResult except ModuleNotFoundError as e: - rospy.logwarn(f"Could not import RoboKudo messages, RoboKudo interface could not be initialized") + logwarn("Failed to import Robokudo messages, the real robot will not be available") + +is_init = False + +number_of_par_goals = 0 +robokudo_lock = Lock() +robokudo_rlock = RLock() +with robokudo_rlock: + par_threads = {} + par_motion_goal = {} + + +def thread_safe(func: Callable) -> Callable: + """ + Adds thread safety to a function via a decorator. This uses the robokudo_lock + + :param func: Function that should be thread safe + :return: A function with thread safety + """ + + def wrapper(*args, **kwargs): + with robokudo_rlock: + return func(*args, **kwargs) -robokudo_action_client = None + return wrapper def init_robokudo_interface(func: Callable) -> Callable: """ - Tries to import the RoboKudo messages and with that initialize the RoboKudo interface. + Checks if the ROS messages are available and if Robokudo is running, if that is the case the interface will be + initialized. + + :param func: Function this decorator should be wrapping + :return: A callable function which initializes the interface and then calls the wrapped function """ + def wrapper(*args, **kwargs): - global robokudo_action_client - topics = list(map(lambda x: x[0], rospy.get_published_topics())) + global is_init + if is_init and "/robokudo" in get_node_names(): + return func(*args, **kwargs) + elif is_init and "/robokudo" not in get_node_names(): + logwarn("Robokudo node is not available anymore, could not initialize robokudo interface") + is_init = False + return + if "robokudo_msgs" not in sys.modules: - rospy.logwarn("Could not initialize the RoboKudo interface since the robokudo_msgs are not imported") + logwarn("Could not initialize the Robokudo interface since the robokudo_msgs are not imported") return - if "/robokudo" in rosnode.get_node_names(): - robokudo_action_client = create_robokudo_action_client() - rospy.loginfo("Successfully initialized robokudo interface") + if "/robokudo" in get_node_names(): + loginfo_once("Successfully initialized Robokudo interface") + is_init = True else: - rospy.logwarn("RoboKudo is not running, could not initialize RoboKudo interface") + logwarn("Robokudo is not running, could not initialize Robokudo interface") return - return func(*args, **kwargs) + return wrapper -def create_robokudo_action_client() -> Callable: - """ - Creates a new action client for the RoboKudo query interface and returns a function encapsulating the action client. - The returned function can be called with an ObjectDesigantor as parameter and returns the result of the action client. - - :return: A callable function encapsulating the action client - """ - client = actionlib.SimpleActionClient('robokudo/query', QueryAction) - rospy.loginfo("Waiting for action server") +@init_robokudo_interface +def send_query(obj_type: Optional[str] = None, region: Optional[str] = None, + attributes: Optional[List[str]] = None) -> Any: + """Generic function to send a query to RoboKudo.""" + goal = QueryGoal() + + if obj_type: + goal.obj.type = obj_type + if region: + goal.obj.location = region + if attributes: + goal.obj.attribute = attributes + + # client = actionlib.SimpleActionClient('robokudo/query', QueryAction) + client = create_action_client("robokudo/query", QueryAction) + loginfo("Waiting for action server") client.wait_for_server() - def action_client(object_desc): - global query_result - - def active_callback(): - rospy.loginfo("Send query to Robokudo") - - def done_callback(state, result): - rospy.loginfo("Finished perceiving") - global query_result - query_result = result + query_result = None - def feedback_callback(msg): - pass + def done_callback(state, result): + nonlocal query_result + query_result = result + loginfo("Query completed") - object_goal = make_query_goal_msg(object_desc) - client.send_goal(object_goal, active_cb=active_callback, done_cb=done_callback, feedback_cb=feedback_callback) - wait = client.wait_for_result() - return query_result + client.send_goal(goal, done_cb=done_callback) + client.wait_for_result() + return query_result - return action_client +@init_robokudo_interface +def query_object(obj_desc: ObjectDesignatorDescription) -> dict: + """Query RoboKudo for an object that fits the description.""" + goal = QueryGoal() + goal.obj.uid = str(id(obj_desc)) + goal.obj.type = str(obj_desc.types[0].name) -def msg_from_obj_desig(obj_desc: ObjectDesignatorDescription) -> 'robokudo_ObjetDesignator': - """ - Creates a RoboKudo Object designator from a PyCRAM Object Designator description + result = send_query(obj_type=goal.obj.type) - :param obj_desc: The PyCRAM Object designator that should be converted - :return: The RobotKudo Object Designator for the given PyCRAM designator - """ - obj_msg = robokudo_ObjetDesignator() - obj_msg.uid = str(id(obj_desc)) - obj_msg.obj_type = obj_desc.types[0] # For testing purposes + pose_candidates = {} + if result and result.res: + for i in range(len(result.res[0].pose)): + pose = Pose.from_pose_stamped(result.res[0].pose[i]) + source = result.res[0].pose_source[0] + pose_candidates[source] = pose + return pose_candidates - return obj_msg +@init_robokudo_interface +def query_human() -> PointStamped: + """Query RoboKudo for human detection and return the detected human's pose.""" + result = send_query(obj_type='human') + if result: + return result # Assuming result is of type PointStamped or similar. + return None -def make_query_goal_msg(obj_desc: ObjectDesignatorDescription) -> 'QueryGoal': - """ - Creates a QueryGoal message from a PyCRAM Object designator description for the use of Querying RobotKudo. - :param obj_desc: The PyCRAM object designator description that should be converted - :return: The RoboKudo QueryGoal for the given object designator description - """ - goal_msg = QueryGoal() - goal_msg.obj.uid = str(id(obj_desc)) - goal_msg.obj.obj_type = str(obj_desc.types[0].name) # For testing purposes - if ObjectType.JEROEN_CUP == obj_desc.types[0]: - goal_msg.obj.color.append("blue") - elif ObjectType.BOWL == obj_desc.types[0]: - goal_msg.obj.color.append("red") - return goal_msg +@init_robokudo_interface +def stop_query(): + """Stop any ongoing query to RoboKudo.""" + #client = actionlib.SimpleActionClient('robokudo/query', QueryAction) + client = create_action_client('robokudo/query', QueryAction) + client.wait_for_server() + client.cancel_all_goals() + loginfo("Cancelled current RoboKudo query goal") @init_robokudo_interface -def query(object_desc: ObjectDesignatorDescription) -> ObjectDesignatorDescription.Object: - """ - Sends a query to RoboKudo to look for an object that fits the description given by the Object designator description. - For sending the query to RoboKudo a simple action client will be created and the Object designator description is - sent as a goal. +def query_specific_region(region: str) -> Any: + """Query RoboKudo to scan a specific region.""" + return send_query(region=region) - :param object_desc: The object designator description which describes the object that should be perceived - :return: An object designator for the found object, if there was an object that fitted the description. - """ - query_result = robokudo_action_client(object_desc) - pose_candidates = {} - if query_result.res == []: - rospy.logwarn("No suitable object could be found") - return - for i in range(0, len(query_result.res[0].pose)): - pose = Pose.from_pose_stamped(query_result.res[0].pose[i]) - pose.frame = World.current_world.robot.get_link_tf_frame(pose.frame) # TODO: pose.frame is a link name? - source = query_result.res[0].poseSource[i] - - lt = LocalTransformer() - pose = lt.transform_pose(pose, "map") +@init_robokudo_interface +def query_human_attributes() -> Any: + """Query RoboKudo for human attributes like brightness of clothes, headgear, and gender.""" + return send_query(obj_type='human', attributes=["attributes"]) - pose_candidates[source] = pose - return pose_candidates +@init_robokudo_interface +def query_waving_human() -> Pose: + """Query RoboKudo for detecting a waving human.""" + result = send_query(obj_type='human') + if result and result.res: + try: + pose = Pose.from_pose_stamped(result.res[0].pose[0]) + return pose + except IndexError: + pass + return None diff --git a/src/pycram/external_interfaces/tmc.py b/src/pycram/external_interfaces/tmc.py new file mode 100644 index 000000000..daf3384e0 --- /dev/null +++ b/src/pycram/external_interfaces/tmc.py @@ -0,0 +1,60 @@ +from typing_extensions import Optional + +from ..datastructures.enums import GripperState +from ..designators.motion_designator import MoveGripperMotion, TalkingMotion +from ..ros.logging import loginfo +from ..ros.publisher import create_publisher +from ..ros.data_types import Rate + +is_init = False + + +def init_tmc_interface(): + global is_init + if is_init: + return + try: + from tmc_control_msgs.msg import GripperApplyEffortActionGoal + from tmc_msgs.msg import Voice + is_init = True + loginfo("Successfully initialized tmc interface") + except ModuleNotFoundError as e: + logwarn(f"Could not import TMC messages, tmc interface could not be initialized") + + +def tmc_gripper_control(designator: MoveGripperMotion, topic_name: Optional[str] = '/hsrb/gripper_controller/grasp/goal'): + """ + Publishes a message to the gripper controller to open or close the gripper for the HSR. + + :param designator: The designator containing the motion to be executed + :param topic_name: The topic name to publish the message to + """ + if (designator.motion == GripperState.OPEN): + pub_gripper = create_publisher(topic_name, GripperApplyEffortActionGoal, 10) + rate = Rate(10) + msg = GripperApplyEffortActionGoal() + msg.goal.effort = 0.8 + pub_gripper.publish(msg) + + elif (designator.motion == GripperState.CLOSE): + pub_gripper = create_publisher(topic_name, GripperApplyEffortActionGoal, 10) + rate = Rate(10) + msg = GripperApplyEffortActionGoal() + msg.goal.effort = -0.8 + pub_gripper.publish(msg) + + +def tmc_talk(designator: TalkingMotion, topic_name: Optional[str] = '/talk_request'): + """ + Publishes a sentence to the talk_request topic of the HSRB robot + + :param designator: The designator containing the sentence to be spoken + :param topic_name: The topic name to publish the sentence to + """ + pub = create_publisher(topic_name, Voice, 10) + texttospeech = Voice() + # language 1 = english (0 = japanese) + texttospeech.language = 1 + texttospeech.sentence = designator.cmd + + pub.publish(texttospeech) diff --git a/src/pycram/failure_handling.py b/src/pycram/failure_handling.py index 8fb266282..1c53061a4 100644 --- a/src/pycram/failure_handling.py +++ b/src/pycram/failure_handling.py @@ -1,8 +1,12 @@ +from .datastructures.enums import State from .designator import DesignatorDescription -from .plan_failures import PlanFailure +from .failures import PlanFailure +from threading import Lock +from typing_extensions import Union, Tuple, Any, List +from .language import Language, Monitor -class FailureHandling: +class FailureHandling(Language): """ Base class for failure handling mechanisms in automated systems or workflows. @@ -11,11 +15,12 @@ class FailureHandling: to be extended by subclasses that implement specific failure handling behaviors. """ - def __init__(self, designator_description: DesignatorDescription): + def __init__(self, designator_description: Union[DesignatorDescription, Monitor]): """ Initializes a new instance of the FailureHandling class. - :param designator_description: The description or context of the task or process for which the failure handling is being set up. + :param Union[DesignatorDescription, Monitor] designator_description: The description or context of the task + or process for which the failure handling is being set up. """ self.designator_description = designator_description @@ -37,15 +42,10 @@ class Retry(FailureHandling): This class represents a specific failure handling strategy where the system attempts to retry a failed action a certain number of times before giving up. - - Attributes: - max_tries (int): The maximum number of attempts to retry the action. - - Inherits: - All attributes and methods from the FailureHandling class. - - Overrides: - perform(): Implements the retry logic. + """ + max_tries: int + """ + The maximum number of attempts to retry the action. """ def __init__(self, designator_description: DesignatorDescription, max_tries: int = 3): @@ -58,7 +58,7 @@ def __init__(self, designator_description: DesignatorDescription, max_tries: int super().__init__(designator_description) self.max_tries = max_tries - def perform(self): + def perform(self) -> Tuple[State, List[Any]]: """ Implementation of the retry mechanism. @@ -79,5 +79,93 @@ def perform(self): raise e +class RetryMonitor(FailureHandling): + """ + A subclass of FailureHandling that implements a retry mechanism that works with a Monitor. + This class represents a specific failure handling strategy that allows us to retry a demo that is + being monitored, in case that monitoring condition is triggered. + """ + max_tries: int + """ + The maximum number of attempts to retry the action. + """ + recovery: dict + """ + A dictionary that maps exception types to recovery actions + """ + def __init__(self, designator_description: Monitor, max_tries: int = 3, recovery: dict = None): + """ + Initializes a new instance of the RetryMonitor class. + :param Monitor designator_description: The Monitor instance to be used. + :param int max_tries: The maximum number of attempts to retry. Defaults to 3. + :param dict recovery: A dictionary that maps exception types to recovery actions. Defaults to None. + """ + super().__init__(designator_description) + self.max_tries = max_tries + self.lock = Lock() + if recovery is None: + self.recovery = {} + else: + if not isinstance(recovery, dict): + raise ValueError( + "Recovery must be a dictionary with exception types as keys and Language instances as values.") + for key, value in recovery.items(): + if not issubclass(key, BaseException): + raise TypeError("Keys in the recovery dictionary must be exception types.") + if not isinstance(value, Language): + raise TypeError("Values in the recovery dictionary must be instances of the Language class.") + self.recovery = recovery + + def perform(self) -> Tuple[State, List[Any]]: + """ + This method attempts to perform the Monitor + plan specified in the designator_description. If the action + fails, it is retried up to max_tries times. If all attempts fail, the last exception is raised. In every + loop, we need to clear the kill_event, and set all relevant 'interrupted' variables to False, to make sure + the Monitor and plan are executed properly again. + + :raises PlanFailure: If all retry attempts fail. + + :return: The state of the execution performed, as well as a flattened list of the + results, in the correct order + """ + + def reset_interrupted(child): + child.interrupted = False + try: + for sub_child in child.children: + reset_interrupted(sub_child) + except AttributeError: + pass + + def flatten(result): + flattened_list = [] + if result: + for item in result: + if isinstance(item, list): + flattened_list.extend(item) + else: + flattened_list.append(item) + return flattened_list + return None + + status, res = None, None + with self.lock: + tries = 0 + while True: + self.designator_description.kill_event.clear() + self.designator_description.interrupted = False + for child in self.designator_description.children: + reset_interrupted(child) + try: + status, res = self.designator_description.perform() + break + except PlanFailure as e: + tries += 1 + if tries >= self.max_tries: + raise e + exception_type = type(e) + if exception_type in self.recovery: + self.recovery[exception_type].perform() + return status, flatten(res) diff --git a/src/pycram/plan_failures.py b/src/pycram/failures.py similarity index 78% rename from src/pycram/plan_failures.py rename to src/pycram/failures.py index e8ac32fe1..736adf8ee 100644 --- a/src/pycram/plan_failures.py +++ b/src/pycram/failures.py @@ -1,3 +1,12 @@ +from pathlib import Path + +from typing_extensions import TYPE_CHECKING, List + +if TYPE_CHECKING: + from .world_concepts.world_object import Object + from .datastructures.enums import JointType + + class PlanFailure(Exception): """Implementation of plan failures.""" @@ -127,8 +136,10 @@ def __init__(self, *args, **kwargs): class IKError(PlanFailure): """Thrown when no inverse kinematics solution could be found""" + def __init__(self, pose, base_frame, tip_frame): - self.message = "Position {} in frame '{}' is not reachable for end effector: '{}'".format(pose, base_frame, tip_frame) + self.message = "Position {} in frame '{}' is not reachable for end effector: '{}'".format(pose, base_frame, + tip_frame) super(IKError, self).__init__(self.message) @@ -395,3 +406,66 @@ def __init__(*args, **kwargs): class CollisionError(PlanFailure): def __init__(*args, **kwargs): super().__init__(*args, **kwargs) + + +""" +The following exceptions are used in the PyCRAM framework to handle errors related to the world and the objects in it. +They are usually related to a bug in the code or a misuse of the framework (e.g. logical errors in the code). +""" + + +class ProspectionObjectNotFound(KeyError): + def __init__(self, obj: 'Object'): + super().__init__(f"The given object {obj.name} is not in the prospection world.") + + +class WorldObjectNotFound(KeyError): + def __init__(self, obj: 'Object'): + super().__init__(f"The given object {obj.name} is not in the main world.") + + +class ObjectAlreadyExists(Exception): + def __init__(self, obj: 'Object'): + super().__init__(f"An object with the name {obj.name} already exists in the world.") + + +class ObjectDescriptionNotFound(KeyError): + def __init__(self, object_name: str, path: str, extension: str): + super().__init__(f"{object_name} with path {path} and extension {extension} is not in supported extensions, and" + f" the description data was not found on the ROS parameter server") + + +class WorldMismatchErrorBetweenObjects(Exception): + def __init__(self, obj_1: 'Object', obj_2: 'Object'): + super().__init__(f"World mismatch between the attached objects {obj_1.name} and {obj_2.name}," + f"obj_1.world: {obj_1.world}, obj_2.world: {obj_2.world}") + + +class ObjectFrameNotFoundError(KeyError): + def __init__(self, frame_name: str): + super().__init__(f"Frame {frame_name} does not belong to any of the objects in the world.") + + +class MultiplePossibleTipLinks(Exception): + def __init__(self, object_name: str, start_link: str, tip_links: List[str]): + super().__init__(f"Multiple possible tip links found for object {object_name} with start link {start_link}:" + f" {tip_links}") + + +class UnsupportedFileExtension(Exception): + def __init__(self, object_name: str, path: str): + extension = Path(path).suffix + super().__init__(f"Unsupported file extension for object {object_name} with path {path}" + f"and extension {extension}") + + +class ObjectDescriptionUndefined(Exception): + def __init__(self, object_name: str): + super().__init__(f"Object description for object {object_name} is not defined, eith a path or a description" + f"object should be provided.") + + +class UnsupportedJointType(Exception): + def __init__(self, joint_type: 'JointType'): + super().__init__(f"Unsupported joint type: {joint_type}") + diff --git a/src/pycram/helper.py b/src/pycram/helper.py index 73cf77dbc..fb2f78927 100644 --- a/src/pycram/helper.py +++ b/src/pycram/helper.py @@ -3,6 +3,13 @@ Classes: Singleton -- implementation of singleton metaclass """ +import os +from typing_extensions import Dict, Optional +import xml.etree.ElementTree as ET + +from pycram.ros.logging import logwarn + + class Singleton(type): """ Metaclass for singletons @@ -16,4 +23,94 @@ class Singleton(type): def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] \ No newline at end of file + return cls._instances[cls] + + +def parse_mjcf_actuators(file_path: str) -> Dict[str, str]: + """ + Parse the actuator elements from an MJCF file. + + :param file_path: The path to the MJCF file. + """ + tree = ET.parse(file_path) + root = tree.getroot() + + joint_actuators = {} + + # Iterate through all actuator elements + for actuator in root.findall(".//actuator/*"): + name = actuator.get('name') + joint = actuator.get('joint') + if name and joint: + joint_actuators[joint] = name + + return joint_actuators + + +def get_robot_mjcf_path(robot_relative_dir: str, robot_name: str, xml_name: Optional[str] = None) -> Optional[str]: + """ + Get the path to the MJCF file of a robot. + + :param robot_relative_dir: The relative directory of the robot in the Multiverse resources/robots directory. + :param robot_name: The name of the robot. + :param xml_name: The name of the XML file of the robot. + :return: The path to the MJCF file of the robot if it exists, otherwise None. + """ + xml_name = xml_name if xml_name is not None else robot_name + if '.xml' not in xml_name: + xml_name = xml_name + '.xml' + multiverse_resources = find_multiverse_resources_path() + try: + robot_folder = os.path.join(multiverse_resources, 'robots', robot_relative_dir, robot_name) + except TypeError: + logwarn("Multiverse resources path not found.") + return None + if multiverse_resources is not None: + list_dir = os.listdir(robot_folder) + if 'mjcf' in list_dir: + if xml_name in os.listdir(robot_folder + '/mjcf'): + return os.path.join(robot_folder, 'mjcf', xml_name) + elif xml_name in os.listdir(robot_folder): + return os.path.join(robot_folder, xml_name) + return None + + +def find_multiverse_resources_path() -> Optional[str]: + """ + :return: The path to the Multiverse resources directory. + """ + # Get the path to the Multiverse installation + multiverse_path = find_multiverse_path() + + # Check if the path to the Multiverse installation was found + if multiverse_path: + # Construct the path to the resources directory + resources_path = os.path.join(multiverse_path, 'resources') + + # Check if the resources directory exists + if os.path.exists(resources_path): + return resources_path + + return None + + +def find_multiverse_path() -> Optional[str]: + """ + :return: the path to the Multiverse installation. + """ + # Get the value of PYTHONPATH environment variable + pythonpath = os.getenv('PYTHONPATH') + multiverse_relative_path = "Multiverse/multiverse" + + # Check if PYTHONPATH is set + if pythonpath: + # Split the PYTHONPATH into individual paths using the platform-specific path separator + paths = pythonpath.split(os.pathsep) + + # Iterate through each path and check if 'Multiverse' is in it + for path in paths: + if multiverse_relative_path in path: + multiverse_path = path.split(multiverse_relative_path)[0] + return multiverse_path + multiverse_relative_path + + diff --git a/src/pycram/language.py b/src/pycram/language.py index 81612a78b..1d8180e30 100644 --- a/src/pycram/language.py +++ b/src/pycram/language.py @@ -1,16 +1,17 @@ # used for delayed evaluation of typing until python 3.11 becomes mainstream from __future__ import annotations -import time -from typing_extensions import Iterable, Optional, Callable, Dict, Any, List, Union +from queue import Queue +from typing_extensions import Iterable, Optional, Callable, Dict, Any, List, Union, Tuple from anytree import NodeMixin, Node, PreOrderIter -from pycram.datastructures.enums import State +from .datastructures.enums import State import threading from .fluent import Fluent -from .plan_failures import PlanFailure, NotALanguageExpression +from .failures import PlanFailure, NotALanguageExpression from .external_interfaces import giskard +from .ros.ros_tools import sleep class Language(NodeMixin): @@ -260,6 +261,7 @@ def __init__(self, condition: Union[Callable, Fluent] = None): """ super().__init__(None, None) self.kill_event = threading.Event() + self.exception_queue = Queue() if callable(condition): self.condition = Fluent(condition) elif isinstance(condition, Fluent): @@ -267,27 +269,43 @@ def __init__(self, condition: Union[Callable, Fluent] = None): else: raise AttributeError("The condition of a Monitor has to be a Callable or a Fluent") - def perform(self): + def perform(self) -> Tuple[State, Any]: """ Behavior of the Monitor, starts a new Thread which checks the condition and then performs the attached language expression - :return: The result of the attached language expression + :return: The state of the attached language expression, as well as a list of the results of the children """ def check_condition(): - while not self.condition.get_value() and not self.kill_event.is_set(): - time.sleep(0.1) - if self.kill_event.is_set(): - return - for child in self.children: - child.interrupt() + while not self.kill_event.is_set(): + try: + cond = self.condition.get_value() + if cond: + for child in self.children: + try: + child.interrupt() + except NotImplementedError: + pass + if isinstance(cond, type) and issubclass(cond, Exception): + self.exception_queue.put(cond) + else: + self.exception_queue.put(PlanFailure("Condition met in Monitor")) + return + except Exception as e: + self.exception_queue.put(e) + return + sleep(0.1) t = threading.Thread(target=check_condition) t.start() - res = self.children[0].perform() - self.kill_event.set() - t.join() - return res + try: + state, result = self.children[0].perform() + if not self.exception_queue.empty(): + raise self.exception_queue.get() + finally: + self.kill_event.set() + t.join() + return state, result def interrupt(self) -> None: """ @@ -303,28 +321,35 @@ class Sequential(Language): Instead, the exception is saved to a list of all exceptions thrown during execution and returned. Behaviour: - Return the state :py:attr:`~State.SUCCEEDED` *iff* all children are executed without exception. - In any other case the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` *iff* all children are executed without + exception. In any other case the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of Sequential, calls perform() on each child sequentially - :return: The state according to the behaviour described in :func:`Sequential` + :return: The state and list of results according to the behaviour described in :func:`Sequential` """ + children_return_values = [None] * len(self.children) try: - for child in self.children: + for index, child in enumerate(self.children): if self.interrupted: if threading.get_ident() in self.block_list: self.block_list.remove(threading.get_ident()) - return + return State.FAILED, children_return_values self.root.executing_thread[child] = threading.get_ident() - child.resolve().perform() + ret_val = child.resolve().perform() + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + children_return_values[index] = child_result + else: + children_return_values[index] = ret_val except PlanFailure as e: self.root.exceptions[self] = e - return State.FAILED - return State.SUCCEEDED + return State.FAILED, children_return_values + return State.SUCCEEDED, children_return_values def interrupt(self) -> None: """ @@ -343,33 +368,40 @@ class TryInOrder(Language): Instead, the exception is saved to a list of all exceptions thrown during execution and returned. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` if one or more children are executed without + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` if one or more children are executed without exception. In the case that all children could not be executed the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of TryInOrder, calls perform() on each child sequentially and catches raised exceptions. - :return: The state according to the behaviour described in :func:`TryInOrder` + :return: The state and list of results according to the behaviour described in :func:`TryInOrder` """ failure_list = [] - for child in self.children: + children_return_values = [None] * len(self.children) + for index, child in enumerate(self.children): if self.interrupted: if threading.get_ident() in self.block_list: self.block_list.remove(threading.get_ident()) - return + return State.INTERRUPTED, children_return_values try: - child.resolve().perform() + ret_val = child.resolve().perform() + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + children_return_values[index] = child_result + else: + children_return_values[index] = ret_val except PlanFailure as e: failure_list.append(e) if len(failure_list) > 0: self.root.exceptions[self] = failure_list if len(failure_list) == len(self.children): self.root.exceptions[self] = failure_list - return State.FAILED + return State.FAILED, children_return_values else: - return State.SUCCEEDED + return State.SUCCEEDED, children_return_values def interrupt(self) -> None: """ @@ -388,19 +420,27 @@ class Parallel(Language): exceptions during execution will be caught, saved to a list and returned upon end. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` *iff* all children could be executed without an exception. In any - other case the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from + each child's perform() method. The state is :py:attr:`~State.SUCCEEDED` *iff* all children could be executed without + an exception. In any other case the State :py:attr:`~State.FAILED` will be returned. + """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of Parallel, creates a new thread for each child and calls perform() of the child in the respective thread. - :return: The state according to the behaviour described in :func:`Parallel` + :return: The state and list of results according to the behaviour described in :func:`Parallel` + """ + results = [None] * len(self.children) + self.threads: List[threading.Thread] = [] + state = State.SUCCEEDED + results_lock = threading.Lock() - def lang_call(child_node): + def lang_call(child_node, index): + nonlocal state if ("DesignatorDescription" in [cls.__name__ for cls in child_node.__class__.__mro__] and self.__class__.__name__ not in self.do_not_use_giskard): if self not in giskard.par_threads.keys(): @@ -409,26 +449,39 @@ def lang_call(child_node): giskard.par_threads[self].append(threading.get_ident()) try: self.root.executing_thread[child] = threading.get_ident() - child_node.resolve().perform() + result = child_node.resolve().perform() + if isinstance(result, tuple): + child_state, child_result = result + with results_lock: + results[index] = child_result + else: + with results_lock: + results[index] = result except PlanFailure as e: + nonlocal state + with results_lock: + state = State.FAILED if self in self.root.exceptions.keys(): self.root.exceptions[self].append(e) else: self.root.exceptions[self] = [e] - for child in self.children: + for index, child in enumerate(self.children): if self.interrupted: + state = State.FAILED break - t = threading.Thread(target=lambda: lang_call(child)) + t = threading.Thread(target=lambda: lang_call(child, index)) t.start() self.threads.append(t) for thread in self.threads: thread.join() - if thread.ident in self.block_list: - self.block_list.remove(thread.ident) + with results_lock: + for thread in self.threads: + if thread.ident in self.block_list: + self.block_list.remove(thread.ident) if self in self.root.exceptions.keys() and len(self.root.exceptions[self]) != 0: - return State.FAILED - return State.SUCCEEDED + state = State.FAILED + return state, results def interrupt(self) -> None: """ @@ -448,20 +501,24 @@ class TryAll(Language): exceptions during execution will be caught, saved to a list and returned upon end. Behaviour: - Returns the State :py:attr:`~State.SUCCEEDED` if one or more children could be executed without raising an - exception. If all children fail the State :py:attr:`~State.FAILED` will be returned. + Returns a tuple containing the final state of execution (SUCCEEDED, FAILED) and a list of results from each + child's perform() method. The state is :py:attr:`~State.SUCCEEDED` if one or more children could be executed + without raising an exception. If all children fail the State :py:attr:`~State.FAILED` will be returned. """ - def perform(self) -> State: + def perform(self) -> Tuple[State, List[Any]]: """ Behaviour of TryAll, creates a new thread for each child and executes all children in their respective threads. - :return: The state according to the behaviour described in :func:`TryAll` + :return: The state and list of results according to the behaviour described in :func:`TryAll` """ + results = [None] * len(self.children) + results_lock = threading.Lock() + state = State.SUCCEEDED self.threads: List[threading.Thread] = [] failure_list = [] - def lang_call(child_node): + def lang_call(child_node, index): if ("DesignatorDescription" in [cls.__name__ for cls in child_node.__class__.__mro__] and self.__class__.__name__ not in self.do_not_use_giskard): if self not in giskard.par_threads.keys(): @@ -469,27 +526,37 @@ def lang_call(child_node): else: giskard.par_threads[self].append(threading.get_ident()) try: - child_node.resolve().perform() + result = child_node.resolve().perform() + if isinstance(result, tuple): + child_state, child_result = result + with results_lock: + results[index] = child_result + else: + with results_lock: + results[index] = result except PlanFailure as e: failure_list.append(e) if self in self.root.exceptions.keys(): self.root.exceptions[self].append(e) else: self.root.exceptions[self] = [e] - - for child in self.children: - t = threading.Thread(target=lambda: lang_call(child)) - t.start() + for index, child in enumerate(self.children): + if self.interrupted: + state = State.FAILED + break + t = threading.Thread(target=lambda: lang_call(child, index)) self.threads.append(t) + t.start() for thread in self.threads: thread.join() - if thread.ident in self.block_list: - self.block_list.remove(thread.ident) + with results_lock: + for thread in self.threads: + if thread.ident in self.block_list: + self.block_list.remove(thread.ident) if len(self.children) == len(failure_list): self.root.exceptions[self] = failure_list - return State.FAILED - else: - return State.SUCCEEDED + state = State.FAILED + return state, results def interrupt(self) -> None: """ @@ -529,9 +596,16 @@ def execute(self) -> Any: """ Execute the code with its arguments - :returns: Anything that the function associated with this object will return. + :returns: State.SUCCEEDED, and anything that the function associated with this object will return. """ - return self.function(**self.kwargs) + child_state = State.SUCCEEDED + ret_val = self.function(**self.kwargs) + if isinstance(ret_val, tuple): + child_state, child_result = ret_val + else: + child_result = ret_val + + return child_state, child_result def interrupt(self) -> None: raise NotImplementedError diff --git a/src/pycram/local_transformer.py b/src/pycram/local_transformer.py index f8aef28ac..9b5520ccb 100644 --- a/src/pycram/local_transformer.py +++ b/src/pycram/local_transformer.py @@ -1,14 +1,14 @@ import sys import logging +from .ros.data_types import Time, Duration +from .ros.logging import logerr + if 'world' in sys.modules: logging.warning("(publisher) Make sure that you are not loading this module from pycram.world.") -import rospy from tf import TransformerROS -from rospy import Duration -from geometry_msgs.msg import TransformStamped from .datastructures.pose import Pose, Transform from typing_extensions import List, Optional, Union, Iterable @@ -29,6 +29,7 @@ class LocalTransformer(TransformerROS): """ _instance = None + prospection_prefix: str = "prospection/" def __new__(cls, *args, **kwargs): if not cls._instance: @@ -65,18 +66,14 @@ def transform_to_object_frame(self, pose: Pose, target_frame = world_object.tf_frame return self.transform_pose(pose, target_frame) - def update_transforms_for_objects(self, source_object_name: str, target_object_name: str) -> None: + def update_transforms_for_objects(self, object_names: List[str]) -> None: """ Updates the transforms for objects affected by the transformation. The objects are identified by their names. - :param source_object_name: Name of the object of the source frame - :param target_object_name: Name of the object of the target frame + :param object_names: List of object names for which the transforms should be updated """ - source_object = self.world.get_object_by_name(source_object_name) - target_object = self.world.get_object_by_name(target_object_name) - for obj in {source_object, target_object}: - if obj: - obj.update_link_transforms() + objects = list(map(self.world.get_object_by_name, object_names)) + [obj.update_link_transforms() for obj in objects] def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: """ @@ -86,34 +83,51 @@ def transform_pose(self, pose: Pose, target_frame: str) -> Optional[Pose]: :param target_frame: Name of the TF frame into which the Pose should be transformed :return: A transformed pose in the target frame """ - self.update_transforms_for_objects(self.get_object_name_for_frame(pose.frame), - self.get_object_name_for_frame(target_frame)) + objects = list(map(self.get_object_name_for_frame, [pose.frame, target_frame])) + self.update_transforms_for_objects([obj for obj in objects if obj is not None]) copy_pose = pose.copy() - copy_pose.header.stamp = rospy.Time(0) - if not self.canTransform(target_frame, pose.frame, rospy.Time(0)): - rospy.logerr( - f"Can not transform pose: \n {pose}\n to frame: {target_frame}.\n Maybe try calling 'update_transforms_for_object'") + copy_pose.header.stamp = Time(0) + if not self.canTransform(target_frame, pose.frame, Time(0)): + logerr( + f"Can not transform pose: \n {pose}\n to frame: {target_frame}." + f"\n Maybe try calling 'update_transforms_for_object'") return new_pose = super().transformPose(target_frame, copy_pose) copy_pose.pose = new_pose.pose copy_pose.header.frame_id = new_pose.header.frame_id - copy_pose.header.stamp = rospy.Time.now() + copy_pose.header.stamp = Time().now() return Pose(*copy_pose.to_list(), frame=new_pose.header.frame_id) - def get_object_name_for_frame(self, frame: str) -> str: + def get_object_name_for_frame(self, frame: str) -> Optional[str]: """ - Returns the name of the object that is associated with the given frame. + Get the name of the object that is associated with the given frame. :param frame: The frame for which the object name should be returned :return: The name of the object associated with the frame """ - return frame.split("/")[0] + world = self.prospection_world if self.prospection_prefix in frame else self.world + if frame == "map": + return None + obj_name = [obj.name for obj in world.objects if frame == obj.tf_frame] + return obj_name[0] if len(obj_name) > 0 else self.get_object_name_for_link_frame(frame) + + def get_object_name_for_link_frame(self, link_frame: str) -> Optional[str]: + """ + Get the name of the object that is associated with the given link frame. + + :param link_frame: The frame of the link for which the object name should be returned + :return: The name of the object associated with the link frame + """ + world = self.prospection_world if self.prospection_prefix in link_frame else self.world + object_name = [obj.name for obj in world.objects for link in obj.links.values() + if link_frame in (link.name, link.tf_frame)] + return object_name[0] if len(object_name) > 0 else None def lookup_transform_from_source_to_target_frame(self, source_frame: str, target_frame: str, - time: Optional[rospy.rostime.Time] = None) -> Transform: + time: Optional[Time] = None) -> Transform: """ Update the transforms for all world objects then Look up for the latest known transform that transforms a point from source frame to target frame. If no time is given the last common time between the two frames is used. @@ -123,28 +137,26 @@ def lookup_transform_from_source_to_target_frame(self, source_frame: str, target :param time: Time at which the transform should be looked up :return: The transform from source_frame to target_frame """ - self.update_transforms_for_objects(self.get_object_name_for_frame(source_frame), - self.get_object_name_for_frame(target_frame)) + objects = list(map(self.get_object_name_for_frame, [source_frame, target_frame])) + self.update_transforms_for_objects([obj for obj in objects if obj is not None]) tf_time = time if time else self.getLatestCommonTime(source_frame, target_frame) translation, rotation = self.lookupTransform(source_frame, target_frame, tf_time) return Transform(translation, rotation, source_frame, target_frame) - def update_transforms(self, transforms: Iterable[Transform], time: rospy.Time = None) -> None: + def update_transforms(self, transforms: Iterable[Transform], time: Time = None) -> None: """ Updates transforms by updating the time stamps of the header of each transform. If no time is given the current time is used. """ - time = time if time else rospy.Time.now() + time = time if time else Time().now() for transform in transforms: transform.header.stamp = time self.setTransform(transform) def get_all_frames(self) -> List[str]: """ - Returns all know coordinate frames as a list with human-readable entries. - - :return: A list of all know coordinate frames. + :return: A list of all known coordinate frames as a list with human-readable entries. """ frames = self.allFramesAsString().split("\n") frames.remove("") diff --git a/src/pycram/object_descriptors/generic.py b/src/pycram/object_descriptors/generic.py new file mode 100644 index 000000000..fb2b456ec --- /dev/null +++ b/src/pycram/object_descriptors/generic.py @@ -0,0 +1,189 @@ +from typing import Optional, Tuple + +from typing_extensions import List, Any, Union, Dict + +from geometry_msgs.msg import Point + +from ..datastructures.dataclasses import VisualShape, BoxVisualShape, Color +from ..datastructures.enums import JointType +from ..datastructures.pose import Pose +from ..description import JointDescription as AbstractJointDescription, LinkDescription as AbstractLinkDescription, \ + ObjectDescription as AbstractObjectDescription + + +class NamedBoxVisualShape(BoxVisualShape): + def __init__(self, name: str, color: Color, visual_frame_position: List[float], half_extents: List[float]): + super().__init__(color, visual_frame_position, half_extents) + self._name: str = name + + @property + def name(self) -> str: + return self._name + + +class LinkDescription(AbstractLinkDescription): + + def __init__(self, name: str, visual_frame_position: List[float], half_extents: List[float], + color: Color = Color()): + super().__init__(NamedBoxVisualShape(name, color, visual_frame_position, half_extents)) + + @property + def geometry(self) -> Union[VisualShape, None]: + return self.parsed_description + + @property + def origin(self) -> Pose: + return Pose(self.parsed_description.visual_frame_position) + + @property + def name(self) -> str: + return self.parsed_description.name + + @property + def color(self) -> Color: + return self.parsed_description.rgba_color + + +class JointDescription(AbstractJointDescription): + + @property + def parent(self) -> str: + raise NotImplementedError + + @property + def child(self) -> str: + raise NotImplementedError + + @property + def type(self) -> JointType: + return JointType.UNKNOWN + + @property + def axis(self) -> Point: + return Point(0, 0, 0) + + @property + def has_limits(self) -> bool: + return False + + @property + def lower_limit(self) -> Union[float, None]: + return 0 + + @property + def upper_limit(self) -> Union[float, None]: + return 0 + + @property + def parent_link_name(self) -> str: + raise NotImplementedError + + @property + def child_link_name(self) -> str: + raise NotImplementedError + + @property + def origin(self) -> Pose: + raise NotImplementedError + + @property + def name(self) -> str: + raise NotImplementedError + + +class ObjectDescription(AbstractObjectDescription): + """ + A generic description of an object in the environment. This description can be applied to any object. + The current use case involves perceiving objects using RoboKudo and spawning them with specified size and color. + """ + + class Link(AbstractObjectDescription.Link, LinkDescription): + ... + + class RootLink(AbstractObjectDescription.RootLink, Link): + ... + + class Joint(AbstractObjectDescription.Joint, JointDescription): + ... + + def __init__(self, *args, **kwargs): + self._links = [LinkDescription(*args, **kwargs)] + + def load_description(self, path: str) -> Any: + ... + + @classmethod + def generate_from_mesh_file(cls, path: str, name: str, save_path: str) -> str: + raise NotImplementedError + + @classmethod + def generate_from_description_file(cls, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> str: + raise NotImplementedError + + @classmethod + def generate_from_parameter_server(cls, name: str, save_path: str) -> str: + raise NotImplementedError + + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + return {} + + @property + def link_map(self) -> Dict[str, LinkDescription]: + return {self._links[0].name: self._links[0]} + + @property + def joint_map(self) -> Dict[str, JointDescription]: + return {} + + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + return {} + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + ... + + @property + def shape_data(self) -> List[float]: + return self._links[0].geometry.shape_data()['halfExtents'] + + @property + def color(self) -> Color: + return self._links[0].color + + @property + def links(self) -> List[LinkDescription]: + return self._links + + def get_link_by_name(self, link_name: str) -> LinkDescription: + if link_name == self._links[0].name: + return self._links[0] + + @property + def joints(self) -> List[JointDescription]: + return [] + + def get_joint_by_name(self, joint_name: str) -> JointDescription: + ... + + def get_root(self) -> str: + return self._links[0].name + + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: + raise NotImplementedError("Do Not Do This on generic objects as they have no chains") + + @staticmethod + def get_file_extension() -> str: + raise NotImplementedError("Do Not Do This on generic objects as they have no extensions") + + @property + def origin(self) -> Pose: + return self._links[0].origin + + @property + def name(self) -> str: + return self._links[0].name diff --git a/src/pycram/object_descriptors/mjcf.py b/src/pycram/object_descriptors/mjcf.py new file mode 100644 index 000000000..36e895372 --- /dev/null +++ b/src/pycram/object_descriptors/mjcf.py @@ -0,0 +1,502 @@ +import os +import pathlib + +import numpy as np +from dm_control import mjcf +from geometry_msgs.msg import Point +from typing_extensions import Union, List, Optional, Dict, Tuple +from xml.etree import ElementTree as ET + +from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \ + SphereVisualShape, MeshVisualShape +from ..datastructures.enums import JointType, MJCFGeomType, MJCFJointType +from ..datastructures.pose import Pose +from ..description import JointDescription as AbstractJointDescription, \ + LinkDescription as AbstractLinkDescription, ObjectDescription as AbstractObjectDescription +from ..failures import MultiplePossibleTipLinks +from ..ros.ros_tools import get_parameter + +try: + from multiverse_parser import Configuration, Factory, InertiaSource, GeomBuilder + from multiverse_parser import (WorldBuilder, + GeomType, GeomProperty, + MeshProperty, + MaterialProperty) + from multiverse_parser import MjcfExporter + from pxr import Usd, UsdGeom +except ImportError: + # do not import this module if multiverse is not found + raise ImportError("Multiverse not found.") + + +class LinkDescription(AbstractLinkDescription): + """ + A class that represents a link description of an object. + """ + + def __init__(self, mjcf_description: mjcf.Element): + super().__init__(mjcf_description) + + @property + def geometry(self) -> Union[VisualShape, None]: + """ + :return: The geometry type of the collision element of this link. + """ + return self._get_visual_shape(self.parsed_description.find_all('geom')[0]) + + @staticmethod + def _get_visual_shape(mjcf_geometry) -> Union[VisualShape, None]: + """ + :param mjcf_geometry: The MJCFGeometry to get the visual shape for. + :return: The VisualShape of the given MJCFGeometry object. + """ + if mjcf_geometry.type == MJCFGeomType.BOX.value: + return BoxVisualShape(Color(), [0, 0, 0], mjcf_geometry.size) + if mjcf_geometry.type == MJCFGeomType.CYLINDER.value: + return CylinderVisualShape(Color(), [0, 0, 0], mjcf_geometry.size[0], mjcf_geometry.size[1] * 2) + if mjcf_geometry.type == MJCFGeomType.SPHERE.value: + return SphereVisualShape(Color(), [0, 0, 0], mjcf_geometry.size[0]) + if mjcf_geometry.type == MJCFGeomType.MESH.value: + return MeshVisualShape(Color(), [0, 0, 0], mjcf_geometry.scale, mjcf_geometry.filename) + return None + + @property + def origin(self) -> Union[Pose, None]: + """ + :return: The origin of this link. + """ + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + +class JointDescription(AbstractJointDescription): + + mjcf_type_map = { + MJCFJointType.HINGE.value: JointType.REVOLUTE, + MJCFJointType.BALL.value: JointType.SPHERICAL, + MJCFJointType.SLIDE.value: JointType.PRISMATIC, + MJCFJointType.FREE.value: JointType.FLOATING + } + """ + A dictionary mapping the MJCF joint types to the PyCRAM joint types. + """ + + pycram_type_map = {pycram_type: mjcf_type for mjcf_type, pycram_type in mjcf_type_map.items()} + """ + A dictionary mapping the PyCRAM joint types to the MJCF joint types. + """ + + def __init__(self, mjcf_description: mjcf.Element, is_virtual: Optional[bool] = False): + super().__init__(mjcf_description, is_virtual=is_virtual) + + @property + def origin(self) -> Pose: + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + @property + def has_limits(self) -> bool: + return self.parsed_description.limited + + @property + def type(self) -> JointType: + """ + :return: The type of this joint. + """ + if hasattr(self.parsed_description, 'type'): + return self.mjcf_type_map[self.parsed_description.type] + else: + return self.mjcf_type_map[MJCFJointType.FREE.value] + + @property + def axis(self) -> Point: + """ + :return: The axis of this joint, for example the rotation axis for a revolute joint. + """ + return Point(*self.parsed_description.axis) + + @property + def lower_limit(self) -> Union[float, None]: + """ + :return: The lower limit of this joint, or None if the joint has no limits. + """ + if self.has_limits: + return self.parsed_description.range[0] + else: + return None + + @property + def upper_limit(self) -> Union[float, None]: + """ + :return: The upper limit of this joint, or None if the joint has no limits. + """ + if self.has_limits: + return self.parsed_description.range[1] + else: + return None + + @property + def parent(self) -> str: + """ + :return: The name of the parent link of this joint. + """ + return self._parent_link_element.parent.name + + @property + def child(self) -> str: + """ + :return: The name of the child link of this joint. + """ + return self._parent_link_element.name + + @property + def _parent_link_element(self) -> mjcf.Element: + return self.parsed_description.parent + + @property + def damping(self) -> float: + """ + :return: The damping of this joint. + """ + return self.parsed_description.damping + + @property + def friction(self) -> float: + raise NotImplementedError("Friction is not implemented for MJCF joints.") + + +class ObjectFactory(Factory): + """ + Create MJCF object descriptions from mesh files. + """ + def __init__(self, object_name: str, file_path: str, config: Configuration, texture_type: str = "png"): + super().__init__(file_path, config) + + self._world_builder = WorldBuilder(usd_file_path=self.tmp_usd_file_path) + + body_builder = self._world_builder.add_body(body_name=object_name) + + tmp_usd_mesh_file_path, tmp_origin_mesh_file_path = self.import_mesh( + mesh_file_path=file_path, merge_mesh=True) + mesh_stage = Usd.Stage.Open(tmp_usd_mesh_file_path) + for idx, mesh_prim in enumerate([prim for prim in mesh_stage.Traverse() if prim.IsA(UsdGeom.Mesh)]): + mesh_name = mesh_prim.GetName() + mesh_path = mesh_prim.GetPath() + mesh_property = MeshProperty.from_mesh_file_path(mesh_file_path=tmp_usd_mesh_file_path, + mesh_path=mesh_path) + # mesh_property._texture_coordinates = None # TODO: See if needed otherwise remove it. + geom_property = GeomProperty(geom_type=GeomType.MESH, + is_visible=False, + is_collidable=True) + geom_builder = body_builder.add_geom(geom_name=f"SM_{object_name}_mesh_{idx}", + geom_property=geom_property) + geom_builder.add_mesh(mesh_name=mesh_name, mesh_property=mesh_property) + + # Add texture if available + texture_file_path = file_path.replace(pathlib.Path(file_path).suffix, f".{texture_type}") + if pathlib.Path(texture_file_path).exists(): + self.add_material_with_texture(geom_builder=geom_builder, material_name=f"M_{object_name}_{idx}", + texture_file_path=texture_file_path) + + geom_builder.build() + + body_builder.compute_and_set_inertial(inertia_source=InertiaSource.FROM_COLLISION_MESH) + + @staticmethod + def add_material_with_texture(geom_builder: GeomBuilder, material_name: str, texture_file_path: str): + """ + Add a material with a texture to the geom builder. + + :param geom_builder: The geom builder to add the material to. + :param material_name: The name of the material. + :param texture_file_path: The path to the texture file. + """ + material_property = MaterialProperty(diffuse_color=texture_file_path, + opacity=None, + emissive_color=None, + specular_color=None) + geom_builder.add_material(material_name=material_name, + material_property=material_property) + + def export_to_mjcf(self, output_file_path: str): + """ + Export the object to a MJCF file. + + :param output_file_path: The path to the output file. + """ + exporter = MjcfExporter(self, output_file_path) + exporter.build() + exporter.export(keep_usd=False) + + +class ObjectDescription(AbstractObjectDescription): + """ + A class that represents an object description of an object. + """ + + COMPILER_TAG = 'compiler' + """ + The tag of the compiler element in the MJCF file. + """ + MESH_DIR_ATTR = 'meshdir' + TEXTURE_DIR_ATTR = 'texturedir' + """ + The attributes of the compiler element in the MJCF file. The meshdir attribute is the directory where the mesh files + are stored and the texturedir attribute is the directory where the texture files are stored.""" + + class Link(AbstractObjectDescription.Link, LinkDescription): + ... + + class RootLink(AbstractObjectDescription.RootLink, Link): + ... + + class Joint(AbstractObjectDescription.Joint, JointDescription): + ... + + def __init__(self): + super().__init__() + self._link_map = None + self._joint_map = None + self._child_map = None + self._parent_map = None + self._links = None + self._joints = None + self.virtual_joint_names = [] + + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + if self._child_map is None: + self._child_map = self._construct_child_map() + return self._child_map + + def _construct_child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + Construct the child map of the object. + """ + child_map = {} + for joint in self.joints: + if joint.parent not in child_map: + child_map[joint.parent] = [(joint.name, joint.child)] + else: + child_map[joint.parent].append((joint.name, joint.child)) + return child_map + + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + if self._parent_map is None: + self._parent_map = self._construct_parent_map() + return self._parent_map + + def _construct_parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + Construct the parent map of the object. + """ + child_map = self.child_map + parent_map = {} + for parent, children in child_map.items(): + for child in children: + parent_map[child[1]] = (child[0], parent) + return parent_map + + @property + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + if self._link_map is None: + self._link_map = {link.name: link for link in self.links} + return self._link_map + + @property + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + if self._joint_map is None: + self._joint_map = {joint.name: joint for joint in self.joints} + return self._joint_map + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Finds the child link and adds a joint to it in the object description. + for arguments documentation see :meth:`pycram.description.ObjectDescription.add_joint` + """ + + position: Optional[List[float]] = None + quaternion: Optional[List[float]] = None + lower_limit: float = 0.0 if lower_limit is None else lower_limit + upper_limit: float = 0.0 if upper_limit is None else upper_limit + limit = [lower_limit, upper_limit] + + if origin is not None: + position = origin.position_as_list() + quaternion = origin.orientation_as_list() + quaternion = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]] + if axis is not None: + axis = [axis.x, axis.y, axis.z] + self.parsed_description.find(child).add('joint', name=name, type=JointDescription.pycram_type_map[joint_type], + axis=axis, pos=position, quat=quaternion, range=limit) + if is_virtual: + self.virtual_joint_names.append(name) + + def load_description(self, path) -> mjcf.RootElement: + return mjcf.from_file(path, model_dir=pathlib.Path(path).parent) + + def load_description_from_string(self, description_string: str) -> mjcf.RootElement: + return mjcf.from_xml_string(description_string) + + def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = Color(), + save_path: Optional[str] = None) -> None: + """ + Generate a mjcf xml file with the given .obj or .stl file as mesh. In addition, use the given rgba_color + to create a material tag in the xml. + + :param path: The path to the mesh file. + :param name: The name of the object. + :param color: The color of the object. + :param save_path: The path to save the generated xml file. + """ + factory = ObjectFactory(object_name=name, file_path=path, + config=Configuration(model_name=name, + fixed_base=False, + default_rgba=np.array(color.get_rgba()))) + factory.export_to_mjcf(output_file_path=save_path) + + def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: + model_str = self.replace_relative_paths_with_absolute_paths(path) + self.write_description_to_file(model_str, save_path) + + def replace_relative_paths_with_absolute_paths(self, model_path: str) -> str: + """ + Replace the relative paths in the xml file to be absolute paths. + + :param model_path: The path to the xml file. + """ + tree = ET.parse(model_path) + root = tree.getroot() + compiler = root.find(self.COMPILER_TAG) + model_dir = pathlib.Path(model_path).parent + for rel_dir_attrib in [self.MESH_DIR_ATTR, self.TEXTURE_DIR_ATTR]: + rel_dir = compiler.get(rel_dir_attrib) + abs_dir = str(pathlib.Path(os.path.join(model_dir, rel_dir)).resolve()) + compiler.set(rel_dir_attrib, abs_dir) + return ET.tostring(root, encoding='unicode', method='xml') + + def generate_from_parameter_server(self, name: str, save_path: str) -> None: + mjcf_string = get_parameter(name) + self.write_description_to_file(mjcf_string, save_path) + + @property + def joints(self) -> List[JointDescription]: + """ + :return: A list of joints descriptions of this object. + """ + if self._joints is None: + self._joints = [JointDescription(joint) for joint in self.parsed_description.find_all('joint')] + return self._joints + + @property + def links(self) -> List[LinkDescription]: + """ + :return: A list of link descriptions of this object. + """ + if self._links is None: + self._links = [LinkDescription(link) for link in self.parsed_description.find_all('body')] + return self._links + + def get_root(self) -> str: + """ + :return: the name of the root link of this object. + """ + if len(self.links) == 1: + return self.links[0].name + elif len(self.links) > 1: + return self.links[1].name + else: + raise ValueError("No links found in the object description.") + + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + :raises MultiplePossibleTipLinks: If there are multiple possible tip links. + """ + link = self.get_root() + while link in self.child_map: + children = self.child_map[link] + if len(children) > 1: + # Multiple children, can't decide which one to take (e.g. fingers of a hand) + raise MultiplePossibleTipLinks(self.name, link, [child[1] for child in children]) + else: + child = children[0][1] + link = child + return link + + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: + """ + :param start_link_name: The name of the start link of the chain. + :param end_link_name: The name of the end link of the chain. + :param joints: Whether to include joints in the chain. + :param links: Whether to include links in the chain. + :param fixed: Whether to include fixed joints in the chain (Note: not used in MJCF). + :return: the chain of links from 'start_link_name' to 'end_link_name'. + """ + chain = [] + if links: + chain.append(end_link_name) + link = end_link_name + while link != start_link_name: + (joint, parent) = self.parent_map[link] + if joints: + chain.append(joint) + if links: + chain.append(parent) + link = parent + chain.reverse() + return chain + + @staticmethod + def get_file_extension() -> str: + """ + :return: The file extension of the URDF file. + """ + return '.xml' + + @property + def origin(self) -> Pose: + return parse_pose_from_body_element(self.parsed_description) + + @property + def name(self) -> str: + return self.parsed_description.name + + +def parse_pose_from_body_element(body: mjcf.Element) -> Pose: + """ + Parse the pose from a body element. + + :param body: The body element. + :return: The pose of the body. + """ + position = body.pos + quaternion = body.quat + position = [0, 0, 0] if position is None else position + quaternion = [1, 0, 0, 0] if quaternion is None else quaternion + quaternion = [quaternion[1], quaternion[2], quaternion[3], quaternion[0]] + return Pose(position, quaternion) diff --git a/src/pycram/object_descriptors/urdf.py b/src/pycram/object_descriptors/urdf.py index 75ab98a03..694d421be 100644 --- a/src/pycram/object_descriptors/urdf.py +++ b/src/pycram/object_descriptors/urdf.py @@ -1,11 +1,14 @@ +import os import pathlib -from xml.etree import ElementTree +import xml.etree.ElementTree as ET -import rospkg -import rospy +import numpy as np + +from ..ros.logging import logerr +from ..ros.ros_tools import create_ros_pack, ResourceNotFound, get_parameter from geometry_msgs.msg import Point -from tf.transformations import quaternion_from_euler -from typing_extensions import Union, List, Optional +from tf.transformations import quaternion_from_euler, euler_from_quaternion +from typing_extensions import Union, List, Optional, Dict, Tuple from urdf_parser_py import urdf from urdf_parser_py.urdf import (URDF, Collision, Box as URDF_Box, Cylinder as URDF_Cylinder, Sphere as URDF_Sphere, Mesh as URDF_Mesh) @@ -16,6 +19,7 @@ LinkDescription as AbstractLinkDescription, ObjectDescription as AbstractObjectDescription from ..datastructures.dataclasses import Color, VisualShape, BoxVisualShape, CylinderVisualShape, \ SphereVisualShape, MeshVisualShape +from ..failures import MultiplePossibleTipLinks from ..utils import suppress_stdout_stderr @@ -30,7 +34,7 @@ def __init__(self, urdf_description: urdf.Link): @property def geometry(self) -> Union[VisualShape, None]: """ - Returns the geometry type of the URDF collision element of this link. + :return: The geometry type of the URDF collision element of this link. """ if self.collision is None: return None @@ -40,10 +44,12 @@ def geometry(self) -> Union[VisualShape, None]: @staticmethod def _get_visual_shape(urdf_geometry) -> Union[VisualShape, None]: """ - Returns the VisualShape of the given URDF geometry. + :param urdf_geometry: The URDFGeometry for which the visual shape is returned. + :return: the VisualShape of the given URDF geometry. """ if isinstance(urdf_geometry, URDF_Box): - return BoxVisualShape(Color(), [0, 0, 0], urdf_geometry.size) + half_extents = np.array(urdf_geometry.size) / 2 + return BoxVisualShape(Color(), [0, 0, 0], half_extents.tolist()) if isinstance(urdf_geometry, URDF_Cylinder): return CylinderVisualShape(Color(), [0, 0, 0], urdf_geometry.radius, urdf_geometry.length) if isinstance(urdf_geometry, URDF_Sphere): @@ -79,8 +85,10 @@ class JointDescription(AbstractJointDescription): 'planar': JointType.PLANAR, 'fixed': JointType.FIXED} - def __init__(self, urdf_description: urdf.Joint): - super().__init__(urdf_description) + pycram_type_map = {pycram_type: urdf_type for urdf_type, pycram_type in urdf_type_map.items()} + + def __init__(self, urdf_description: urdf.Joint, is_virtual: Optional[bool] = False): + super().__init__(urdf_description, is_virtual=is_virtual) @property def origin(self) -> Pose: @@ -130,14 +138,14 @@ def upper_limit(self) -> Union[float, None]: return None @property - def parent_link_name(self) -> str: + def parent(self) -> str: """ :return: The name of the parent link of this joint. """ return self.parsed_description.parent @property - def child_link_name(self) -> str: + def child(self) -> str: """ :return: The name of the child link of this joint. """ @@ -172,21 +180,83 @@ class RootLink(AbstractObjectDescription.RootLink, Link): class Joint(AbstractObjectDescription.Joint, JointDescription): ... + @property + def child_map(self) -> Dict[str, List[Tuple[str, str]]]: + """ + :return: A dictionary mapping the name of a link to its children which are represented as a tuple of the child + joint name and the link name. + """ + return self.parsed_description.child_map + + @property + def parent_map(self) -> Dict[str, Tuple[str, str]]: + """ + :return: A dictionary mapping the name of a link to its parent joint and link as a tuple. + """ + return self.parsed_description.parent_map + + @property + def link_map(self) -> Dict[str, LinkDescription]: + """ + :return: A dictionary mapping the name of a link to its description. + """ + if self._link_map is None: + self._link_map = {link.name: link for link in self.links} + return self._link_map + + @property + def joint_map(self) -> Dict[str, JointDescription]: + """ + :return: A dictionary mapping the name of a joint to its description. + """ + if self._joint_map is None: + self._joint_map = {joint.name: joint for joint in self.joints} + return self._joint_map + + def add_joint(self, name: str, child: str, joint_type: JointType, + axis: Point, parent: Optional[str] = None, origin: Optional[Pose] = None, + lower_limit: Optional[float] = None, upper_limit: Optional[float] = None, + is_virtual: Optional[bool] = False) -> None: + """ + Add a joint to the object description, could be a virtual joint as well. + For documentation of the parameters, see :meth:`pycram.description.ObjectDescription.add_joint`. + """ + if lower_limit is not None or upper_limit is not None: + limit = urdf.JointLimit(lower=lower_limit, upper=upper_limit) + else: + limit = None + if origin is not None: + origin = urdf.Pose(origin.position_as_list(), euler_from_quaternion(origin.orientation_as_list())) + if axis is not None: + axis = [axis.x, axis.y, axis.z] + if parent is None: + parent = self.get_root() + else: + parent = self.get_link_by_name(parent).parsed_description + joint = urdf.Joint(name, + parent, + self.get_link_by_name(child).parsed_description, + JointDescription.pycram_type_map[joint_type], + axis, origin, limit) + self.parsed_description.add_joint(joint) + if is_virtual: + self.virtual_joint_names.append(name) + def load_description(self, path) -> URDF: with open(path, 'r') as file: # Since parsing URDF causes a lot of warning messages which can't be deactivated, we suppress them with suppress_stdout_stderr(): return URDF.from_xml_string(file.read()) - def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = Color()) -> str: + def generate_from_mesh_file(self, path: str, name: str, save_path: str, color: Optional[Color] = Color()) -> None: """ - Generates an URDf file with the given .obj or .stl file as mesh. In addition, the given rgba_color will be - used to create a material tag in the URDF. + Generate a URDf file with the given .obj or .stl file as mesh. In addition, use the given rgba_color to create a + material tag in the URDF. The URDF file will be saved to the given save_path. :param path: The path to the mesh file. :param name: The name of the object. + :param save_path: The path to save the URDF file to. :param color: The color of the object. - :return: The absolute path of the created file """ urdf_template = ' \n \ \n \ @@ -211,55 +281,45 @@ def generate_from_mesh_file(self, path: str, name: str, color: Optional[Color] = pathlib_obj = pathlib.Path(path) path = str(pathlib_obj.resolve()) content = urdf_template.replace("~a", name).replace("~b", path).replace("~c", rgb) - return content + self.write_description_to_file(content, save_path) - def generate_from_description_file(self, path: str) -> str: + def generate_from_description_file(self, path: str, save_path: str, make_mesh_paths_absolute: bool = True) -> None: with open(path, mode="r") as f: urdf_string = self.fix_missing_inertial(f.read()) - urdf_string = self.remove_error_tags(urdf_string) - urdf_string = self.fix_link_attributes(urdf_string) - try: - urdf_string = self.correct_urdf_string(urdf_string) - except rospkg.ResourceNotFound as e: - rospy.logerr(f"Could not find resource package linked in this URDF") - raise e - return urdf_string - - def generate_from_parameter_server(self, name: str) -> str: - urdf_string = rospy.get_param(name) - return self.correct_urdf_string(urdf_string) - - def get_link_by_name(self, link_name: str) -> LinkDescription: - """ - :return: The link description with the given name. - """ - for link in self.links: - if link.name == link_name: - return link - raise ValueError(f"Link with name {link_name} not found") + urdf_string = self.remove_error_tags(urdf_string) + urdf_string = self.fix_link_attributes(urdf_string) + try: + urdf_string = self.replace_relative_references_with_absolute_paths(urdf_string) + urdf_string = self.fix_missing_inertial(urdf_string) + except ResourceNotFound as e: + logerr(f"Could not find resource package linked in this URDF") + raise e + urdf_string = self.make_mesh_paths_absolute(urdf_string, path) if make_mesh_paths_absolute else urdf_string + self.write_description_to_file(urdf_string, save_path) + + def generate_from_parameter_server(self, name: str, save_path: str) -> None: + urdf_string = get_parameter(name) + urdf_string = self.replace_relative_references_with_absolute_paths(urdf_string) + urdf_string = self.fix_missing_inertial(urdf_string) + self.write_description_to_file(urdf_string, save_path) @property - def links(self) -> List[LinkDescription]: - """ - :return: A list of links descriptions of this object. - """ - return [LinkDescription(link) for link in self.parsed_description.links] - - def get_joint_by_name(self, joint_name: str) -> JointDescription: + def joints(self) -> List[JointDescription]: """ - :return: The joint description with the given name. + :return: A list of joints descriptions of this object. """ - for joint in self.joints: - if joint.name == joint_name: - return joint - raise ValueError(f"Joint with name {joint_name} not found") + if self._joints is None: + self._joints = [JointDescription(joint) for joint in self.parsed_description.joints] + return self._joints @property - def joints(self) -> List[JointDescription]: + def links(self) -> List[LinkDescription]: """ - :return: A list of joints descriptions of this object. + :return: A list of link descriptions of this object. """ - return [JointDescription(joint) for joint in self.parsed_description.joints] + if self._links is None: + self._links = [LinkDescription(link) for link in self.parsed_description.links] + return self._links def get_root(self) -> str: """ @@ -267,21 +327,44 @@ def get_root(self) -> str: """ return self.parsed_description.get_root() - def get_chain(self, start_link_name: str, end_link_name: str) -> List[str]: - """ + def get_tip(self) -> str: + """ + :return: the name of the tip link of this object. + :raises MultiplePossibleTipLinks: If there are multiple possible tip links. + """ + link = self.get_root() + while link in self.parsed_description.child_map: + children = self.parsed_description.child_map[link] + if len(children) > 1: + # Multiple children, can't decide which one to take (e.g. fingers of a hand) + raise MultiplePossibleTipLinks(self.parsed_description.name, link, [child[1] for child in children]) + else: + child = children[0][1] + link = child + return link + + def get_chain(self, start_link_name: str, end_link_name: str, joints: Optional[bool] = True, + links: Optional[bool] = True, fixed: Optional[bool] = True) -> List[str]: + """ + :param start_link_name: The name of the start link of the chain. + :param end_link_name: The name of the end link of the chain. + :param joints: Whether to include joints in the chain. + :param links: Whether to include links in the chain. + :param fixed: Whether to include fixed joints in the chain. :return: the chain of links from 'start_link_name' to 'end_link_name'. """ - return self.parsed_description.get_chain(start_link_name, end_link_name) + return self.parsed_description.get_chain(start_link_name, end_link_name, joints, links, fixed) - def correct_urdf_string(self, urdf_string: str) -> str: + @staticmethod + def replace_relative_references_with_absolute_paths(urdf_string: str) -> str: """ - Changes paths for files in the URDF from ROS paths to paths in the file system. Since World (PyBullet legacy) - can't deal with ROS package paths. + Change paths for files in the URDF from ROS paths and file dir references to paths in the file system. Since + World (PyBullet legacy) can't deal with ROS package paths. :param urdf_string: The name of the URDf on the parameter server :return: The URDF string with paths in the filesystem instead of ROS packages """ - r = rospkg.RosPack() + r = create_ros_pack() new_urdf_string = "" for line in urdf_string.split('\n'): if "package://" in line: @@ -289,9 +372,37 @@ def correct_urdf_string(self, urdf_string: str) -> str: s1 = s[1].split('/') path = r.get_path(s1[0]) line = line.replace("package://" + s1[0], path) + if 'file://' in line: + line = line.replace("file://", './') new_urdf_string += line + '\n' - return self.fix_missing_inertial(new_urdf_string) + return new_urdf_string + + @staticmethod + def make_mesh_paths_absolute(urdf_string: str, urdf_file_path: str) -> str: + """ + Convert all relative mesh paths in the URDF to absolute paths. + + :param urdf_string: The URDF description as string + :param urdf_file_path: The path to the URDF file + :returns: The new URDF description as string. + """ + # Parse the URDF file + root = ET.fromstring(urdf_string) + + # Iterate through all mesh tags + for mesh in root.findall('.//mesh'): + filename = mesh.attrib.get('filename', '') + if filename: + # If the filename is a relative path, convert it to an absolute path + if not os.path.isabs(filename): + # Deduce the base path from the relative path + base_path = os.path.dirname( + os.path.abspath(os.path.join(os.path.dirname(urdf_file_path), filename))) + abs_path = os.path.abspath(os.path.join(base_path, os.path.basename(filename))) + mesh.set('filename', abs_path) + + return ET.tostring(root, encoding='unicode') @staticmethod def fix_missing_inertial(urdf_string: str) -> str: @@ -303,10 +414,10 @@ def fix_missing_inertial(urdf_string: str) -> str: :returns: The new, corrected URDF description as string. """ - inertia_tree = ElementTree.ElementTree(ElementTree.Element("inertial")) - inertia_tree.getroot().append(ElementTree.Element("mass", {"value": "0.1"})) - inertia_tree.getroot().append(ElementTree.Element("origin", {"rpy": "0 0 0", "xyz": "0 0 0"})) - inertia_tree.getroot().append(ElementTree.Element("inertia", {"ixx": "0.01", + inertia_tree = ET.ElementTree(ET.Element("inertial")) + inertia_tree.getroot().append(ET.Element("mass", {"value": "0.1"})) + inertia_tree.getroot().append(ET.Element("origin", {"rpy": "0 0 0", "xyz": "0 0 0"})) + inertia_tree.getroot().append(ET.Element("inertia", {"ixx": "0.01", "ixy": "0", "ixz": "0", "iyy": "0.01", @@ -314,48 +425,48 @@ def fix_missing_inertial(urdf_string: str) -> str: "izz": "0.01"})) # create tree from string - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) for link_element in tree.iter("link"): inertial = [*link_element.iter("inertial")] if len(inertial) == 0: link_element.append(inertia_tree.getroot()) - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def remove_error_tags(urdf_string: str) -> str: """ - Removes all tags in the removing_tags list from the URDF since these tags are known to cause errors with the + Remove all tags in the removing_tags list from the URDF since these tags are known to cause errors with the URDF_parser :param urdf_string: String of the URDF from which the tags should be removed :return: The URDF string with the tags removed """ - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) removing_tags = ["gazebo", "transmission"] for tag_name in removing_tags: all_tags = tree.findall(tag_name) for tag in all_tags: tree.getroot().remove(tag) - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def fix_link_attributes(urdf_string: str) -> str: """ - Removes the attribute 'type' from links since this is not parsable by the URDF parser. + Remove the attribute 'type' from links since this is not parsable by the URDF parser. :param urdf_string: The string of the URDF from which the attributes should be removed :return: The URDF string with the attributes removed """ - tree = ElementTree.ElementTree(ElementTree.fromstring(urdf_string)) + tree = ET.ElementTree(ET.fromstring(urdf_string)) for link in tree.iter("link"): if "type" in link.attrib.keys(): del link.attrib["type"] - return ElementTree.tostring(tree.getroot(), encoding='unicode') + return ET.tostring(tree.getroot(), encoding='unicode') @staticmethod def get_file_extension() -> str: diff --git a/src/pycram/ontology/ontology.py b/src/pycram/ontology/ontology.py index 51a054d7b..f768d9315 100644 --- a/src/pycram/ontology/ontology.py +++ b/src/pycram/ontology/ontology.py @@ -4,17 +4,16 @@ import itertools import logging import os.path +import sqlite3 from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Tuple, Union -import owlready2 -import rospy - from owlready2 import (Namespace, Ontology, World as OntologyWorld, Thing, EntityClass, Imp, Property, ObjectProperty, OwlReadyError, types, onto_path, default_world, get_namespace, get_ontology, destroy_entity, - sync_reasoner_pellet, sync_reasoner_hermit) + sync_reasoner_pellet, sync_reasoner_hermit, + OwlReadyOntologyParsingError) from owlready2.class_construct import GeneralClassAxiom from ..datastructures.enums import ObjectType @@ -22,7 +21,9 @@ from ..designator import DesignatorDescription, ObjectDesignatorDescription from ..ontology.ontology_common import (OntologyConceptHolderStore, OntologyConceptHolder, - ONTOLOGY_SQL_BACKEND_FILE_EXTENSION) + ONTOLOGY_SQL_BACKEND_FILE_EXTENSION, + ONTOLOGY_SQL_IN_MEMORY_BACKEND) +from ..ros.logging import loginfo, logerr, logwarn SOMA_HOME_ONTOLOGY_IRI = "http://www.ease-crc.org/ont/SOMA-HOME.owl" SOMA_ONTOLOGY_IRI = "http://www.ease-crc.org/ont/SOMA.owl" @@ -35,12 +36,14 @@ class OntologyManager(object, metaclass=Singleton): Singleton class as the adapter accessing data of an OWL ontology, largely based on owlready2. """ - def __init__(self, main_ontology_iri: Optional[str] = None, ontology_search_path: Optional[str] = None, + def __init__(self, main_ontology_iri: Optional[str] = None, main_sql_backend_filename: Optional[str] = None, + ontology_search_path: Optional[str] = None, use_global_default_world: bool = True): """ Create the singleton object of OntologyManager class :param main_ontology_iri: Ontology IRI (Internationalized Resource Identifier), either a URL to a remote OWL file or the full name path of a local one + :param main_sql_backend_filename: a full file path (no need to already exist) being used as SQL backend for the ontology world. If None, in-memory is used instead :param ontology_search_path: directory path from which a possibly existing ontology is searched. This is appended to `owlready2.onto_path`, a global variable containing a list of directories for searching local copies of ontologies (similarly to python `sys.path` for modules/packages). If not specified, the path is "$HOME/ontologies" :param use_global_default_world: whether or not using the owlready2-provided global default persistent world """ @@ -74,22 +77,14 @@ def __init__(self, main_ontology_iri: Optional[str] = None, ontology_search_path #: Namespace of the main ontology self.main_ontology_namespace: Optional[Namespace] = None - # Create the main ontology world holding triples, of which a sqlite3 file path, of same name with `main_ontology` & - # at the same folder with `main_ontology_iri` (if it is a local abosulte path), is automatically registered as cache of the world - self.main_ontology_world = self.create_ontology_world( - sql_backend_filename=os.path.join(self.get_main_ontology_dir(), - f"{Path(self.main_ontology_iri).stem}{ONTOLOGY_SQL_BACKEND_FILE_EXTENSION}"), - use_global_default_world=use_global_default_world) + #: SQL backend for :attr:`main_ontology_world`, being either "memory" or a full file path (no need to already exist) + self.main_ontology_sql_backend = main_sql_backend_filename if main_sql_backend_filename else ONTOLOGY_SQL_IN_MEMORY_BACKEND - # Load ontologies from `main_ontology_iri` to `main_ontology_world` - # If `main_ontology_iri` is a remote URL, Owlready2 first searches for a local copy of the OWL file (from `onto_path`), - # if not found, tries to download it from the Internet. - ontology_info = self.load_ontology(self.main_ontology_iri) - if ontology_info: - self.main_ontology, self.main_ontology_namespace = ontology_info - if self.main_ontology and self.main_ontology.loaded: - self.soma = self.ontologies.get(SOMA_ONTOLOGY_NAMESPACE) - self.dul = self.ontologies.get(DUL_ONTOLOGY_NAMESPACE) + # Create the main ontology world holding triples + self.create_main_ontology_world(use_global_default_world=use_global_default_world) + + # Create the main ontology & its namespace, fetching :attr:`soma`, :attr:`dul` if loading from SOMA ontology + self.create_main_ontology() @staticmethod def print_ontology_class(ontology_class: Type[Thing]): @@ -100,20 +95,20 @@ def print_ontology_class(ontology_class: Type[Thing]): """ if ontology_class is None: return - rospy.loginfo(f"{ontology_class} {type(ontology_class)}") - rospy.loginfo(f"Defined class: {ontology_class.get_defined_class()}") - rospy.loginfo(f"Super classes: {ontology_class.is_a}") - rospy.loginfo(f"Equivalent to: {EntityClass.get_equivalent_to(ontology_class)}") - rospy.loginfo(f"Indirectly equivalent to: {ontology_class.get_indirect_equivalent_to()}") - rospy.loginfo(f"Ancestors: {list(ontology_class.ancestors())}") - rospy.loginfo(f"Subclasses: {list(ontology_class.subclasses())}") - rospy.loginfo(f"Disjoint unions: {ontology_class.get_disjoint_unions()}") - rospy.loginfo(f"Properties: {list(ontology_class.get_class_properties())}") - rospy.loginfo(f"Indirect Properties: {list(ontology_class.INDIRECT_get_class_properties())}") - rospy.loginfo(f"Instances: {list(ontology_class.instances())}") - rospy.loginfo(f"Direct Instances: {list(ontology_class.direct_instances())}") - rospy.loginfo(f"Inverse Restrictions: {list(ontology_class.inverse_restrictions())}") - rospy.loginfo("-------------------") + loginfo(f"{ontology_class} {type(ontology_class)}") + loginfo(f"Defined class: {ontology_class.get_defined_class()}") + loginfo(f"Super classes: {ontology_class.is_a}") + loginfo(f"Equivalent to: {EntityClass.get_equivalent_to(ontology_class)}") + loginfo(f"Indirectly equivalent to: {ontology_class.get_indirect_equivalent_to()}") + loginfo(f"Ancestors: {list(ontology_class.ancestors())}") + loginfo(f"Subclasses: {list(ontology_class.subclasses())}") + loginfo(f"Disjoint unions: {ontology_class.get_disjoint_unions()}") + loginfo(f"Properties: {list(ontology_class.get_class_properties())}") + loginfo(f"Indirect Properties: {list(ontology_class.INDIRECT_get_class_properties())}") + loginfo(f"Instances: {list(ontology_class.instances())}") + loginfo(f"Direct Instances: {list(ontology_class.direct_instances())}") + loginfo(f"Inverse Restrictions: {list(ontology_class.inverse_restrictions())}") + loginfo("-------------------") @staticmethod def print_ontology_property(ontology_property: Property): @@ -125,17 +120,17 @@ def print_ontology_property(ontology_property: Property): if ontology_property is None: return property_class = type(ontology_property) - rospy.loginfo(f"{ontology_property} {property_class}") - rospy.loginfo(f"Relations: {list(ontology_property.get_relations())}") - rospy.loginfo(f"Domain: {ontology_property.get_domain()}") - rospy.loginfo(f"Range: {ontology_property.get_range()}") + loginfo(f"{ontology_property} {property_class}") + loginfo(f"Relations: {list(ontology_property.get_relations())}") + loginfo(f"Domain: {ontology_property.get_domain()}") + loginfo(f"Range: {ontology_property.get_range()}") if hasattr(property_class, "_equivalent_to"): - rospy.loginfo(f"Equivalent classes: {EntityClass.get_equivalent_to(property_class)}") + loginfo(f"Equivalent classes: {EntityClass.get_equivalent_to(property_class)}") if hasattr(property_class, "_indirect"): - rospy.loginfo(f"Indirectly equivalent classes: {EntityClass.get_indirect_equivalent_to(property_class)}") - rospy.loginfo(f"Property chain: {ontology_property.get_property_chain()}") - rospy.loginfo(f"Class property type: {ontology_property.get_class_property_type()}") - rospy.loginfo("-------------------") + loginfo(f"Indirectly equivalent classes: {EntityClass.get_indirect_equivalent_to(property_class)}") + loginfo(f"Property chain: {ontology_property.get_property_chain()}") + loginfo(f"Class property type: {ontology_property.get_class_property_type()}") + loginfo("-------------------") @staticmethod def get_default_ontology_search_path() -> Optional[str]: @@ -147,7 +142,7 @@ def get_default_ontology_search_path() -> Optional[str]: if onto_path: return onto_path[0] else: - rospy.logerr("No ontology search path has been configured!") + logerr("No ontology search path has been configured!") return None def get_main_ontology_dir(self) -> Optional[str]: @@ -160,35 +155,88 @@ def get_main_ontology_dir(self) -> Optional[str]: return os.path.dirname(self.main_ontology_iri) if os.path.isabs( self.main_ontology_iri) else self.get_default_ontology_search_path() + def is_main_ontology_sql_backend_in_memory(self) -> bool: + """ + Whether the main ontology's SQL backend is in-memory + + :return: true if the main ontology's SQL backend is in-memory + """ + return self.main_ontology_sql_backend == ONTOLOGY_SQL_IN_MEMORY_BACKEND + + def create_main_ontology_world(self, use_global_default_world: bool = True) -> None: + """ + Create the main ontology world, either reusing the owlready2-provided global default ontology world or create a new one + A backend sqlite3 file of same name with `main_ontology` is also created at the same folder with :attr:`main_ontology_iri` + (if it is a local absolute path). The file is automatically registered as cache for the main ontology world. + + :param use_global_default_world: whether or not using the owlready2-provided global default persistent world + :param sql_backend_filename: a full file path (no need to already exist) being used as SQL backend for the ontology world. If None, memory is used instead + """ + self.main_ontology_world = self.create_ontology_world( + sql_backend_filename=self.main_ontology_sql_backend, + use_global_default_world=use_global_default_world) + @staticmethod def create_ontology_world(use_global_default_world: bool = False, sql_backend_filename: Optional[str] = None) -> OntologyWorld: """ - Either reuse the owlready2-provided global default ontology world or create a new one + Either reuse the owlready2-provided global default ontology world or create a new one. :param use_global_default_world: whether or not using the owlready2-provided global default persistent world - :param sql_backend_filename: a full file path (no need to already exist) being used as SQL backend for the ontology world. If None, memory is used instead + :param sql_backend_filename: an absolute file path (no need to already exist) being used as SQL backend for the ontology world. If it is None or non-absolute path, in-memory is used instead :return: owlready2-provided global default ontology world or a newly created ontology world """ world = default_world - sql_backend_path_valid = sql_backend_filename and os.path.isabs(sql_backend_filename) - sql_backend_name = sql_backend_filename if sql_backend_path_valid else "memory" - if use_global_default_world: - # Reuse default world - if sql_backend_path_valid: - world.set_backend(filename=sql_backend_filename, exclusive=False, enable_thread_parallelism=True) - else: - world.set_backend(exclusive=False, enable_thread_parallelism=True) - rospy.loginfo(f"Using global default ontology world with SQL backend: {sql_backend_name}") - else: - # Create a new world with parallelized file parsing enabled - if sql_backend_path_valid: - world = OntologyWorld(filename=sql_backend_filename, exclusive=False, enable_thread_parallelism=True) + sql_backend_path_absolute = (sql_backend_filename and os.path.isabs(sql_backend_filename)) + if sql_backend_filename and (sql_backend_filename != ONTOLOGY_SQL_IN_MEMORY_BACKEND): + if not sql_backend_path_absolute: + logerr(f"For ontology world accessing, either f{ONTOLOGY_SQL_IN_MEMORY_BACKEND}" + f"or an absolute path to its SQL file backend is expected: {sql_backend_filename}") + return default_world + elif not sql_backend_filename.endswith(ONTOLOGY_SQL_BACKEND_FILE_EXTENSION): + logerr( + f"Ontology world SQL backend file path, {sql_backend_filename}," + f"is expected to be of extension {ONTOLOGY_SQL_BACKEND_FILE_EXTENSION}!") + return default_world + + sql_backend_path_valid = sql_backend_path_absolute + sql_backend_name = sql_backend_filename if sql_backend_path_valid else ONTOLOGY_SQL_IN_MEMORY_BACKEND + try: + if use_global_default_world: + # Reuse default world + if sql_backend_path_valid: + world.set_backend(filename=sql_backend_filename, exclusive=False, enable_thread_parallelism=True) + else: + world.set_backend(exclusive=False, enable_thread_parallelism=True) + loginfo(f"Using global default ontology world with SQL backend: {sql_backend_name}") else: - world = OntologyWorld(exclusive=False, enable_thread_parallelism=True) - rospy.loginfo(f"Created a new ontology world with SQL backend: {sql_backend_name}") + # Create a new world with parallelized file parsing enabled + if sql_backend_path_valid: + world = OntologyWorld(filename=sql_backend_filename, exclusive=False, enable_thread_parallelism=True) + else: + world = OntologyWorld(exclusive=False, enable_thread_parallelism=True) + loginfo(f"Created a new ontology world with SQL backend: {sql_backend_name}") + except sqlite3.Error as e: + logerr(f"Failed accessing the SQL backend of ontology world: {sql_backend_name}", + e.sqlite_errorcode, e.sqlite_errorname) return world + def create_main_ontology(self) -> bool: + """ + Load ontologies from :attr:`main_ontology_iri` to :attr:`main_ontology_world` + If `main_ontology_iri` is a remote URL, Owlready2 first searches for a local copy of the OWL file (from `onto_path`), + if not found, tries to download it from the Internet. + + :return: True if loading succeeds + """ + ontology_info = self.load_ontology(self.main_ontology_iri) + if ontology_info: + self.main_ontology, self.main_ontology_namespace = ontology_info + if self.main_ontology and self.main_ontology.loaded: + self.soma = self.ontologies.get(SOMA_ONTOLOGY_NAMESPACE) + self.dul = self.ontologies.get(DUL_ONTOLOGY_NAMESPACE) + return ontology_info is not None + def load_ontology(self, ontology_iri: str) -> Optional[Tuple[Ontology, Namespace]]: """ Load an ontology from an IRI @@ -197,34 +245,49 @@ def load_ontology(self, ontology_iri: str) -> Optional[Tuple[Ontology, Namespace :return: A tuple including an ontology instance & its namespace """ if not ontology_iri: - rospy.logerr("Ontology IRI is empty") + logerr("Ontology IRI is empty") return None - # If `ontology_iri` is a local path -> create an empty ontology file if not existing - if not (ontology_iri.startswith("http:") or ontology_iri.startswith("https:")) \ - and not Path(ontology_iri).exists(): - with open(ontology_iri, 'w'): + is_local_ontology_iri = not (ontology_iri.startswith("http:") or ontology_iri.startswith("https:")) + + # If `ontology_iri` is a local path + if is_local_ontology_iri and not Path(ontology_iri).exists(): + # -> Create an empty ontology file if not existing + ontology_path = ontology_iri if os.path.isabs(ontology_iri) else ( + os.path.join(self.get_main_ontology_dir(), ontology_iri)) + with open(ontology_path, 'w'): pass # Load ontology from `ontology_iri` - if self.main_ontology_world: - ontology = self.main_ontology_world.get_ontology(ontology_iri).load(reload_if_newer=True) - else: - ontology = get_ontology(ontology_iri).load(reload_if_newer=True) + ontology = None + try: + if self.main_ontology_world: + ontology = self.main_ontology_world.get_ontology(ontology_iri).load(reload_if_newer=True) + else: + ontology = get_ontology(ontology_iri).load(reload_if_newer=True) + except OwlReadyOntologyParsingError as error: + logwarn(error) + if is_local_ontology_iri: + logerr(f"Main ontology failed being loaded from {ontology_iri}") + else: + logwarn(f"Main ontology failed being downloaded from the remote {ontology_iri}") + return None + + # Browse loaded `ontology`, fetching sub-ontologies ontology_namespace = get_namespace(ontology_iri) - if ontology.loaded: - rospy.loginfo( + if ontology and ontology.loaded: + loginfo( f'Ontology [{ontology.base_iri}]\'s name: {ontology.name} has been loaded') - rospy.loginfo(f'- main namespace: {ontology_namespace.name}') - rospy.loginfo(f'- loaded ontologies:') + loginfo(f'- main namespace: {ontology_namespace.name}') + loginfo(f'- loaded ontologies:') def fetch_ontology(ontology__): self.ontologies[ontology__.name] = ontology__ - rospy.loginfo(ontology__.base_iri) + loginfo(ontology__.base_iri) self.browse_ontologies(ontology, condition=None, func=lambda ontology__: fetch_ontology(ontology__)) else: - rospy.logerr(f"Ontology [{ontology.base_iri}]\'s name: {ontology.name} failed being loaded") + logerr(f"Ontology [{ontology.base_iri}]\'s name: {ontology.name} failed being loaded") return ontology, ontology_namespace def initialized(self) -> bool: @@ -246,10 +309,10 @@ def browse_ontologies(ontology: Ontology, :param func: a Callable specifying the operations to perform on all the loaded ontologies if condition is None, otherwise only the first ontology which meets the condition """ if ontology is None: - rospy.logerr(f"Ontology {ontology=} is None!") + logerr(f"Ontology {ontology=} is None!") return elif not ontology.loaded: - rospy.logerr(f"Ontology {ontology} was not loaded!") + logerr(f"Ontology {ontology} was not loaded!") return will_do_func = func is not None @@ -283,23 +346,23 @@ def save(self, target_filename: Optional[str] = None, overwrite: bool = False) - else f"{self.get_main_ontology_dir()}/{Path(self.main_ontology_iri).name}" save_to_same_file = is_current_ontology_local and (target_filename == current_ontology_filename) if save_to_same_file and not overwrite: - rospy.logerr( + logerr( f"Ontologies cannot be saved to the originally loaded [{target_filename}] if not by overwriting") return False else: save_filename = target_filename if target_filename else current_ontology_filename self.main_ontology.save(save_filename) if save_to_same_file and overwrite: - rospy.logwarn(f"Main ontology {self.main_ontology.name} has been overwritten to {save_filename}") + logwarn(f"Main ontology {self.main_ontology.name} has been overwritten to {save_filename}") else: - rospy.loginfo(f"Main ontology {self.main_ontology.name} has been saved to {save_filename}") + loginfo(f"Main ontology {self.main_ontology.name} has been saved to {save_filename}") # Commit the whole graph data of the current ontology world, saving it into SQLite3, to be reused the next time # the ontologies are loaded main_ontology_sql_filename = self.main_ontology_world.filename self.main_ontology_world.save() if os.path.isfile(main_ontology_sql_filename): - rospy.loginfo( + loginfo( f"Main ontology world for {self.main_ontology.name} has been cached and saved to SQL: {main_ontology_sql_filename}") #else: it could be using memory cache as SQL backend return True @@ -322,7 +385,7 @@ def create_ontology_concept_class(self, class_name: str, return ontology_concept_class if getattr(ontology, class_name, None): - rospy.logerr(f"Ontology concept class {ontology.name}.{class_name} already exists") + logerr(f"Ontology concept class {ontology.name}.{class_name} already exists") return None with ontology: @@ -348,7 +411,7 @@ def create_ontology_property_class(self, class_name: str, else None if getattr(ontology, class_name, None): - rospy.logerr(f"Ontology property class {ontology.name}.{class_name} already exists") + logerr(f"Ontology property class {ontology.name}.{class_name} already exists") return None with ontology: @@ -378,7 +441,7 @@ def get_ontology_classes_by_condition(self, condition: Callable, first_match_onl return out_classes if not out_classes: - rospy.loginfo(f"No class with {kwargs} is found in the ontology {self.main_ontology}") + loginfo(f"No class with {kwargs} is found in the ontology {self.main_ontology}") return out_classes @staticmethod @@ -499,7 +562,7 @@ def create_ontology_triple_classes(self, subject_class_name: str, object_class_n ontology_subject_parent_class, ontology=ontology) if not ontology_subject_class: - rospy.logerr(f"{ontology.name}: Failed creating ontology subject class named {subject_class_name}") + logerr(f"{ontology.name}: Failed creating ontology subject class named {subject_class_name}") return False # Object @@ -512,7 +575,7 @@ def create_ontology_triple_classes(self, subject_class_name: str, object_class_n ontology_object_class = ontology_object_parent_class if not ontology_object_class: - rospy.logerr(f"{ontology.name}: Failed creating ontology object class named {object_class_name}") + logerr(f"{ontology.name}: Failed creating ontology object class named {object_class_name}") return False # Predicate @@ -520,7 +583,7 @@ def create_ontology_triple_classes(self, subject_class_name: str, object_class_n ontology_property_parent_class, ontology=ontology) if not ontology_predicate_class: - rospy.logerr(f"{ontology.name}: Failed creating ontology predicate class named {predicate_class_name}") + logerr(f"{ontology.name}: Failed creating ontology predicate class named {predicate_class_name}") return False ontology_predicate_class.domain = [ontology_subject_class] ontology_predicate_class.range = [ontology_object_class] @@ -532,7 +595,7 @@ def create_ontology_triple_classes(self, subject_class_name: str, object_class_n ontology_inverse_property_parent_class, ontology=ontology) if not ontology_inverse_predicate_class: - rospy.logerr( + logerr( f"{ontology.name}: Failed creating ontology inverse-predicate class named {inverse_predicate_class_name}") return False ontology_inverse_predicate_class.inverse_property = ontology_predicate_class @@ -574,7 +637,7 @@ def create_ontology_linked_designator_by_concept(self, designator_class: Type[De """ ontology_concept_name = f'{object_name}_concept' if len(OntologyConceptHolderStore().get_designators_of_ontology_concept(ontology_concept_name)) > 0: - rospy.logerr( + logerr( f"A designator named [{object_name}] is already created for ontology concept [{ontology_concept_name}]") return None @@ -582,7 +645,7 @@ def create_ontology_linked_designator_by_concept(self, designator_class: Type[De is_object_designator = issubclass(designator_class, ObjectDesignatorDescription) if is_object_designator: if not object_name: - rospy.logerr( + logerr( f"An empty object name was given as creating its Object designator for ontology concept class [{ontology_concept_class.name}]") return None designator = designator_class(names=[object_name]) @@ -635,7 +698,7 @@ def set_ontology_relation(subject_designator: DesignatorDescription, object_concepts_list.append(holder.ontology_concept) return True else: - rospy.logerr(f"Ontology concept [{subject_concept.name}] has no predicate named [{predicate_name}]") + logerr(f"Ontology concept [{subject_concept.name}] has no predicate named [{predicate_name}]") return False @staticmethod @@ -778,7 +841,7 @@ def reason(self, world: OntologyWorld = None, use_pellet_reasoner: bool = True) reasoner_name = "HermiT" sync_reasoner_hermit(x=reasoning_world, infer_property_values=True) except OwlReadyError as error: - rospy.logerr(f"{reasoner_name} reasoning failed: {error}") + logerr(f"{reasoner_name} reasoning failed: {error}") return False - rospy.loginfo(f"{reasoner_name} reasoning finishes!") + loginfo(f"{reasoner_name} reasoning finishes!") return True diff --git a/src/pycram/ontology/ontology_common.py b/src/pycram/ontology/ontology_common.py index 080078a08..0df9ef570 100644 --- a/src/pycram/ontology/ontology_common.py +++ b/src/pycram/ontology/ontology_common.py @@ -2,15 +2,17 @@ import itertools from typing import Callable, Dict, List, Optional, Type, TYPE_CHECKING -import rospy from ..helper import Singleton +from ..ros.logging import logerr + if TYPE_CHECKING: from ..designator import DesignatorDescription from owlready2 import issubclass, Thing ONTOLOGY_SQL_BACKEND_FILE_EXTENSION = ".sqlite3" +ONTOLOGY_SQL_IN_MEMORY_BACKEND = "memory" ONTOLOGY_OWL_FILE_EXTENSION = ".owl" @@ -35,7 +37,7 @@ def add_ontology_concept_holder(self, ontology_concept_name: str, ontology_conce :return: True if the ontology concept can be added into the concept store (if not already existing), otherwise False """ if ontology_concept_name in self.__all_ontology_concept_holders: - rospy.logerr(f"OntologyConceptHolder for `{ontology_concept_name}` was already created!") + logerr(f"OntologyConceptHolder for `{ontology_concept_name}` was already created!") return False else: self.__all_ontology_concept_holders.setdefault(ontology_concept_name, ontology_concept_holder) diff --git a/src/pycram/orm/action_designator.py b/src/pycram/orm/action_designator.py index c84a16ea2..90d4ba5de 100644 --- a/src/pycram/orm/action_designator.py +++ b/src/pycram/orm/action_designator.py @@ -47,7 +47,7 @@ class SetGripperAction(Action): motion: Mapped[GripperState] -class Release(ObjectMixin, Action): +class ReleaseAction(ObjectMixin, Action): """ORM Class of pycram.designators.action_designator.Release.""" id: Mapped[int] = mapped_column(ForeignKey(f'{Action.__tablename__}.id'), primary_key=True, init=False) @@ -102,7 +102,6 @@ class OpenAction(ObjectMixin, Action): id: Mapped[int] = mapped_column(ForeignKey(f'{Action.__tablename__}.id'), primary_key=True, init=False) arm: Mapped[Arms] - # distance: Mapped[float] = mapped_column(init=False) class CloseAction(ObjectMixin, Action): diff --git a/src/pycram/orm/base.py b/src/pycram/orm/base.py index 0e064bd33..06918ebe2 100644 --- a/src/pycram/orm/base.py +++ b/src/pycram/orm/base.py @@ -120,7 +120,7 @@ class PoseMixin(MappedAsDataclass): @declared_attr def pose_id(self) -> Mapped[int]: - return mapped_column(ForeignKey(f'{Pose.__tablename__}.id'), init=self.pose_to_init) + return mapped_column(ForeignKey(f'{Pose.__tablename__}.id'), init=self.pose_to_init, nullable=True) @declared_attr def pose(self): diff --git a/src/pycram/orm/utils.py b/src/pycram/orm/utils.py index 624773a88..adcc5ebcc 100644 --- a/src/pycram/orm/utils.py +++ b/src/pycram/orm/utils.py @@ -1,12 +1,11 @@ import traceback -import rospy import sqlalchemy -import pycram.orm.base -from pycram.designators.object_designator import * +from .base import Base +from ..designators.object_designator import * import json -from pycram.designators.action_designator import * -import pycram.orm +from ..ros.logging import loginfo, logwarn + def write_database_to_file(in_sessionmaker: sqlalchemy.orm.sessionmaker, filename: str, @@ -21,7 +20,7 @@ def write_database_to_file(in_sessionmaker: sqlalchemy.orm.sessionmaker, filenam with in_sessionmaker() as session: with open("whatever.txt", "w") as f: to_json_dict = dict() - for table in pycram.orm.base.Base.metadata.sorted_tables: + for table in Base.metadata.sorted_tables: list_of_row = list() for column_object in session.query(table).all(): list_of_row.append(column_object) @@ -37,13 +36,13 @@ def print_database(in_sessionmaker: sqlalchemy.orm.sessionmaker): :param in_sessionmaker: Database Session which should be printed """ with in_sessionmaker() as session: - for table in pycram.orm.base.Base.metadata.sorted_tables: + for table in Base.metadata.sorted_tables: try: smt = sqlalchemy.select('*').select_from(table) result = session.execute(smt).all() - rospy.loginfo("Table: {}\tcontent:{}".format(table, result)) + loginfo("Table: {}\tcontent:{}".format(table, result)) except sqlalchemy.exc.ArgumentError as e: - rospy.logwarn(e) + logwarn(e) def update_primary_key(source_session_maker: sqlalchemy.orm.sessionmaker, @@ -60,7 +59,7 @@ def update_primary_key(source_session_maker: sqlalchemy.orm.sessionmaker, """ destination_session = destination_session_maker() source_session = source_session_maker() - sortedTables = pycram.orm.base.Base.metadata.sorted_tables + sortedTables = Base.metadata.sorted_tables for table in sortedTables: try: list_of_primary_keys_of_this_table = table.primary_key.columns.values() @@ -77,7 +76,7 @@ def update_primary_key(source_session_maker: sqlalchemy.orm.sessionmaker, results = destination_session.execute(sqlalchemy.select(table)) for column_object in results: # iterate over all columns if column_object.__getattr__(key.name) in all_source_key_values: - rospy.loginfo( + loginfo( "Found primary_key collision in table {} value: {} max value in memory {}".format(table, column_object.__getattr__( key.name), @@ -90,8 +89,8 @@ def update_primary_key(source_session_maker: sqlalchemy.orm.sessionmaker, highest_free_key_value += 1 destination_session.commit() # commit after every table except AttributeError as e: - rospy.logwarn("Possible found abstract ORM class {}".format(e.__name__)) - rospy.logwarn(e) + logwarn("Possible found abstract ORM class {}".format(e.__name__)) + logwarn(e) destination_session.close() @@ -112,7 +111,7 @@ def copy_database(source_session_maker: sqlalchemy.orm.sessionmaker, """ with source_session_maker() as source_session, destination_session_maker() as destination_session: - sorted_tables = pycram.orm.base.Base.metadata.sorted_tables + sorted_tables = Base.metadata.sorted_tables for table in sorted_tables: for value in source_session.query(table).all(): insert_statement = sqlalchemy.insert(table).values(value) @@ -121,7 +120,7 @@ def copy_database(source_session_maker: sqlalchemy.orm.sessionmaker, def update_primary_key_constrains(session_maker: sqlalchemy.orm.sessionmaker): - ''' + """ Iterates through all tables related to any ORM Class and sets in their corresponding foreign keys in the given endpoint to "ON UPDATE CASCADING". @@ -130,15 +129,15 @@ def update_primary_key_constrains(session_maker: sqlalchemy.orm.sessionmaker): :param session_maker: :return: empty - ''' + """ with session_maker() as session: - for table in pycram.orm.base.Base.metadata.sorted_tables: + for table in Base.metadata.sorted_tables: try: foreign_key_statement = sqlalchemy.text( "SELECT con.oid, con.conname, con.contype, con.confupdtype, con.confdeltype, con.confmatchtype, pg_get_constraintdef(con.oid) FROM pg_catalog.pg_constraint con INNER JOIN pg_catalog.pg_class rel ON rel.oid = con.conrelid INNER JOIN pg_catalog.pg_namespace nsp ON nsp.oid = connamespace WHERE rel.relname = '{}';".format( table)) response = session.execute(foreign_key_statement) - rospy.loginfo(25 * '~' + "{}".format(table) + 25 * '~') + loginfo(25 * '~' + "{}".format(table) + 25 * '~') for line in response: if line.conname.endswith("fkey"): if 'a' in line.confupdtype: # a --> no action | if there is no action we set it to cascading @@ -157,7 +156,7 @@ def update_primary_key_constrains(session_maker: sqlalchemy.orm.sessionmaker): alter_statement) # There is no real data coming back for this session.commit() except AttributeError: - rospy.loginfo("Attribute Error: {} has no attribute __tablename__".format(table)) + loginfo("Attribute Error: {} has no attribute __tablename__".format(table)) def migrate_neems(source_session_maker: sqlalchemy.orm.sessionmaker, diff --git a/src/pycram/orm/views.py b/src/pycram/orm/views.py index 51c9136f2..999a4aabe 100644 --- a/src/pycram/orm/views.py +++ b/src/pycram/orm/views.py @@ -1,4 +1,4 @@ -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_base, Mapped, column_property from typing_extensions import Union import sqlalchemy.orm from sqlalchemy import table, inspect, event, select, engine, MetaData, Select, TableClause, ExecutableDDLElement @@ -122,23 +122,11 @@ class PickUpWithContextView(base): 3D Vector for object position """ - __relative_x = (__robot_position.x - __object_position.x) - """ - Distance on x axis between robot and object - """ - - __relative_y = (__robot_position.y - __object_position.y) - """ - Distance on y axis between robot and object - """ - __table__ = view("PickUpWithContextView", Base.metadata, - (select(PickUpAction.id.label("id"), PickUpAction.arm.label("arm"), - PickUpAction.grasp.label("grasp"), RobotState.torso_height.label("torso_height"), - __relative_x.label("relative_x"), __relative_y.label("relative_y"), - Quaternion.x.label("quaternion_x"), Quaternion.y.label("quaternion_y"), - Quaternion.z.label("quaternion_z"), Quaternion.w.label("quaternion_w"), - Object.obj_type.label("obj_type"), TaskTreeNode.status.label("status")) + (select(PickUpAction.id, PickUpAction.arm, PickUpAction.grasp, RobotState.torso_height, + (__robot_position.x-__object_position.x).label("relative_x"), + (__robot_position.y-__object_position.y).label("relative_y"), Quaternion.x, Quaternion.y, + Quaternion.z, Quaternion.w, Object.obj_type, TaskTreeNode.status) .join(TaskTreeNode.action.of_type(PickUpAction)) .join(PickUpAction.robot_state) .join(__robot_pose, RobotState.pose) @@ -147,3 +135,16 @@ class PickUpWithContextView(base): .join(PickUpAction.object) .join(Object.pose) .join(__object_position, Pose.position))) + + id: Mapped[int] = __table__.c.id + arm: Mapped[str] = __table__.c.arm + grasp: Mapped[str] = __table__.c.grasp + torso_height: Mapped[float] = __table__.c.torso_height + relative_x: Mapped[float] = column_property(__table__.c.relative_x) + relative_y: Mapped[float] = column_property(__table__.c.relative_y) + quaternion_x: Mapped[float] = __table__.c.x + quaternion_y: Mapped[float] = __table__.c.y + quaternion_z: Mapped[float] = __table__.c.z + quaternion_w: Mapped[float] = __table__.c.w + obj_type: Mapped[str] = __table__.c.obj_type + status: Mapped[str] = __table__.c.status diff --git a/src/pycram/pose_generator_and_validator.py b/src/pycram/pose_generator_and_validator.py index 87b2d477f..6672b6c6c 100644 --- a/src/pycram/pose_generator_and_validator.py +++ b/src/pycram/pose_generator_and_validator.py @@ -1,5 +1,5 @@ -import tf import numpy as np +import tf from .datastructures.world import World from .world_concepts.world_object import Object @@ -9,8 +9,7 @@ from .datastructures.pose import Pose, Transform from .robot_description import RobotDescription from .external_interfaces.ik import request_ik -from .plan_failures import IKError -from .utils import _apply_ik +from .failures import IKError from typing_extensions import Tuple, List, Union, Dict, Iterable @@ -120,13 +119,13 @@ def visibility_validator(pose: Pose, robot_pose = robot.get_pose() if isinstance(object_or_pose, Object): robot.set_pose(pose) - camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()) + camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link()) robot.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) ray = world.ray_test(camera_pose.position_as_list(), object_or_pose.get_position_as_list()) res = ray == object_or_pose.id else: robot.set_pose(pose) - camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()) + camera_pose = robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link()) robot.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) # TODO: Check if this is correct ray = world.ray_test(camera_pose.position_as_list(), object_or_pose) @@ -186,7 +185,8 @@ def reachability_validator(pose: Pose, res = False arms = [] for description in manipulator_descs: - retract_target_pose = LocalTransformer().transform_pose(target, robot.get_link_tf_frame(description.end_effector.tool_frame)) + retract_target_pose = LocalTransformer().transform_pose(target, robot.get_link_tf_frame( + description.end_effector.tool_frame)) retract_target_pose.position.x -= 0.07 # Care hard coded value copied from PlaceAction class # retract_pose needs to be in world frame? @@ -203,14 +203,14 @@ def reachability_validator(pose: Pose, # test the possible solution and apply it to the robot pose, joint_states = request_ik(target, robot, joints, tool_frame) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) # _apply_ik(robot, resp, joints) in_contact = collision_check(robot, allowed_collision) if not in_contact: # only check for retract pose if pose worked pose, joint_states = request_ik(retract_target_pose, robot, joints, tool_frame) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) # _apply_ik(robot, resp, joints) in_contact = collision_check(robot, allowed_collision) if not in_contact: @@ -218,7 +218,7 @@ def reachability_validator(pose: Pose, except IKError: pass finally: - robot.set_joint_positions(joint_state_before_ik) + robot.set_multiple_joint_positions(joint_state_before_ik) if arms: res = True return res, arms @@ -245,8 +245,7 @@ def collision_check(robot: Object, allowed_collision: Dict[Object, List]): for obj in World.current_world.objects: if obj.name == "floor": continue - in_contact= _in_contact(robot, obj, allowed_collision, allowed_robot_links) - + in_contact = _in_contact(robot, obj, allowed_collision, allowed_robot_links) + if in_contact: + break return in_contact - - diff --git a/src/pycram/process_module.py b/src/pycram/process_module.py index 50ce0ee71..dc3eabab7 100644 --- a/src/pycram/process_module.py +++ b/src/pycram/process_module.py @@ -11,11 +11,11 @@ from abc import ABC from typing_extensions import Callable, Type, Any, Union -import rospy - from .language import Language from .robot_description import RobotDescription from typing_extensions import TYPE_CHECKING +from .datastructures.enums import ExecutionType +from .ros.logging import logerr, logwarn_once if TYPE_CHECKING: from .designators.motion_designator import BaseMotion @@ -26,7 +26,7 @@ class ProcessModule: Implementation of process modules. Process modules are the part that communicate with the outer world to execute designators. """ - execution_delay = True + execution_delay = False """ Adds a delay of 0.5 seconds after executing a process module, to make the execution in simulation more realistic """ @@ -89,7 +89,7 @@ def __enter__(self): sets it to 'real' """ self.pre = ProcessModuleManager.execution_type - ProcessModuleManager.execution_type = "real" + ProcessModuleManager.execution_type = ExecutionType.REAL self.pre_delay = ProcessModule.execution_delay ProcessModule.execution_delay = False @@ -127,7 +127,7 @@ def __enter__(self): sets it to 'simulated' """ self.pre = ProcessModuleManager.execution_type - ProcessModuleManager.execution_type = "simulated" + ProcessModuleManager.execution_type = ExecutionType.SIMULATED def __exit__(self, _type, value, traceback): """ @@ -140,6 +140,41 @@ def __call__(self): return self +class SemiRealRobot: + """ + Management class for executing designators on the semi-real robot. This is intended to be used in a with environment. + When importing this class an instance is imported instead. + + Example: + + .. code-block:: python + + with semi_real_robot: + some designators + """ + + def __init__(self): + self.pre: str = "" + + def __enter__(self): + """ + Entering function for 'with' scope, saves the previously set :py:attr:`~ProcessModuleManager.execution_type` and + sets it to 'semi_real' + """ + self.pre = ProcessModuleManager.execution_type + ProcessModuleManager.execution_type = ExecutionType.SEMI_REAL + + def __exit__(self, type, value, traceback): + """ + Exit method for the 'with' scope, sets the :py:attr:`~ProcessModuleManager.execution_type` to the previously + used one. + """ + ProcessModuleManager.execution_type = self.pre + + def __call__(self): + return self + + def with_real_robot(func: Callable) -> Callable: """ Decorator to execute designators in the decorated class on the real robot. @@ -158,7 +193,7 @@ def plan(): def wrapper(*args, **kwargs): pre = ProcessModuleManager.execution_type - ProcessModuleManager.execution_type = "real" + ProcessModuleManager.execution_type = ExecutionType.REAL ret = func(*args, **kwargs) ProcessModuleManager.execution_type = pre return ret @@ -184,7 +219,7 @@ def plan(): def wrapper(*args, **kwargs): pre = ProcessModuleManager.execution_type - ProcessModuleManager.execution_type = "simulated" + ProcessModuleManager.execution_type = ExecutionType.SIMULATED ret = func(*args, **kwargs) ProcessModuleManager.execution_type = pre return ret @@ -195,6 +230,7 @@ def wrapper(*args, **kwargs): # These are imported, so they don't have to be initialized when executing with simulated_robot = SimulatedRobot() real_robot = RealRobot() +semi_real_robot = SemiRealRobot() class ProcessModuleManager(ABC): @@ -240,14 +276,12 @@ def __init__(self, robot_name): @staticmethod def get_manager() -> Union[ProcessModuleManager, None]: """ - Returns the Process Module manager for the currently loaded robot or None if there is no Manager. - :return: ProcessModuleManager instance of the current robot """ manager = None _default_manager = None if not ProcessModuleManager.execution_type: - rospy.logerr( + logerr( f"No execution_type is set, did you use the with_simulated_robot or with_real_robot decorator?") return @@ -260,17 +294,17 @@ def get_manager() -> Union[ProcessModuleManager, None]: if manager: return manager elif _default_manager: - rospy.logwarn_once(f"No Process Module Manager found for robot: '{RobotDescription.current_robot_description.name}'" + logwarn_once(f"No Process Module Manager found for robot: '{RobotDescription.current_robot_description.name}'" f", using default process modules") return _default_manager else: - rospy.logerr(f"No Process Module Manager found for robot: '{RobotDescription.current_robot_description.name}'" + logerr(f"No Process Module Manager found for robot: '{RobotDescription.current_robot_description.name}'" f", and no default process modules available") return None def navigate(self) -> Type[ProcessModule]: """ - Returns the Process Module for navigating the robot with respect to + Get the Process Module for navigating the robot with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for navigating @@ -280,7 +314,7 @@ def navigate(self) -> Type[ProcessModule]: def pick_up(self) -> Type[ProcessModule]: """ - Returns the Process Module for picking up with respect to the :py:attr:`~ProcessModuleManager.execution_type` + Get the Process Module for picking up with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for picking up an object """ @@ -289,7 +323,7 @@ def pick_up(self) -> Type[ProcessModule]: def place(self) -> Type[ProcessModule]: """ - Returns the Process Module for placing with respect to the :py:attr:`~ProcessModuleManager.execution_type` + Get the Process Module for placing with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for placing an Object """ @@ -298,7 +332,7 @@ def place(self) -> Type[ProcessModule]: def looking(self) -> Type[ProcessModule]: """ - Returns the Process Module for looking at a point with respect to + Get the Process Module for looking at a point with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for looking at a specific point @@ -308,7 +342,7 @@ def looking(self) -> Type[ProcessModule]: def detecting(self) -> Type[ProcessModule]: """ - Returns the Process Module for detecting an object with respect to + Get the Process Module for detecting an object with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for detecting an object @@ -318,7 +352,7 @@ def detecting(self) -> Type[ProcessModule]: def move_tcp(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the Tool Center Point with respect to + Get the Process Module for moving the Tool Center Point with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the TCP @@ -328,7 +362,7 @@ def move_tcp(self) -> Type[ProcessModule]: def move_arm_joints(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the joints of the robot arm + Get the Process Module for moving the joints of the robot arm with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the arm joints @@ -338,7 +372,7 @@ def move_arm_joints(self) -> Type[ProcessModule]: def world_state_detecting(self) -> Type[ProcessModule]: """ - Returns the Process Module for detecting an object using the world state with respect to the + Get the Process Module for detecting an object using the world state with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for world state detecting @@ -348,7 +382,7 @@ def world_state_detecting(self) -> Type[ProcessModule]: def move_joints(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving any joint of the robot with respect to the + Get the Process Module for moving any joint of the robot with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving joints @@ -358,7 +392,7 @@ def move_joints(self) -> Type[ProcessModule]: def move_gripper(self) -> Type[ProcessModule]: """ - Returns the Process Module for moving the gripper with respect to + Get the Process Module for moving the gripper with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for moving the gripper @@ -368,7 +402,7 @@ def move_gripper(self) -> Type[ProcessModule]: def open(self) -> Type[ProcessModule]: """ - Returns the Process Module for opening drawers with respect to + Get the Process Module for opening drawers with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for opening drawers @@ -378,7 +412,7 @@ def open(self) -> Type[ProcessModule]: def close(self) -> Type[ProcessModule]: """ - Returns the Process Module for closing drawers with respect to + Get the Process Module for closing drawers with respect to the :py:attr:`~ProcessModuleManager.execution_type` :return: The Process Module for closing drawers diff --git a/src/pycram/process_modules/boxy_process_modules.py b/src/pycram/process_modules/boxy_process_modules.py index 5dc25d8f9..4cb8b6e89 100644 --- a/src/pycram/process_modules/boxy_process_modules.py +++ b/src/pycram/process_modules/boxy_process_modules.py @@ -89,13 +89,13 @@ def _execute(self, desig): pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("neck_shoulder_link")) if pose_in_shoulder.position.x >= 0 and pose_in_shoulder.position.x >= abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "front")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "front")) if pose_in_shoulder.position.y >= 0 and pose_in_shoulder.position.y >= abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_right")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_right")) if pose_in_shoulder.position.x <= 0 and abs(pose_in_shoulder.position.x) > abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "back")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "back")) if pose_in_shoulder.position.y <= 0 and abs(pose_in_shoulder.position.y) > abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_left")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("neck", "neck_left")) pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("neck_shoulder_link")) @@ -115,7 +115,7 @@ def _execute(self, desig): robot = World.robot gripper = desig.gripper motion = desig.motion - robot.set_joint_positions(RobotDescription.current_robot_description.kinematic_chains[gripper].get_static_gripper_state(motion)) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.kinematic_chains[gripper].get_static_gripper_state(motion)) class BoxyDetecting(ProcessModule): @@ -160,9 +160,9 @@ def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.right_arm_poses: - robot.set_joint_positions(desig.right_arm_poses) + robot.set_multiple_joint_positions(desig.right_arm_poses) if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class BoxyWorldStateDetecting(ProcessModule): @@ -200,29 +200,29 @@ def __init__(self): self._close_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyNavigation(self._navigate_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyMoveHead(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyDetecting(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyMoveTCP(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyMoveArmJoints(self._move_arm_joints_lock) def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyWorldStateDetecting(self._world_state_detecting_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return BoxyMoveGripper(self._move_gripper_lock) diff --git a/src/pycram/process_modules/default_process_modules.py b/src/pycram/process_modules/default_process_modules.py index 7c9a76dd5..40a3dbdc7 100644 --- a/src/pycram/process_modules/default_process_modules.py +++ b/src/pycram/process_modules/default_process_modules.py @@ -196,41 +196,41 @@ def __init__(self): self._close_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultNavigation(self._navigate_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultMoveHead(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultDetecting(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultMoveTCP(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultMoveArmJoints(self._move_arm_joints_lock) def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultWorldStateDetecting(self._world_state_detecting_lock) def move_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultMoveJoints(self._move_joints_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultMoveGripper(self._move_gripper_lock) def open(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultOpen(self._open_lock) def close(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DefaultClose(self._close_lock) diff --git a/src/pycram/process_modules/donbot_process_modules.py b/src/pycram/process_modules/donbot_process_modules.py index e9ff45edc..e39b5bcb3 100644 --- a/src/pycram/process_modules/donbot_process_modules.py +++ b/src/pycram/process_modules/donbot_process_modules.py @@ -69,13 +69,13 @@ def _execute(self, desig): pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("ur5_shoulder_link")) if pose_in_shoulder.position.x >= 0 and pose_in_shoulder.position.x >= abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "front")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "front")) if pose_in_shoulder.position.y >= 0 and pose_in_shoulder.position.y >= abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_right")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_right")) if pose_in_shoulder.position.x <= 0 and abs(pose_in_shoulder.position.x) > abs(pose_in_shoulder.position.y): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "back")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "back")) if pose_in_shoulder.position.y <= 0 and abs(pose_in_shoulder.position.y) > abs(pose_in_shoulder.position.x): - robot.set_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_left")) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_static_joint_chain("left", "arm_left")) pose_in_shoulder = local_transformer.transform_pose(target, robot.get_link_tf_frame("ur5_shoulder_link")) @@ -94,7 +94,7 @@ def _execute(self, desig): robot = World.robot gripper = desig.gripper motion = desig.motion - robot.set_joint_positions(RobotDescription.current_robot_description.get_arm_chain(gripper).get_static_gripper_state(motion)) + robot.set_multiple_joint_positions(RobotDescription.current_robot_description.get_arm_chain(gripper).get_static_gripper_state(motion)) class DonbotMoveTCP(ProcessModule): @@ -118,7 +118,7 @@ class DonbotMoveJoints(ProcessModule): def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class DonbotWorldStateDetecting(ProcessModule): @@ -149,33 +149,33 @@ def __init__(self): self._close_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotNavigation(self._navigate_lock) def place(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotPlace(self._place_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotMoveHead(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotDetecting(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotMoveTCP(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotMoveJoints(self._move_arm_joints_lock) def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotWorldStateDetecting(self._world_state_detecting_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return DonbotMoveGripper(self._move_gripper_lock) diff --git a/src/pycram/process_modules/hsrb_process_modules.py b/src/pycram/process_modules/hsrb_process_modules.py index 5b66d6ec2..ffd1f4287 100644 --- a/src/pycram/process_modules/hsrb_process_modules.py +++ b/src/pycram/process_modules/hsrb_process_modules.py @@ -1,45 +1,22 @@ import numpy as np -import rospy from threading import Lock -from typing import Any +from typing_extensions import Any -from ..datastructures.enums import JointType +from ..datastructures.enums import ExecutionType +from ..external_interfaces.tmc import tmc_gripper_control, tmc_talk from ..robot_description import RobotDescription from ..process_module import ProcessModule -from ..datastructures.pose import Point -from ..utils import _apply_ik -from ..external_interfaces.ik import request_ik -from .. import world_reasoning as btr from ..local_transformer import LocalTransformer from ..designators.motion_designator import * from ..external_interfaces import giskard -from ..world_concepts.world_object import Object from ..datastructures.world import World +from pydub import AudioSegment +from pydub.playback import play +from gtts import gTTS +import io -def calculate_and_apply_ik(robot, gripper: str, target_position: Point, max_iterations: Optional[int] = None): - """ - Calculates the inverse kinematics for the given target pose and applies it to the robot. - """ - target_position_l = [target_position.x, target_position.y, target_position.z] - # TODO: Check if this is correct (getting the arm and using its joints), previously joints was not provided. - arm = "right" if gripper == RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame() else "left" - inv = request_ik(Pose(target_position_l, [0, 0, 0, 1]), - robot, RobotDescription.current_robot_description.kinematic_chains[arm].joints, gripper) - _apply_ik(robot, inv) - - -def _park_arms(arm): - """ - Defines the joint poses for the parking positions of the arms of HSRB and applies them to the - in the World defined robot. - :return: None - """ - - robot = World.robot - if arm == "left": - for joint, pose in RobotDescription.current_robot_description.get_static_joint_chain("left", "park").items(): - robot.set_joint_position(joint, pose) +from ..ros.logging import logdebug class HSRBNavigation(ProcessModule): @@ -52,176 +29,14 @@ def _execute(self, desig: MoveMotion): robot.set_pose(desig.target) -class HSRBMoveHead(ProcessModule): - """ - This process module moves the head to look at a specific point in the world coordinate frame. - This point can either be a position or an object. - """ - - def _execute(self, desig: LookingMotion): - target = desig.target - robot = World.robot - - local_transformer = LocalTransformer() - pose_in_pan = local_transformer.transform_pose(target, robot.get_link_tf_frame("head_pan_link")) - pose_in_tilt = local_transformer.transform_pose(target, robot.get_link_tf_frame("head_tilt_link")) - - new_pan = np.arctan2(pose_in_pan.position.y, pose_in_pan.position.x) - new_tilt = np.arctan2(pose_in_tilt.position.z, pose_in_tilt.position.x ** 2 + pose_in_tilt.position.y ** 2) * -1 - - current_pan = robot.get_joint_position("head_pan_joint") - current_tilt = robot.get_joint_position("head_tilt_joint") - - robot.set_joint_position("head_pan_joint", new_pan + current_pan) - robot.set_joint_position("head_tilt_joint", new_tilt + current_tilt) - - -class HSRBMoveGripper(ProcessModule): - """ - This process module controls the gripper of the robot. They can either be opened or closed. - Furthermore, it can only moved one gripper at a time. - """ - - def _execute(self, desig: MoveGripperMotion): - robot = World.robot - gripper = desig.gripper - motion = desig.motion - for joint, state in RobotDescription.current_robot_description.get_arm_chain(gripper).get_static_gripper_state(motion).items(): - robot.set_joint_position(joint, state) - - class HSRBDetecting(ProcessModule): """ This process module tries to detect an object with the given type. To be detected the object has to be in the field of view of the robot. """ - - def _execute(self, desig: DetectingMotion): - rospy.loginfo("Detecting technique: {}".format(desig.technique)) - robot = World.robot - object_type = desig.object_type - # Should be "wide_stereo_optical_frame" - cam_frame_name = RobotDescription.current_robot_description.get_camera_frame() - # should be [0, 0, 1] - front_facing_axis = RobotDescription.current_robot_description.get_default_camera().front_facing_axis - # if desig.technique == 'all': - # rospy.loginfo("Fake detecting all generic objects") - # objects = BulletWorld.current_bullet_world.get_all_objets_not_robot() - # elif desig.technique == 'human': - # rospy.loginfo("Fake detecting human -> spawn 0,0,0") - # human = [] - # human.append(Object("human", ObjectType.HUMAN, "human_male.stl", pose=Pose([0, 0, 0]))) - # object_dict = {} - # - # # Iterate over the list of objects and store each one in the dictionary - # for i, obj in enumerate(human): - # object_dict[obj.name] = obj - # return object_dict - # - # else: - # rospy.loginfo("Fake -> Detecting specific object type") - objects = World.current_world.get_object_by_type(object_type) - - object_dict = {} - - perceived_objects = [] - for obj in objects: - if btr.visible(obj, robot.get_link_pose(cam_frame_name), front_facing_axis): - return obj - - -class HSRBMoveTCP(ProcessModule): - """ - This process moves the tool center point of either the right or the left arm. - """ - - def _execute(self, desig: MoveTCPMotion): - target = desig.target - robot = World.robot - - _move_arm_tcp(target, robot, desig.arm) - - -class HSRBMoveArmJoints(ProcessModule): - """ - This process modules moves the joints of either the right or the left arm. The joint states can be given as - list that should be applied or a pre-defined position can be used, such as "parking" - """ - - def _execute(self, desig: MoveArmJointsMotion): - - robot = World.robot - if desig.right_arm_poses: - robot.set_joint_positions(desig.right_arm_poses) - if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) - - -class HSRBMoveJoints(ProcessModule): - """ - Process Module for generic joint movements, is not confined to the arms but can move any joint of the robot - """ - - def _execute(self, desig: MoveJointsMotion): - robot = World.robot - robot.set_joint_positions(dict(zip(desig.names, desig.positions))) - - -class HSRBWorldStateDetecting(ProcessModule): - """ - This process module detectes an object even if it is not in the field of view of the robot. - """ - - def _execute(self, desig: WorldStateDetectingMotion): - obj_type = desig.object_type - return list(filter(lambda obj: obj.obj_type == obj_type, World.current_world.objects))[0] - - -class HSRBOpen(ProcessModule): - """ - Low-level implementation of opening a container in the simulation. Assumes the handle is already grasped. - """ - - def _execute(self, desig: OpeningMotion): - part_of_object = desig.object_part.world_object - - container_joint = part_of_object.find_joint_above_link(desig.object_part.name, JointType.PRISMATIC) - - goal_pose = btr.link_pose_for_joint_config(part_of_object, { - container_joint: part_of_object.get_joint_limits(container_joint)[1] - 0.05}, desig.object_part.name) - - _move_arm_tcp(goal_pose, World.robot, desig.arm) - - desig.object_part.world_object.set_joint_position(container_joint, - part_of_object.get_joint_limits(container_joint)[1]) - - -class HSRBClose(ProcessModule): - """ - Low-level implementation that lets the robot close a grasped container, in simulation - """ - - def _execute(self, desig: ClosingMotion): - part_of_object = desig.object_part.world_object - - container_joint = part_of_object.find_joint_above_link(desig.object_part.name, JointType.PRISMATIC) - - goal_pose = btr.link_pose_for_joint_config(part_of_object, { - container_joint: part_of_object.get_joint_limits(container_joint)[0]}, desig.object_part.name) - - _move_arm_tcp(goal_pose, World.robot, desig.arm) - - desig.object_part.world_object.set_joint_position(container_joint, - part_of_object.get_joint_limits(container_joint)[0]) - - -def _move_arm_tcp(target: Pose, robot: Object, arm: Arms) -> None: - gripper = RobotDescription.current_robot_description.get_arm_chain(arm).get_tool_frame() - - joints = RobotDescription.current_robot_description.get_arm_chain(arm).joints - - inv = request_ik(target, robot, joints, gripper) - _apply_ik(robot, inv) + # pass + def _execute(self, desig: DetectingMotion) -> Any: + pass ########################################################### @@ -235,47 +50,20 @@ class HSRBNavigationReal(ProcessModule): """ def _execute(self, designator: MoveMotion) -> Any: - rospy.logdebug(f"Sending goal to giskard to Move the robot") + logdebug(f"Sending goal to giskard to Move the robot") # giskard.achieve_cartesian_goal(designator.target, robot_description.base_link, "map") - queryPoseNav(designator.target) - - -class HSRBNavigationSemiReal(ProcessModule): - """ - Process module for the real HSRB that sends a cartesian goal to giskard to move the robot base - """ - - def _execute(self, designator: MoveMotion) -> Any: - rospy.logdebug(f"Sending goal to giskard to Move the robot") - giskard.achieve_cartesian_goal(designator.target, RobotDescription.current_robot_description.base_link, "map") + # todome fix this # queryPoseNav(designator.target) class HSRBMoveHeadReal(ProcessModule): """ - Process module for the real robot to move that such that it looks at the given position. Uses the same calculation - as the simulated one + Process module for the real HSRB that sends a pose goal to giskard to move the robot head """ def _execute(self, desig: LookingMotion): target = desig.target - robot = World.robot - - local_transformer = LocalTransformer() - pose_in_pan = local_transformer.transform_pose(target, robot.get_link_tf_frame("head_pan_link")) - pose_in_tilt = local_transformer.transform_pose(target, robot.get_link_tf_frame("head_tilt_link")) - - new_pan = np.arctan2(pose_in_pan.position.y, pose_in_pan.position.x) - new_tilt = np.arctan2(pose_in_tilt.position.z, pose_in_tilt.position.x + pose_in_tilt.position.y) - - current_pan = robot.get_joint_position("head_pan_joint") - current_tilt = robot.get_joint_position("head_tilt_joint") - - giskard.avoid_all_collisions() - giskard.achieve_joint_goal( - {"head_pan_joint": new_pan + current_pan, "head_tilt_joint": new_tilt + current_tilt}) - giskard.achieve_joint_goal( - {"head_pan_joint": new_pan + current_pan, "head_tilt_joint": new_tilt + current_tilt}) + giskard.move_head_to_pose(target) class HSRBDetectingReal(ProcessModule): @@ -285,98 +73,12 @@ class HSRBDetectingReal(ProcessModule): """ def _execute(self, desig: DetectingMotion) -> Any: - # todo at the moment perception ignores searching for a specific object type so we do as well on real - if desig.technique == 'human' and (desig.state == "start" or desig.state == None): - human_pose = queryHuman() - pose = Pose.from_pose_stamped(human_pose) - pose.position.z = 0 - human = [] - human.append(Object("human", ObjectType.HUMAN, "human_male.stl", pose=pose)) - object_dict = {} - - # Iterate over the list of objects and store each one in the dictionary - for i, obj in enumerate(human): - object_dict[obj.name] = obj - return object_dict - - return human_pose - elif desig.technique == 'human' and desig.state == "stop": - stop_queryHuman() - return "stopped" - - query_result = queryEmpty(ObjectDesignatorDescription(types=[desig.object_type])) - perceived_objects = [] - for i in range(0, len(query_result.res)): - # this has to be pose from pose stamped since we spawn the object with given header - obj_pose = Pose.from_pose_stamped(query_result.res[i].pose[0]) - # obj_pose.orientation = [0, 0, 0, 1] - # obj_pose_tmp = query_result.res[i].pose[0] - obj_type = query_result.res[i].type - obj_size = query_result.res[i].shape_size - obj_color = query_result.res[i].color[0] - color_switch = { - "red": [1, 0, 0, 1], - "green": [0, 1, 0, 1], - "blue": [0, 0, 1, 1], - "black": [0, 0, 0, 1], - "white": [1, 1, 1, 1], - # add more colors if needed - } - color = color_switch.get(obj_color) - if color is None: - color = [0, 0, 0, 1] - - # atm this is the string size that describes the object but it is not the shape size thats why string - def extract_xyz_values(input_string): - # Split the input string by commas and colon to separate key-value pairs - # key_value_pairs = input_string.split(', ') - - # Initialize variables to store the X, Y, and Z values - x_value = None - y_value = None - z_value = None - - for key in input_string: - x_value = key.dimensions.x - y_value = key.dimensions.y - z_value = key.dimensions.z - - # - # # Iterate through the key-value pairs to extract the values - # for pair in key_value_pairs: - # key, value = pair.split(': ') - # if key == 'x': - # x_value = float(value) - # elif key == 'y': - # y_value = float(value) - # elif key == 'z': - # z_value = float(value) - - return x_value, y_value, z_value - - x, y, z = extract_xyz_values(obj_size) - size = (x, z / 2, y) - size_box = (x / 2, z / 2, y / 2) - hard_size = (0.02, 0.02, 0.03) - id = World.current_world.add_rigid_box(obj_pose, hard_size, color) - box_object = Object(obj_type + "_" + str(rospy.get_time()), obj_type, pose=obj_pose, color=color, id=id, - customGeom={"size": [hard_size[0], hard_size[1], hard_size[2]]}) - box_object.set_pose(obj_pose) - box_desig = ObjectDesignatorDescription.Object(box_object.name, box_object.type, box_object) - - perceived_objects.append(box_desig) - - object_dict = {} - - # Iterate over the list of objects and store each one in the dictionary - for i, obj in enumerate(perceived_objects): - object_dict[obj.name] = obj - return object_dict + pass class HSRBMoveTCPReal(ProcessModule): """ - Moves the tool center point of the real HSRB while avoiding all collisions + Moves the tool center point of the real HSRB while avoiding all collisions via giskard """ def _execute(self, designator: MoveTCPMotion) -> Any: @@ -385,13 +87,13 @@ def _execute(self, designator: MoveTCPMotion) -> Any: giskard.avoid_all_collisions() if designator.allow_gripper_collision: giskard.allow_gripper_collision(designator.arm) - giskard.achieve_cartesian_goal(pose_in_map, RobotDescription.current_robot_description.get_arm_chain(designator.arm).get_tool_frame(), - "map") + giskard.achieve_cartesian_goal(pose_in_map, RobotDescription.current_robot_description.get_arm_chain( + designator.arm).get_tool_frame(), "map") class HSRBMoveArmJointsReal(ProcessModule): """ - Moves the arm joints of the real HSRB to the given configuration while avoiding all collisions + Moves the arm joints of the real HSRB to the given configuration while avoiding all collisions via giskard """ def _execute(self, designator: MoveArmJointsMotion) -> Any: @@ -419,67 +121,84 @@ class HSRBMoveGripperReal(ProcessModule): """ def _execute(self, designator: MoveGripperMotion) -> Any: - if (designator.motion == "open"): - pub_gripper = rospy.Publisher('/hsrb/gripper_controller/grasp/goal', GripperApplyEffortActionGoal, - queue_size=10) - rate = rospy.Rate(10) - rospy.sleep(2) - msg = GripperApplyEffortActionGoal() # sprechen joint gripper_controll_manager an, indem wir goal publishen type den giskard fürs greifen erwartet - msg.goal.effort = 0.8 - pub_gripper.publish(msg) - - elif (designator.motion == "close"): - pub_gripper = rospy.Publisher('/hsrb/gripper_controller/grasp/goal', GripperApplyEffortActionGoal, - queue_size=10) - rate = rospy.Rate(10) - rospy.sleep(2) - msg = GripperApplyEffortActionGoal() - msg.goal.effort = -0.8 - pub_gripper.publish(msg) - - # if designator.allow_gripper_collision: - # giskard.allow_gripper_collision("left") - # giskard.achieve_gripper_motion_goal(designator.motion) + tmc_gripper_control(designator) class HSRBOpenReal(ProcessModule): """ - Tries to open an already grasped container + This process Modules tries to open an already grasped container via giskard """ def _execute(self, designator: OpeningMotion) -> Any: - giskard.achieve_open_container_goal(RobotDescription.current_robot_description.get_arm_chain(designator.arm).get_tool_frame(), - designator.object_part.name) + giskard.achieve_open_container_goal( + RobotDescription.current_robot_description.get_arm_chain(designator.arm).get_tool_frame(), + designator.object_part.name) class HSRBCloseReal(ProcessModule): """ - Tries to close an already grasped container + This process module executes close a an already grasped container via giskard """ def _execute(self, designator: ClosingMotion) -> Any: - giskard.achieve_close_container_goal(RobotDescription.current_robot_description.get_arm_chain(designator.arm).get_tool_frame(), - designator.object_part.name) - - -# class HSRBTalkReal(ProcessModule): -# """ -# Tries to close an already grasped container -# """ -# -# def _execute(self, designator: TalkingMotion.Motion) -> Any: -# pub = rospy.Publisher('/talk_request', Voice, queue_size=10) -# -# # fill message of type Voice with required data: -# texttospeech = Voice() -# # language 1 = english (0 = japanese) -# texttospeech.language = 1 -# texttospeech.sentence = designator.cmd -# -# rospy.sleep(1) -# pub.publish(texttospeech) + giskard.achieve_close_container_goal( + RobotDescription.current_robot_description.get_arm_chain(designator.arm).get_tool_frame(), + designator.object_part.name) + + +class HSRBTalkReal(ProcessModule): + """ + Let the robot speak over tmc interface. + """ + + def _execute(self, designator: TalkingMotion) -> Any: + tmc_talk(designator) + + +########################################################### +########## Process Modules for the Semi Real HSRB ############### +########################################################### +class HSRBNavigationSemiReal(ProcessModule): + """ + Process module for the real HSRB that sends a cartesian goal to giskard to move the robot base + """ + def _execute(self, designator: MoveMotion) -> Any: + logdebug(f"Sending goal to giskard to Move the robot") + giskard.teleport_robot(designator.target) + + +class HSRBTalkSemiReal(ProcessModule): + """ + Low Level implementation to let the robot talk using gTTS and pydub. + """ + + def _execute(self, designator: TalkingMotion) -> Any: + """ + Convert text to speech using gTTS, modify the pitch and play it without saving to disk. + """ + sentence = designator.cmd + # Create a gTTS object + tts = gTTS(text=sentence, lang='en', slow=False) + + # Save the speech to an in-memory file + mp3_fp = io.BytesIO() + tts.write_to_fp(mp3_fp) + mp3_fp.seek(0) + # Load the audio into pydub from the in-memory file + audio = AudioSegment.from_file(mp3_fp, format="mp3") + + # Speed up the audio slightly + faster_audio = audio.speedup(playback_speed=1.2) + + # Play the modified audio + play(faster_audio) + + +########################################################### +########## HSRB MANAGER ############### +########################################################### class HSRBManager(ProcessModuleManager): def __init__(self): @@ -499,85 +218,65 @@ def __init__(self): self._talk_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return HSRBNavigation(self._navigate_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBNavigationReal(self._navigate_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBNavigationSemiReal(self._navigate_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBMoveHead(self._looking_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBMoveHeadReal(self._looking_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBMoveHeadReal(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return HSRBDetecting(self._detecting_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBDetectingReal(self._detecting_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBDetecting(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBMoveTCP(self._move_tcp_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBMoveTCPReal(self._move_tcp_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBMoveTCPReal(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBMoveArmJoints(self._move_arm_joints_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBMoveArmJointsReal(self._move_arm_joints_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBMoveArmJointsReal(self._move_arm_joints_lock) - def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated" or ProcessModuleManager.execution_type == "real": - return HSRBWorldStateDetecting(self._world_state_detecting_lock) - elif ProcessModuleManager.execution_type == "semi_real": - return HSRBWorldStateDetecting(self._world_state_detecting_lock) - def move_joints(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBMoveJoints(self._move_joints_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBMoveJointsReal(self._move_joints_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBMoveJointsReal(self._move_joints_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBMoveGripper(self._move_gripper_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBMoveGripperReal(self._move_gripper_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBMoveGripperReal(self._move_gripper_lock) def open(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBOpen(self._open_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBOpenReal(self._open_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBOpenReal(self._open_lock) def close(self): - if ProcessModuleManager.execution_type == "simulated": - return HSRBClose(self._close_lock) - elif ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.REAL: return HSRBCloseReal(self._close_lock) - elif ProcessModuleManager.execution_type == "semi_real": + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: return HSRBCloseReal(self._close_lock) - # def talk(self): - # if ProcessModuleManager.execution_type == "real": - # return HSRBTalkReal(self._talk_lock) - # elif ProcessModuleManager.execution_type == "semi_real": - # return HSRBTalkReal(self._talk_lock) + def talk(self): + if ProcessModuleManager.execution_type == ExecutionType.REAL: + return HSRBTalkReal(self._talk_lock) + elif ProcessModuleManager.execution_type == ExecutionType.SEMI_REAL: + return HSRBTalkSemiReal(self._talk_lock) diff --git a/src/pycram/process_modules/pr2_process_modules.py b/src/pycram/process_modules/pr2_process_modules.py index 0ac65ba8d..099662bf9 100644 --- a/src/pycram/process_modules/pr2_process_modules.py +++ b/src/pycram/process_modules/pr2_process_modules.py @@ -5,10 +5,10 @@ from .. import world_reasoning as btr import numpy as np -import rospy from ..process_module import ProcessModule, ProcessModuleManager from ..external_interfaces.ik import request_ik +from ..ros.logging import logdebug from ..utils import _apply_ik from ..local_transformer import LocalTransformer from ..designators.object_designator import ObjectDesignatorDescription @@ -19,9 +19,9 @@ from ..datastructures.world import World from ..world_concepts.world_object import Object from ..datastructures.pose import Pose -from ..datastructures.enums import JointType, ObjectType, Arms +from ..datastructures.enums import JointType, ObjectType, Arms, ExecutionType from ..external_interfaces import giskard -from ..external_interfaces.robokudo import query +from ..external_interfaces.robokudo import * try: from pr2_controllers_msgs.msg import Pr2GripperCommandGoal, Pr2GripperCommandAction, Pr2 @@ -87,7 +87,7 @@ def _execute(self, desig: DetectingMotion): robot = World.robot object_type = desig.object_type # Should be "wide_stereo_optical_frame" - cam_frame_name = RobotDescription.current_robot_description.get_camera_frame() + camera_link_name = RobotDescription.current_robot_description.get_camera_link() # should be [0, 0, 1] camera_description = RobotDescription.current_robot_description.cameras[ list(RobotDescription.current_robot_description.cameras.keys())[0]] @@ -95,7 +95,7 @@ def _execute(self, desig: DetectingMotion): objects = World.current_world.get_object_by_type(object_type) for obj in objects: - if btr.visible(obj, robot.get_link_pose(cam_frame_name), front_facing_axis): + if btr.visible(obj, robot.get_link_pose(camera_link_name), front_facing_axis): return obj @@ -121,9 +121,9 @@ def _execute(self, desig: MoveArmJointsMotion): robot = World.robot if desig.right_arm_poses: - robot.set_joint_positions(desig.right_arm_poses) + robot.set_multiple_joint_positions(desig.right_arm_poses) if desig.left_arm_poses: - robot.set_joint_positions(desig.left_arm_poses) + robot.set_multiple_joint_positions(desig.left_arm_poses) class PR2MoveJoints(ProcessModule): @@ -133,7 +133,7 @@ class PR2MoveJoints(ProcessModule): def _execute(self, desig: MoveJointsMotion): robot = World.robot - robot.set_joint_positions(dict(zip(desig.names, desig.positions))) + robot.set_multiple_joint_positions(dict(zip(desig.names, desig.positions))) class Pr2WorldStateDetecting(ProcessModule): @@ -206,7 +206,7 @@ class Pr2NavigationReal(ProcessModule): """ def _execute(self, designator: MoveMotion) -> Any: - rospy.logdebug(f"Sending goal to giskard to Move the robot") + logdebug(f"Sending goal to giskard to Move the robot") giskard.achieve_cartesian_goal(designator.target, RobotDescription.current_robot_description.base_link, "map") @@ -315,10 +315,10 @@ class Pr2MoveGripperReal(ProcessModule): def _execute(self, designator: MoveGripperMotion) -> Any: def activate_callback(): - rospy.loginfo("Started gripper Movement") + loginfo("Started gripper Movement") def done_callback(state, result): - rospy.loginfo(f"Reached goal {designator.motion}: {result.reached_goal}") + loginfo(f"Reached goal {designator.motion}: {result.reached_goal}") def feedback_callback(msg): pass @@ -331,7 +331,7 @@ def feedback_callback(msg): else: controller_topic = "l_gripper_controller/gripper_action" client = actionlib.SimpleActionClient(controller_topic, Pr2GripperCommandAction) - rospy.loginfo("Waiting for action server") + loginfo("Waiting for action server") client.wait_for_server() client.send_goal(goal, active_cb=activate_callback, done_cb=done_callback, feedback_cb=feedback_callback) wait = client.wait_for_result() @@ -375,59 +375,60 @@ def __init__(self): self._close_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2Navigation(self._navigate_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2NavigationReal(self._navigate_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2MoveHead(self._looking_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2MoveHeadReal(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2Detecting(self._detecting_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2DetectingReal(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2MoveTCP(self._move_tcp_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2MoveTCPReal(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2MoveArmJoints(self._move_arm_joints_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2MoveArmJointsReal(self._move_arm_joints_lock) def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated" or ProcessModuleManager.execution_type == "real": + if (ProcessModuleManager.execution_type == ExecutionType.SIMULATED or + ProcessModuleManager.execution_type == ExecutionType.REAL): return Pr2WorldStateDetecting(self._world_state_detecting_lock) def move_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return PR2MoveJoints(self._move_joints_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2MoveJointsReal(self._move_joints_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2MoveGripper(self._move_gripper_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2MoveGripperReal(self._move_gripper_lock) def open(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2Open(self._open_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2OpenReal(self._open_lock) def close(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return Pr2Close(self._close_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return Pr2CloseReal(self._close_lock) diff --git a/src/pycram/process_modules/stretch_process_modules.py b/src/pycram/process_modules/stretch_process_modules.py index 77b33596a..028b8f333 100644 --- a/src/pycram/process_modules/stretch_process_modules.py +++ b/src/pycram/process_modules/stretch_process_modules.py @@ -1,8 +1,7 @@ from typing import Any -import rospy - -from ..external_interfaces.robokudo import query +from ..external_interfaces.robokudo import * +from ..ros.logging import logdebug from ..utils import _apply_ik from ..external_interfaces import giskard from .default_process_modules import * @@ -132,7 +131,7 @@ def _move_arm_tcp(target: Pose, robot: Object, arm: Arms) -> None: # inv = request_ik(target, robot, joints, gripper) pose, joint_states = request_giskard_ik(target, robot, gripper) robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) ########################################################### @@ -146,7 +145,7 @@ class StretchNavigationReal(ProcessModule): """ def _execute(self, designator: MoveMotion) -> Any: - rospy.logdebug(f"Sending goal to giskard to Move the robot") + logdebug(f"Sending goal to giskard to Move the robot") giskard.achieve_cartesian_goal(designator.target, RobotDescription.current_robot_description.base_link, "map") @@ -295,59 +294,59 @@ def __init__(self): self._close_lock = Lock() def navigate(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchNavigate(self._navigate_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchNavigationReal(self._navigate_lock) def looking(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchMoveHead(self._looking_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchMoveHeadReal(self._looking_lock) def detecting(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchDetecting(self._detecting_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchDetectingReal(self._detecting_lock) def move_tcp(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchMoveTCP(self._move_tcp_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchMoveTCPReal(self._move_tcp_lock) def move_arm_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchMoveArmJoints(self._move_arm_joints_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchMoveArmJointsReal(self._move_arm_joints_lock) def world_state_detecting(self): - if ProcessModuleManager.execution_type == "simulated" or ProcessModuleManager.execution_type == "real": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED or ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchWorldStateDetecting(self._world_state_detecting_lock) def move_joints(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchMoveJoints(self._move_joints_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchMoveJointsReal(self._move_joints_lock) def move_gripper(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchMoveGripper(self._move_gripper_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchMoveGripperReal(self._move_gripper_lock) def open(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchOpen(self._open_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchOpenReal(self._open_lock) def close(self): - if ProcessModuleManager.execution_type == "simulated": + if ProcessModuleManager.execution_type == ExecutionType.SIMULATED: return StretchClose(self._close_lock) - elif ProcessModuleManager.execution_type == "real": + elif ProcessModuleManager.execution_type == ExecutionType.REAL: return StretchCloseReal(self._close_lock) diff --git a/src/pycram/robot_description.py b/src/pycram/robot_description.py index ce9805395..495267a3a 100644 --- a/src/pycram/robot_description.py +++ b/src/pycram/robot_description.py @@ -1,12 +1,13 @@ # used for delayed evaluation of typing until python 3.11 becomes mainstream from __future__ import annotations - -import rospy from typing_extensions import List, Dict, Union, Optional -from urdf_parser_py.urdf import URDF +from .datastructures.dataclasses import VirtualMobileBaseJoints +from .datastructures.enums import Arms, Grasp, GripperState, GripperType, JointType +from .object_descriptors.urdf import ObjectDescription as URDFObject +from .ros.logging import logerr from .utils import suppress_stdout_stderr -from .datastructures.enums import Arms, Grasp, GripperState, GripperType +from .helper import parse_mjcf_actuators class RobotDescriptionManager: @@ -42,7 +43,12 @@ def load_description(self, name: str): RobotDescription.current_robot_description = self.descriptions[name] return self.descriptions[name] else: - rospy.logerr(f"Robot description {name} not found") + for key in self.descriptions.keys(): + if key in name.lower(): + RobotDescription.current_robot_description = self.descriptions[key] + return self.descriptions[key] + else: + logerr(f"Robot description {name} not found") def register_description(self, description: RobotDescription): """ @@ -81,7 +87,7 @@ class RobotDescription: """ Torso joint of the robot """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -105,8 +111,14 @@ class RobotDescription: """ All joints defined in the URDF, by default fixed joints are not included """ + virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = None + """ + Virtual mobile base joint names for mobile robots, these joints are not part of the URDF, however they are used to + move the robot in the simulation (e.g. set_pose for the robot would actually move these joints) + """ - def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, urdf_path: str): + def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, urdf_path: str, + virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = None, mjcf_path: Optional[str] = None): """ Initialize the RobotDescription. The URDF is loaded from the given path and used as basis for the kinematic chains. @@ -116,6 +128,8 @@ def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, :param torso_link: Torso link of the robot :param torso_joint: Torso joint of the robot, this is the joint that moves the torso upwards if there is one :param urdf_path: Path to the URDF file of the robot + :param virtual_mobile_base_joints: Virtual mobile base joint names for mobile robots + :param mjcf_path: Path to the MJCF file of the robot """ self.name = name self.base_link = base_link @@ -123,12 +137,35 @@ def __init__(self, name: str, base_link: str, torso_link: str, torso_joint: str, self.torso_joint = torso_joint with suppress_stdout_stderr(): # Since parsing URDF causes a lot of warning messages which can't be deactivated, we suppress them - self.urdf_object = URDF.from_xml_file(urdf_path) + self.urdf_object = URDFObject(urdf_path) + self.joint_types = {joint.name: joint.type for joint in self.urdf_object.joints} + self.joint_actuators: Optional[Dict] = parse_mjcf_actuators(mjcf_path) if mjcf_path is not None else None self.kinematic_chains: Dict[str, KinematicChainDescription] = {} self.cameras: Dict[str, CameraDescription] = {} self.grasps: Dict[Grasp, List[float]] = {} self.links: List[str] = [l.name for l in self.urdf_object.links] self.joints: List[str] = [j.name for j in self.urdf_object.joints] + self.virtual_mobile_base_joints: Optional[VirtualMobileBaseJoints] = virtual_mobile_base_joints + + @property + def has_actuators(self): + """ + Property to check if the robot has actuators defined in the MJCF file. + + :return: True if the robot has actuators, False otherwise + """ + return self.joint_actuators is not None + + def get_actuator_for_joint(self, joint: str) -> Optional[str]: + """ + Get the actuator name for a given joint. + + :param joint: Name of the joint + :return: Name of the actuator + """ + if self.has_actuators: + return self.joint_actuators.get(joint) + return None def add_kinematic_chain_description(self, chain: KinematicChainDescription): """ @@ -200,7 +237,7 @@ def add_grasp_orientations(self, orientations: Dict[Grasp, List[float]]): def get_manipulator_chains(self) -> List[KinematicChainDescription]: """ - Returns a list of all manipulator chains of the robot which posses an end effector. + Get a list of all manipulator chains of the robot which posses an end effector. :return: A list of KinematicChainDescription objects """ @@ -210,7 +247,7 @@ def get_manipulator_chains(self) -> List[KinematicChainDescription]: result.append(chain) return result - def get_camera_frame(self) -> str: + def get_camera_link(self) -> str: """ Quick method to get the name of a link of a camera. Uses the first camera in the list of cameras. @@ -218,9 +255,17 @@ def get_camera_frame(self) -> str: """ return self.cameras[list(self.cameras.keys())[0]].link_name + def get_camera_frame(self) -> str: + """ + Quick method to get the name of a link of a camera. Uses the first camera in the list of cameras. + + :return: A name of the link of a camera + """ + return f"{self.name}/{self.cameras[list(self.cameras.keys())[0]].link_name}" + def get_default_camera(self) -> CameraDescription: """ - Returns the first camera in the list of cameras. + Get the first camera in the list of cameras. :return: A CameraDescription object """ @@ -228,7 +273,7 @@ def get_default_camera(self) -> CameraDescription: def get_static_joint_chain(self, kinematic_chain_name: str, configuration_name: str): """ - Returns the static joint states of a kinematic chain for a specific configuration. When trying to access one of + Get the static joint states of a kinematic chain for a specific configuration. When trying to access one of the robot arms the function `:func: get_arm_chain` should be used. :param kinematic_chain_name: @@ -246,7 +291,7 @@ def get_static_joint_chain(self, kinematic_chain_name: str, configuration_name: def get_parent(self, name: str) -> str: """ - Returns the parent of a link or joint in the URDF. Always returns the imeadiate parent, for a link this is a joint + Get the parent of a link or joint in the URDF. Always returns the imeadiate parent, for a link this is a joint and vice versa. :param name: Name of the link or joint in the URDF @@ -266,7 +311,7 @@ def get_parent(self, name: str) -> str: def get_child(self, name: str, return_multiple_children: bool = False) -> Union[str, List[str]]: """ - Returns the child of a link or joint in the URDF. Always returns the immediate child, for a link this is a joint + Get the child of a link or joint in the URDF. Always returns the immediate child, for a link this is a joint and vice versa. Since a link can have multiple children, the return_multiple_children parameter can be set to True to get a list of all children. @@ -293,9 +338,19 @@ def get_child(self, name: str, return_multiple_children: bool = False) -> Union[ child_link = self.urdf_object.joint_map[name].child return child_link + def get_arm_tool_frame(self, arm: Arms) -> str: + """ + Get the name of the tool frame of a specific arm. + + :param arm: Arm for which the tool frame should be returned + :return: The name of the link of the tool frame in the URDF. + """ + chain = self.get_arm_chain(arm) + return chain.get_tool_frame() + def get_arm_chain(self, arm: Arms) -> Union[KinematicChainDescription, List[KinematicChainDescription]]: """ - Returns the kinematic chain of a specific arm. If the arm is set to BOTH, all kinematic chains are returned. + Get the kinematic chain of a specific arm. If the arm is set to BOTH, all kinematic chains are returned. :param arm: Arm for which the chain should be returned :return: KinematicChainDescription object of the arm @@ -329,7 +384,7 @@ class KinematicChainDescription: """ Last link of the chain """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -358,7 +413,7 @@ class KinematicChainDescription: Dictionary of static joint states for the chain """ - def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDF, arm_type: Arms = None, + def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDFObject, arm_type: Arms = None, include_fixed_joints=False): """ Initialize the KinematicChainDescription object. @@ -373,7 +428,7 @@ def __init__(self, name: str, start_link: str, end_link: str, urdf_object: URDF, self.name: str = name self.start_link: str = start_link self.end_link: str = end_link - self.urdf_object: URDF = urdf_object + self.urdf_object: URDFObject = urdf_object self.include_fixed_joints: bool = include_fixed_joints self.link_names: List[str] = [] self.joint_names: List[str] = [] @@ -395,11 +450,12 @@ def _init_joints(self): Initializes the joints of the chain by getting the chain from the URDF object. """ joints = self.urdf_object.get_chain(self.start_link, self.end_link, links=False) - self.joint_names = list(filter(lambda j: self.urdf_object.joint_map[j].type != "fixed" or self.include_fixed_joints, joints)) + self.joint_names = list(filter(lambda j: self.urdf_object.joint_map[j].type != JointType.FIXED + or self.include_fixed_joints, joints)) def get_joints(self) -> List[str]: """ - Returns a list of all joints of the chain. + Get a list of all joints of the chain. :return: List of joint names """ @@ -407,9 +463,7 @@ def get_joints(self) -> List[str]: def get_links(self) -> List[str]: """ - Returns a list of all links of the chain. - - :return: List of link names + :return: A list of all links of the chain. """ return self.link_names @@ -445,7 +499,7 @@ def add_static_joint_states(self, name: str, states: dict): def get_static_joint_states(self, name: str) -> Dict[str, float]: """ - Returns the dictionary of static joint states for a given name of the static joint states. + Get the dictionary of static joint states for a given name of the static joint states. :param name: Name of the static joint states :return: Dictionary of joint names and their values @@ -453,11 +507,11 @@ def get_static_joint_states(self, name: str) -> Dict[str, float]: try: return self.static_joint_states[name] except KeyError: - rospy.logerr(f"Static joint states for chain {name} not found") + logerr(f"Static joint states for chain {name} not found") def get_tool_frame(self) -> str: """ - Returns the name of the tool frame of the end effector of this chain, if it has an end effector. + Get the name of the tool frame of the end effector of this chain, if it has an end effector. :return: The name of the link of the tool frame in the URDF. """ @@ -468,7 +522,7 @@ def get_tool_frame(self) -> str: def get_static_gripper_state(self, state: GripperState) -> Dict[str, float]: """ - Returns the static joint states for the gripper of the chain. + Get the static joint states for the gripper of the chain. :param state: Name of the static joint states :return: Dictionary of joint names and their values @@ -552,7 +606,7 @@ class EndEffectorDescription: """ Name of the tool frame link in the URDf """ - urdf_object: URDF + urdf_object: URDFObject """ Parsed URDF of the robot """ @@ -577,7 +631,7 @@ class EndEffectorDescription: Distance the gripper can open, in cm """ - def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URDF): + def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URDFObject): """ Initialize the EndEffectorDescription object. @@ -589,7 +643,7 @@ def __init__(self, name: str, start_link: str, tool_frame: str, urdf_object: URD self.name: str = name self.start_link: str = start_link self.tool_frame: str = tool_frame - self.urdf_object: URDF = urdf_object + self.urdf_object: URDFObject = urdf_object self.link_names: List[str] = [] self.joint_names: List[str] = [] self.static_joint_states: Dict[GripperState, Dict[str, float]] = {} diff --git a/src/pycram/robot_descriptions/__init__.py b/src/pycram/robot_descriptions/__init__.py index 33dc54ca9..1ec759998 100644 --- a/src/pycram/robot_descriptions/__init__.py +++ b/src/pycram/robot_descriptions/__init__.py @@ -9,7 +9,8 @@ class DeprecatedRobotDescription: def raise_error(self): - raise DeprecationWarning("Robot description moved, please use RobotDescription.current_robot_description from pycram.robot_description") + raise DeprecationWarning("Robot description moved, please use RobotDescription.current_robot_description from" + " pycram.robot_description") @property def name(self): diff --git a/src/pycram/robot_descriptions/boxy_description.py b/src/pycram/robot_descriptions/boxy_description.py index 6af022b8e..f4dc7cfc1 100644 --- a/src/pycram/robot_descriptions/boxy_description.py +++ b/src/pycram/robot_descriptions/boxy_description.py @@ -1,10 +1,9 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path from ..robot_description import RobotDescription, CameraDescription, KinematicChainDescription, \ EndEffectorDescription, RobotDescriptionManager from ..datastructures.enums import Arms, Grasp, GripperState -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "boxy" + '.urdf' +filename = get_ros_package_path('pycram') + '/resources/robots/' + "boxy" + '.urdf' boxy_description = RobotDescription("boxy", "base_link", "triangle_base_link", "triangle_base_joint", filename) diff --git a/src/pycram/robot_descriptions/donbot_description.py b/src/pycram/robot_descriptions/donbot_description.py index 69d50ad02..f37958440 100644 --- a/src/pycram/robot_descriptions/donbot_description.py +++ b/src/pycram/robot_descriptions/donbot_description.py @@ -1,10 +1,9 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription from ..datastructures.enums import Arms, Grasp, GripperState -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "iai_donbot" + '.urdf' +filename = get_ros_package_path('pycram') + '/resources/robots/' + "iai_donbot" + '.urdf' donbot_description = RobotDescription("iai_donbot", "base_link", "ur5_base_link", "arm_base_mounting_joint", filename) diff --git a/src/pycram/robot_descriptions/hsrb_description.py b/src/pycram/robot_descriptions/hsrb_description.py index f83f23191..ae452e92e 100644 --- a/src/pycram/robot_descriptions/hsrb_description.py +++ b/src/pycram/robot_descriptions/hsrb_description.py @@ -1,11 +1,10 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription from ..datastructures.enums import GripperState, Grasp, Arms -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "hsrb" + '.urdf' +filename = get_ros_package_path('pycram') + '/resources/robots/' + "hsrb" + '.urdf' hsrb_description = RobotDescription("hsrb", "base_link", "arm_lift_link", "arm_lift_joint", filename) diff --git a/src/pycram/robot_descriptions/pr2_description.py b/src/pycram/robot_descriptions/pr2_description.py index 402125f2a..35ee28838 100644 --- a/src/pycram/robot_descriptions/pr2_description.py +++ b/src/pycram/robot_descriptions/pr2_description.py @@ -1,13 +1,17 @@ +from ..datastructures.dataclasses import VirtualMobileBaseJoints from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription from ..datastructures.enums import Arms, Grasp, GripperState, GripperType -import rospkg +from ..ros.ros_tools import get_ros_package_path -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "pr2" + '.urdf' +from ..helper import get_robot_mjcf_path + +filename = get_ros_package_path('pycram') + '/resources/robots/' + "pr2" + '.urdf' + +mjcf_filename = get_robot_mjcf_path("", "pr2") pr2_description = RobotDescription("pr2", "base_link", "torso_lift_link", "torso_lift_joint", - filename) + filename, virtual_mobile_base_joints=VirtualMobileBaseJoints(), mjcf_path=mjcf_filename) ################################## Left Arm ################################## left_arm = KinematicChainDescription("left", "torso_lift_link", "l_wrist_roll_link", diff --git a/src/pycram/robot_descriptions/stretch_description.py b/src/pycram/robot_descriptions/stretch_description.py index fad1c7427..3c86be24f 100644 --- a/src/pycram/robot_descriptions/stretch_description.py +++ b/src/pycram/robot_descriptions/stretch_description.py @@ -1,11 +1,10 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ CameraDescription, RobotDescriptionManager from ..datastructures.enums import GripperState, Arms, Grasp -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "stretch_description" + '.urdf' +filename = get_ros_package_path('pycram') + '/resources/robots/' + "stretch_description" + '.urdf' stretch_description = RobotDescription("stretch_description", "base_link", "link_lift", "joint_lift", filename) diff --git a/src/pycram/robot_descriptions/tiago_description.py b/src/pycram/robot_descriptions/tiago_description.py index 6a92d47ec..ef39e8a8b 100644 --- a/src/pycram/robot_descriptions/tiago_description.py +++ b/src/pycram/robot_descriptions/tiago_description.py @@ -1,13 +1,19 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path + +from ..datastructures.dataclasses import VirtualMobileBaseJoints +from ..datastructures.enums import GripperState, Arms, Grasp from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager, CameraDescription -from ..datastructures.enums import GripperState, Arms, Grasp +from ..helper import get_robot_mjcf_path + +filename = get_ros_package_path('pycram') + '/resources/robots/' + "tiago_dual" + '.urdf' -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "tiago_dual" + '.urdf' +mjcf_filename = get_robot_mjcf_path("pal_robotics", "tiago_dual") tiago_description = RobotDescription("tiago_dual", "base_link", "torso_lift_link", "torso_lift_joint", - filename) + filename, + virtual_mobile_base_joints=VirtualMobileBaseJoints(), + mjcf_path=mjcf_filename) ################################## Left Arm ################################## left_arm = KinematicChainDescription("left_arm", "torso_lift_link", "arm_left_7_link", diff --git a/src/pycram/robot_descriptions/turtlebot_description.py b/src/pycram/robot_descriptions/turtlebot_description.py new file mode 100644 index 000000000..34dbe79a9 --- /dev/null +++ b/src/pycram/robot_descriptions/turtlebot_description.py @@ -0,0 +1,13 @@ +from ..ros.ros_tools import get_ros_package_path + +from ..robot_description import RobotDescriptionManager, RobotDescription + +# Description for turtlebot3_waffle_pi +filename = get_ros_package_path('pycram') + '/resources/robots/' + "turtlebot" + '.urdf' + +turtlebot = RobotDescription("turtlebot", "world", "base_link", "base_joint", + filename) + +# Add to RobotDescriptionManager +rdm = RobotDescriptionManager() +rdm.register_description(turtlebot) diff --git a/src/pycram/robot_descriptions/ur5_description.py b/src/pycram/robot_descriptions/ur5_description.py index d50f189fd..3a931dbc3 100644 --- a/src/pycram/robot_descriptions/ur5_description.py +++ b/src/pycram/robot_descriptions/ur5_description.py @@ -1,10 +1,9 @@ -import rospkg +from ..ros.ros_tools import get_ros_package_path from ..robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ RobotDescriptionManager from ..datastructures.enums import Arms, Grasp, GripperState -rospack = rospkg.RosPack() -filename = rospack.get_path('pycram') + '/resources/robots/' + "ur5_robotiq" + '.urdf' +filename = get_ros_package_path('pycram') + '/resources/robots/' + "ur5_robotiq" + '.urdf' ur5_description = RobotDescription("ur5_robotiq", "world", "base_link", "ee_link", filename) diff --git a/src/pycram/ros/__init__.py b/src/pycram/ros/__init__.py index e69de29bb..c22224413 100644 --- a/src/pycram/ros/__init__.py +++ b/src/pycram/ros/__init__.py @@ -0,0 +1,6 @@ +import rospy +from .ros_tools import is_master_online + +# Check is for sphinx autoAPI to be able to work in a CI workflow +if is_master_online(): + rospy.init_node("pycram") diff --git a/src/pycram/ros/action_lib.py b/src/pycram/ros/action_lib.py new file mode 100644 index 000000000..efc5f3048 --- /dev/null +++ b/src/pycram/ros/action_lib.py @@ -0,0 +1,7 @@ +import actionlib + +from actionlib import SimpleActionClient + +def create_action_client(topic_name: str, action_message) -> SimpleActionClient: + return actionlib.SimpleActionClient(topic_name, action_message) + diff --git a/src/pycram/ros/data_types.py b/src/pycram/ros/data_types.py new file mode 100644 index 000000000..4b6956279 --- /dev/null +++ b/src/pycram/ros/data_types.py @@ -0,0 +1,11 @@ +import rospy + +from rospy import ServiceException +def Time(time=0.0): + return rospy.Time(time) + +def Duration(duration=0.0): + return rospy.Duration(duration) + +def Rate(rate): + return rospy.Rate(rate) \ No newline at end of file diff --git a/src/pycram/ros/logging.py b/src/pycram/ros/logging.py new file mode 100644 index 000000000..184f22841 --- /dev/null +++ b/src/pycram/ros/logging.py @@ -0,0 +1,64 @@ +import rospy +import inspect +from pathlib import Path + + +def _get_caller_method_name(): + """ + Get the name of the method that called the function from which this function is called. It is intended as a helper + function for the log functions. + + :return: Name of the method that called the function from which this function is called. + """ + return inspect.stack()[2][3] + +def _get_caller_method_line(): + """ + Get the line of the method that called the function from which this function is called. It is intended as a helper + function for the log functions. + + :return: Line number of the method that called the function from which this function is called. + """ + return inspect.stack()[2][2] + +def _get_caller_file_name(): + """ + Get the file name of the method that called the function from which this function is called. It is intended as a helper + function for the log functions. + + :return: File name of the method that called the function from which this function is called. + """ + path = Path(inspect.stack()[2][1]) + return path.name + + +def logwarn(message: str): + rospy.logwarn(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def loginfo(message: str): + rospy.loginfo(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def logerr(message: str): + rospy.logerr(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def logdebug(message: str): + rospy.logdebug(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def logwarn_once(message: str): + rospy.logwarn_once(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def loginfo_once(message: str): + rospy.loginfo_once(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def logerr_once(message: str): + rospy.logerr_once(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") + + +def logdebug_once(message: str): + rospy.logdebug_once(f"[{_get_caller_file_name()}:{_get_caller_method_line()}:{_get_caller_method_name()}] {message}") diff --git a/src/pycram/ros/publisher.py b/src/pycram/ros/publisher.py new file mode 100644 index 000000000..533a7c89b --- /dev/null +++ b/src/pycram/ros/publisher.py @@ -0,0 +1,4 @@ +import rospy + +def create_publisher(topic, msg_type, queue_size=10) -> rospy.Publisher: + return rospy.Publisher(topic, msg_type, queue_size=queue_size) \ No newline at end of file diff --git a/src/pycram/ros/ros_tools.py b/src/pycram/ros/ros_tools.py new file mode 100644 index 000000000..96bcf7ce6 --- /dev/null +++ b/src/pycram/ros/ros_tools.py @@ -0,0 +1,45 @@ +import rosgraph +import rosnode +import rospy +# import rospkg + +from rospkg import RosPack, ResourceNotFound +from typing_extensions import Any + + +def get_node_names(namespace=None): + return rosnode.get_node_names(namespace) + + +def create_ros_pack(ros_paths: Any = None) -> RosPack: + """ + Creates a RosPack instance to search for resources of ros packages. + + :param ros_paths: An ordered list of paths to search for resources. + :return: An instance of RosPack + """ + return RosPack(ros_paths) + + +def get_ros_package_path(package_name: str) -> str: + rospack = create_ros_pack() + return rospack.get_path(package_name) + + +def get_parameter(name: str) -> Any: + return rospy.get_param(name) + + +def wait_for_message(topic_name: str): + return rospy.wait_for_message(topic_name) + + +def is_master_online(): + return rosgraph.is_master_online() + + +def sleep(duration: float): + rospy.sleep(duration) + +def create_timer(duration: int, callback, oneshot=False): + return rospy.Timer(rospy.Duration(duration), callback, oneshot=oneshot) \ No newline at end of file diff --git a/src/pycram/ros/service.py b/src/pycram/ros/service.py new file mode 100644 index 000000000..be2f84f4d --- /dev/null +++ b/src/pycram/ros/service.py @@ -0,0 +1,11 @@ +import rosservice +import rospy + + +def get_service_proxy(topic_name: str, service_message) -> rospy.ServiceProxy: + return rospy.ServiceProxy(topic_name, service_message) + + +def wait_for_service(topic_name: str): + rospy.loginfo_once(f"Waiting for service: {topic_name}") + rospy.wait_for_service(topic_name) diff --git a/src/pycram/ros/subscriber.py b/src/pycram/ros/subscriber.py new file mode 100644 index 000000000..d4ab3dd8b --- /dev/null +++ b/src/pycram/ros/subscriber.py @@ -0,0 +1,4 @@ +import rospy + +def create_subscriber(topic, msg_type, callback, queue_size=10) -> rospy.Subscriber: + return rospy.Subscriber(topic, msg_type, callback, queue_size=queue_size) \ No newline at end of file diff --git a/src/pycram/ros/viz_marker_publisher.py b/src/pycram/ros/viz_marker_publisher.py index 0aa149e9b..ab9a01c89 100644 --- a/src/pycram/ros/viz_marker_publisher.py +++ b/src/pycram/ros/viz_marker_publisher.py @@ -1,321 +1,3 @@ -import atexit -import threading -import time -from typing import List, Optional, Tuple - -import rospy -from geometry_msgs.msg import Vector3 -from std_msgs.msg import ColorRGBA -from visualization_msgs.msg import Marker, MarkerArray - -from ..datastructures.dataclasses import BoxVisualShape, CylinderVisualShape, MeshVisualShape, SphereVisualShape -from ..datastructures.pose import Pose, Transform -from ..designator import ObjectDesignatorDescription -from ..datastructures.world import World - - class VizMarkerPublisher: - """ - Publishes an Array of visualization marker which represent the situation in the World - """ - - def __init__(self, topic_name="/pycram/viz_marker", interval=0.1): - """ - The Publisher creates an Array of Visualization marker with a Marker for each link of each Object in the - World. This Array is published with a rate of interval. - - :param topic_name: The name of the topic to which the Visualization Marker should be published. - :param interval: The interval at which the visualization marker should be published, in seconds. - """ - self.topic_name = topic_name - self.interval = interval - - self.pub = rospy.Publisher(self.topic_name, MarkerArray, queue_size=10) - - self.thread = threading.Thread(target=self._publish) - self.kill_event = threading.Event() - self.main_world = World.current_world if not World.current_world.is_prospection_world else World.current_world.world_sync.world - - self.thread.start() - atexit.register(self._stop_publishing) - - def _publish(self) -> None: - """ - Constantly publishes the Marker Array. To the given topic name at a fixed rate. - """ - while not self.kill_event.is_set(): - marker_array = self._make_marker_array() - - self.pub.publish(marker_array) - time.sleep(self.interval) - - def _make_marker_array(self) -> MarkerArray: - """ - Creates the Marker Array to be published. There is one Marker for link for each object in the Array, each Object - creates a name space in the visualization Marker. The type of Visualization Marker is decided by the collision - tag of the URDF. - - :return: An Array of Visualization Marker - """ - marker_array = MarkerArray() - for obj in self.main_world.objects: - if obj.name == "floor": - continue - for link in obj.link_name_to_id.keys(): - geom = obj.get_link_geometry(link) - if not geom: - continue - msg = Marker() - msg.header.frame_id = "map" - msg.ns = obj.name - msg.id = obj.link_name_to_id[link] - msg.type = Marker.MESH_RESOURCE - msg.action = Marker.ADD - link_pose = obj.get_link_transform(link) - if obj.get_link_origin(link) is not None: - link_origin = obj.get_link_origin_transform(link) - else: - link_origin = Transform() - link_pose_with_origin = link_pose * link_origin - msg.pose = link_pose_with_origin.to_pose().pose - - color = [1, 1, 1, 1] if obj.link_name_to_id[link] == -1 else obj.get_link_color(link).get_rgba() - - msg.color = ColorRGBA(*color) - msg.lifetime = rospy.Duration(1) - - if isinstance(geom, MeshVisualShape): - msg.type = Marker.MESH_RESOURCE - msg.mesh_resource = "file://" + geom.file_name - msg.scale = Vector3(1, 1, 1) - msg.mesh_use_embedded_materials = True - elif isinstance(geom, CylinderVisualShape): - msg.type = Marker.CYLINDER - msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.length) - elif isinstance(geom, BoxVisualShape): - msg.type = Marker.CUBE - msg.scale = Vector3(*geom.size) - elif isinstance(geom, SphereVisualShape): - msg.type = Marker.SPHERE - msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.radius * 2) - - marker_array.markers.append(msg) - return marker_array - - def _stop_publishing(self) -> None: - """ - Stops the publishing of the Visualization Marker update by setting the kill event and collecting the thread. - """ - self.kill_event.set() - self.thread.join() - - -class ManualMarkerPublisher: - """ - Class to manually add and remove marker of objects and poses. - """ - - def __init__(self, topic_name: str = '/pycram/manual_marker', interval: float = 0.1): - """ - The Publisher creates an Array of Visualization marker with a marker for a pose or object. - This Array is published with a rate of interval. - - :param topic_name: Name of the marker topic - :param interval: Interval at which the marker should be published - """ - self.start_time = None - self.marker_array_pub = rospy.Publisher(topic_name, MarkerArray, queue_size=10) - - self.marker_array = MarkerArray() - self.marker_overview = {} - self.current_id = 0 - - self.interval = interval - self.log_message = None - - def publish(self, pose: Pose, color: Optional[List] = None, bw_object: Optional[ObjectDesignatorDescription] = None, - name: Optional[str] = None): - """ - Publish a pose or an object into the MarkerArray. - Priorities to add an object if possible - - :param pose: Pose of the marker - :param color: Color of the marker if no object is given - :param bw_object: Object to add as a marker - :param name: Name of the marker - """ - - if color is None: - color = [1, 0, 1, 1] - - self.start_time = time.time() - thread = threading.Thread(target=self._publish, args=(pose, bw_object, name, color)) - thread.start() - rospy.loginfo(self.log_message) - thread.join() - - def _publish(self, pose: Pose, bw_object: Optional[ObjectDesignatorDescription] = None, name: Optional[str] = None, - color: Optional[List] = None): - """ - Publish the marker into the MarkerArray - """ - stop_thread = False - duration = 2 - - while not stop_thread: - if time.time() - self.start_time > duration: - stop_thread = True - if bw_object is None: - self._publish_pose(name=name, pose=pose, color=color) - else: - self._publish_object(name=name, pose=pose, bw_object=bw_object) - - rospy.sleep(self.interval) - - def _publish_pose(self, name: str, pose: Pose, color: Optional[List] = None): - """ - Publish a Pose as a marker - - :param name: Name of the marker - :param pose: Pose of the marker - :param color: Color of the marker - """ - - if name is None: - name = 'pose_marker' - - if name in self.marker_overview.keys(): - self._update_marker(self.marker_overview[name], new_pose=pose) - return - - color_rgba = ColorRGBA(*color) - self._make_marker_array(name=name, marker_type=Marker.ARROW, marker_pose=pose, - marker_scales=(0.05, 0.05, 0.05), color_rgba=color_rgba) - self.marker_array_pub.publish(self.marker_array) - self.log_message = f"Pose '{name}' published" - - def _publish_object(self, name: Optional[str], pose: Pose, bw_object: ObjectDesignatorDescription): - """ - Publish an Object as a marker - - :param name: Name of the marker - :param pose: Pose of the marker - :param bw_object: ObjectDesignatorDescription for the marker - """ - - bw_real = bw_object.resolve() - - if name is None: - name = bw_real.name - - if name in self.marker_overview.keys(): - self._update_marker(self.marker_overview[name], new_pose=pose) - return - - path = bw_real.world_object.root_link.geometry.file_name - - self._make_marker_array(name=name, marker_type=Marker.MESH_RESOURCE, marker_pose=pose, - path_to_resource=path) - - self.marker_array_pub.publish(self.marker_array) - self.log_message = f"Object '{name}' published" - - def _make_marker_array(self, name, marker_type: int, marker_pose: Pose, marker_scales: Tuple = (1.0, 1.0, 1.0), - color_rgba: ColorRGBA = ColorRGBA(*[1.0, 1.0, 1.0, 1.0]), - path_to_resource: Optional[str] = None): - """ - Create a Marker and add it to the MarkerArray - - :param name: Name of the Marker - :param marker_type: Type of the marker to create - :param marker_pose: Pose of the marker - :param marker_scales: individual scaling of the markers axes - :param color_rgba: Color of the marker as RGBA - :param path_to_resource: Path to the resource of a Bulletworld object - """ - - frame_id = marker_pose.header.frame_id - new_marker = Marker() - new_marker.id = self.current_id - new_marker.header.frame_id = frame_id - new_marker.ns = name - new_marker.header.stamp = rospy.Time.now() - new_marker.type = marker_type - new_marker.action = Marker.ADD - new_marker.pose = marker_pose.pose - new_marker.scale.x = marker_scales[0] - new_marker.scale.y = marker_scales[1] - new_marker.scale.z = marker_scales[2] - new_marker.color.a = color_rgba.a - new_marker.color.r = color_rgba.r - new_marker.color.g = color_rgba.g - new_marker.color.b = color_rgba.b - - if path_to_resource is not None: - new_marker.mesh_resource = 'file://' + path_to_resource - - self.marker_array.markers.append(new_marker) - self.marker_overview[name] = new_marker.id - self.current_id += 1 - - def _update_marker(self, marker_id: int, new_pose: Pose) -> bool: - """ - Update an existing marker to a new pose - - :param marker_id: id of the marker that should be updated - :param new_pose: Pose where the updated marker is set - - :return: True if update was successful, False otherwise - """ - - # Find the marker with the specified ID - for marker in self.marker_array.markers: - if marker.id == marker_id: - # Update successful - marker.pose = new_pose - self.log_message = f"Marker '{marker.ns}' updated" - self.marker_array_pub.publish(self.marker_array) - return True - - # Update was not successful - rospy.logwarn(f"Marker {marker_id} not found for update") - return False - - def remove_marker(self, bw_object: Optional[ObjectDesignatorDescription] = None, name: Optional[str] = None): - """ - Remove a marker by object or name - - :param bw_object: Object which marker should be removed - :param name: Name of object that should be removed - """ - - if bw_object is not None: - bw_real = bw_object.resolve() - name = bw_real.name - - if name is None: - rospy.logerr('No name for object given, cannot remove marker') - return - - marker_id = self.marker_overview.pop(name) - - for marker in self.marker_array.markers: - if marker.id == marker_id: - marker.action = Marker.DELETE - - self.marker_array_pub.publish(self.marker_array) - self.marker_array.markers.pop(marker_id) - - rospy.loginfo(f"Removed Marker '{name}'") - - def clear_all_marker(self): - """ - Clear all existing markers - """ - for marker in self.marker_array.markers: - marker.action = Marker.DELETE - - self.marker_overview = {} - self.marker_array_pub.publish(self.marker_array) - - rospy.loginfo('Removed all markers') + def __init__(self): + raise DeprecationWarning("This function moved and can now be found in pycram.ros_utils.viz_marker_publisher") \ No newline at end of file diff --git a/src/pycram/ros_utils/__init__.py b/src/pycram/ros_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/ros/force_torque_sensor.py b/src/pycram/ros_utils/force_torque_sensor.py similarity index 93% rename from src/pycram/ros/force_torque_sensor.py rename to src/pycram/ros_utils/force_torque_sensor.py index 5f52503f5..3d98e79cd 100644 --- a/src/pycram/ros/force_torque_sensor.py +++ b/src/pycram/ros_utils/force_torque_sensor.py @@ -2,11 +2,11 @@ import time import threading -import rospy - from geometry_msgs.msg import WrenchStamped from std_msgs.msg import Header from ..datastructures.world import World +from ..ros.data_types import Time +from ..ros.publisher import create_publisher class ForceTorqueSensor: @@ -34,7 +34,7 @@ def __init__(self, joint_name, fts_topic="/pycram/fts", interval=0.1): f" does not exist in robot object") self.world.enable_joint_force_torque_sensor(self.world.robot, self.fts_joint_idx) - self.fts_pub = rospy.Publisher(fts_topic, WrenchStamped, queue_size=10) + self.fts_pub = create_publisher(fts_topic, WrenchStamped, queue_size=10) self.interval = interval self.kill_event = threading.Event() @@ -53,7 +53,7 @@ def _publish(self) -> None: joint_ft = self.world.get_joint_reaction_force_torque(self.world.robot, self.fts_joint_idx) h = Header() h.seq = seq - h.stamp = rospy.Time.now() + h.stamp = Time().now() h.frame_id = self.joint_name wrench_msg = WrenchStamped() diff --git a/src/pycram/ros/joint_state_publisher.py b/src/pycram/ros_utils/joint_state_publisher.py similarity index 91% rename from src/pycram/ros/joint_state_publisher.py rename to src/pycram/ros_utils/joint_state_publisher.py index 08b78edf2..c747b0af0 100644 --- a/src/pycram/ros/joint_state_publisher.py +++ b/src/pycram/ros_utils/joint_state_publisher.py @@ -2,11 +2,11 @@ import threading import atexit -import rospy - from sensor_msgs.msg import JointState from std_msgs.msg import Header from ..datastructures.world import World +from ..ros.data_types import Time +from ..ros.publisher import create_publisher class JointStatePublisher: @@ -23,7 +23,7 @@ def __init__(self, joint_state_topic="/pycram/joint_state", interval=0.1): """ self.world = World.current_world - self.joint_state_pub = rospy.Publisher(joint_state_topic, JointState, queue_size=10) + self.joint_state_pub = create_publisher(joint_state_topic, JointState, queue_size=10) self.interval = interval self.kill_event = threading.Event() self.thread = threading.Thread(target=self._publish) @@ -43,7 +43,7 @@ def _publish(self) -> None: while not self.kill_event.is_set(): current_joint_states = [robot.get_joint_position(joint_name) for joint_name in joint_names] h = Header() - h.stamp = rospy.Time.now() + h.stamp = Time().now() h.seq = seq h.frame_id = "" joint_state_msg = JointState() diff --git a/src/pycram/ros/robot_state_updater.py b/src/pycram/ros_utils/robot_state_updater.py similarity index 85% rename from src/pycram/ros/robot_state_updater.py rename to src/pycram/ros_utils/robot_state_updater.py index 0f3ad0f4f..314d85535 100644 --- a/src/pycram/ros/robot_state_updater.py +++ b/src/pycram/ros_utils/robot_state_updater.py @@ -1,4 +1,3 @@ -import rospy import atexit import tf import time @@ -8,6 +7,8 @@ from ..datastructures.world import World from ..robot_descriptions import robot_description from ..datastructures.pose import Pose +from ..ros.data_types import Time, Duration +from ..ros.ros_tools import wait_for_message, create_timer class RobotStateUpdater: @@ -31,8 +32,8 @@ def __init__(self, tf_topic: str, joint_state_topic: str): self.tf_topic = tf_topic self.joint_state_topic = joint_state_topic - self.tf_timer = rospy.Timer(rospy.Duration.from_sec(0.1), self._subscribe_tf) - self.joint_state_timer = rospy.Timer(rospy.Duration.from_sec(0.1), self._subscribe_joint_state) + self.tf_timer = create_timer(Duration().from_sec(0.1), self._subscribe_tf) + self.joint_state_timer = create_timer(Duration().from_sec(0.1), self._subscribe_joint_state) atexit.register(self._stop_subscription) @@ -42,7 +43,7 @@ def _subscribe_tf(self, msg: TransformStamped) -> None: :param msg: TransformStamped message published to the topic """ - trans, rot = self.tf_listener.lookupTransform("/map", robot_description.base_frame, rospy.Time(0)) + trans, rot = self.tf_listener.lookupTransform("/map", robot_description.base_frame, Time(0)) World.robot.set_pose(Pose(trans, rot)) def _subscribe_joint_state(self, msg: JointState) -> None: @@ -54,7 +55,7 @@ def _subscribe_joint_state(self, msg: JointState) -> None: :param msg: JointState message published to the topic. """ try: - msg = rospy.wait_for_message(self.joint_state_topic, JointState) + msg = wait_for_message(self.joint_state_topic, JointState) for name, position in zip(msg.name, msg.position): World.robot.set_joint_position(name, position) except AttributeError: diff --git a/src/pycram/ros/tf_broadcaster.py b/src/pycram/ros_utils/tf_broadcaster.py similarity index 85% rename from src/pycram/ros/tf_broadcaster.py rename to src/pycram/ros_utils/tf_broadcaster.py index 9f4661a43..2e6287e00 100644 --- a/src/pycram/ros/tf_broadcaster.py +++ b/src/pycram/ros_utils/tf_broadcaster.py @@ -1,18 +1,21 @@ import time -import rospy import threading import atexit from ..datastructures.pose import Pose from ..datastructures.world import World +from ..datastructures.enums import ExecutionType from tf2_msgs.msg import TFMessage +from ..ros.publisher import create_publisher +from ..ros.data_types import Time + class TFBroadcaster: """ Broadcaster that publishes TF frames for every object in the World. """ - def __init__(self, projection_namespace="simulated", odom_frame="odom", interval=0.1): + def __init__(self, projection_namespace=ExecutionType.SIMULATED, odom_frame="odom", interval=0.1): """ The broadcaster prefixes all published TF messages with a projection namespace to distinguish between the TF frames from the simulation and the one from the real robot. @@ -23,8 +26,8 @@ def __init__(self, projection_namespace="simulated", odom_frame="odom", interval """ self.world = World.current_world - self.tf_static_publisher = rospy.Publisher("/tf_static", TFMessage, queue_size=10) - self.tf_publisher = rospy.Publisher("/tf", TFMessage, queue_size=10) + self.tf_static_publisher = create_publisher("/tf_static", TFMessage, queue_size=10) + self.tf_publisher = create_publisher("/tf", TFMessage, queue_size=10) self.thread = threading.Thread(target=self._publish, daemon=True) self.kill_event = threading.Event() self.interval = interval @@ -52,11 +55,11 @@ def _update_objects(self) -> None: """ for obj in self.world.objects: pose = obj.get_pose() - pose.header.stamp = rospy.Time.now() + pose.header.stamp = Time.now() self._publish_pose(obj.tf_frame, pose) for link in obj.link_name_to_id.keys(): link_pose = obj.get_link_pose(link) - link_pose.header.stamp = rospy.Time.now() + link_pose.header.stamp = Time.now() self._publish_pose(obj.get_link_tf_frame(link), link_pose) def _update_static_odom(self) -> None: @@ -78,8 +81,8 @@ def _publish_pose(self, child_frame_id: str, pose: Pose, static=False) -> None: frame_id = pose.frame if frame_id != child_frame_id: tf_stamped = pose.to_transform(child_frame_id) - tf_stamped.frame = self.projection_namespace + "/" + tf_stamped.frame - tf_stamped.child_frame_id = self.projection_namespace + "/" + tf_stamped.child_frame_id + tf_stamped.frame = self.projection_namespace.name + "/" + tf_stamped.frame + tf_stamped.child_frame_id = self.projection_namespace.name + "/" + tf_stamped.child_frame_id tf2_msg = TFMessage() tf2_msg.transforms.append(tf_stamped) if static: diff --git a/src/pycram/ros_utils/viz_marker_publisher.py b/src/pycram/ros_utils/viz_marker_publisher.py new file mode 100644 index 000000000..ceeb11910 --- /dev/null +++ b/src/pycram/ros_utils/viz_marker_publisher.py @@ -0,0 +1,327 @@ +import atexit +import threading +import time +from typing import List, Optional, Tuple + +import numpy as np +from geometry_msgs.msg import Vector3 +from std_msgs.msg import ColorRGBA +from visualization_msgs.msg import Marker, MarkerArray + +from ..datastructures.dataclasses import BoxVisualShape, CylinderVisualShape, MeshVisualShape, SphereVisualShape +from ..datastructures.pose import Pose, Transform +from ..designator import ObjectDesignatorDescription +from ..datastructures.world import World +from ..ros.data_types import Duration, Time +from ..ros.logging import loginfo, logwarn, logerr +from ..ros.publisher import create_publisher +from ..ros.ros_tools import sleep + + +class VizMarkerPublisher: + """ + Publishes an Array of visualization marker which represent the situation in the World + """ + + def __init__(self, topic_name="/pycram/viz_marker", interval=0.1): + """ + The Publisher creates an Array of Visualization marker with a Marker for each link of each Object in the + World. This Array is published with a rate of interval. + + :param topic_name: The name of the topic to which the Visualization Marker should be published. + :param interval: The interval at which the visualization marker should be published, in seconds. + """ + self.topic_name = topic_name + self.interval = interval + + self.pub = create_publisher(self.topic_name, MarkerArray, queue_size=10) + + self.thread = threading.Thread(target=self._publish) + self.kill_event = threading.Event() + self.main_world = World.current_world if not World.current_world.is_prospection_world else World.current_world.world_sync.world + self.lock = self.main_world.object_lock + self.thread.start() + atexit.register(self._stop_publishing) + + def _publish(self) -> None: + """ + Constantly publishes the Marker Array. To the given topic name at a fixed rate. + """ + while not self.kill_event.is_set(): + self.lock.acquire() + marker_array = self._make_marker_array() + self.lock.release() + self.pub.publish(marker_array) + time.sleep(self.interval) + + def _make_marker_array(self) -> MarkerArray: + """ + Creates the Marker Array to be published. There is one Marker for link for each object in the Array, each Object + creates a name space in the visualization Marker. The type of Visualization Marker is decided by the collision + tag of the URDF. + + :return: An Array of Visualization Marker + """ + marker_array = MarkerArray() + for obj in self.main_world.objects: + if obj.name == "floor": + continue + for link in obj.link_name_to_id.keys(): + geom = obj.get_link_geometry(link) + if not geom: + continue + msg = Marker() + msg.header.frame_id = "map" + msg.ns = obj.name + msg.id = obj.link_name_to_id[link] + msg.type = Marker.MESH_RESOURCE + msg.action = Marker.ADD + link_pose = obj.get_link_transform(link) + if obj.get_link_origin(link) is not None: + link_origin = obj.get_link_origin_transform(link) + else: + link_origin = Transform() + link_pose_with_origin = link_pose * link_origin + msg.pose = link_pose_with_origin.to_pose().pose + + color = obj.get_link_color(link).get_rgba() + + msg.color = ColorRGBA(*color) + msg.lifetime = Duration(1) + + if isinstance(geom, MeshVisualShape): + msg.type = Marker.MESH_RESOURCE + msg.mesh_resource = "file://" + geom.file_name + msg.scale = Vector3(1, 1, 1) + msg.mesh_use_embedded_materials = True + elif isinstance(geom, CylinderVisualShape): + msg.type = Marker.CYLINDER + msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.length) + elif isinstance(geom, BoxVisualShape): + msg.type = Marker.CUBE + size = np.array(geom.size) * 2 + msg.scale = Vector3(size[0], size[1], size[2]) + elif isinstance(geom, SphereVisualShape): + msg.type = Marker.SPHERE + msg.scale = Vector3(geom.radius * 2, geom.radius * 2, geom.radius * 2) + + marker_array.markers.append(msg) + return marker_array + + def _stop_publishing(self) -> None: + """ + Stops the publishing of the Visualization Marker update by setting the kill event and collecting the thread. + """ + self.kill_event.set() + self.thread.join() + + +class ManualMarkerPublisher: + """ + Class to manually add and remove marker of objects and poses. + """ + + def __init__(self, topic_name: str = '/pycram/manual_marker', interval: float = 0.1): + """ + The Publisher creates an Array of Visualization marker with a marker for a pose or object. + This Array is published with a rate of interval. + + :param topic_name: Name of the marker topic + :param interval: Interval at which the marker should be published + """ + self.start_time = None + self.marker_array_pub = create_publisher(topic_name, MarkerArray, queue_size=10) + + self.marker_array = MarkerArray() + self.marker_overview = {} + self.current_id = 0 + + self.interval = interval + self.log_message = None + + def publish(self, pose: Pose, color: Optional[List] = None, bw_object: Optional[ObjectDesignatorDescription] = None, + name: Optional[str] = None): + """ + Publish a pose or an object into the MarkerArray. + Priorities to add an object if possible + + :param pose: Pose of the marker + :param color: Color of the marker if no object is given + :param bw_object: Object to add as a marker + :param name: Name of the marker + """ + + if color is None: + color = [1, 0, 1, 1] + + self.start_time = time.time() + thread = threading.Thread(target=self._publish, args=(pose, bw_object, name, color)) + thread.start() + loginfo(self.log_message) + thread.join() + + def _publish(self, pose: Pose, bw_object: Optional[ObjectDesignatorDescription] = None, name: Optional[str] = None, + color: Optional[List] = None): + """ + Publish the marker into the MarkerArray + """ + stop_thread = False + duration = 2 + + while not stop_thread: + if time.time() - self.start_time > duration: + stop_thread = True + if bw_object is None: + self._publish_pose(name=name, pose=pose, color=color) + else: + self._publish_object(name=name, pose=pose, bw_object=bw_object) + + sleep(self.interval) + + def _publish_pose(self, name: str, pose: Pose, color: Optional[List] = None): + """ + Publish a Pose as a marker + + :param name: Name of the marker + :param pose: Pose of the marker + :param color: Color of the marker + """ + + if name is None: + name = 'pose_marker' + + if name in self.marker_overview.keys(): + self._update_marker(self.marker_overview[name], new_pose=pose) + return + + color_rgba = ColorRGBA(*color) + self._make_marker_array(name=name, marker_type=Marker.ARROW, marker_pose=pose, + marker_scales=(0.05, 0.05, 0.05), color_rgba=color_rgba) + self.marker_array_pub.publish(self.marker_array) + self.log_message = f"Pose '{name}' published" + + def _publish_object(self, name: Optional[str], pose: Pose, bw_object: ObjectDesignatorDescription): + """ + Publish an Object as a marker + + :param name: Name of the marker + :param pose: Pose of the marker + :param bw_object: ObjectDesignatorDescription for the marker + """ + + bw_real = bw_object.resolve() + + if name is None: + name = bw_real.name + + if name in self.marker_overview.keys(): + self._update_marker(self.marker_overview[name], new_pose=pose) + return + + path = bw_real.world_object.root_link.geometry.file_name + + self._make_marker_array(name=name, marker_type=Marker.MESH_RESOURCE, marker_pose=pose, + path_to_resource=path) + + self.marker_array_pub.publish(self.marker_array) + self.log_message = f"Object '{name}' published" + + def _make_marker_array(self, name, marker_type: int, marker_pose: Pose, marker_scales: Tuple = (1.0, 1.0, 1.0), + color_rgba: ColorRGBA = ColorRGBA(*[1.0, 1.0, 1.0, 1.0]), + path_to_resource: Optional[str] = None): + """ + Create a Marker and add it to the MarkerArray + + :param name: Name of the Marker + :param marker_type: Type of the marker to create + :param marker_pose: Pose of the marker + :param marker_scales: individual scaling of the markers axes + :param color_rgba: Color of the marker as RGBA + :param path_to_resource: Path to the resource of a Bulletworld object + """ + + frame_id = marker_pose.header.frame_id + new_marker = Marker() + new_marker.id = self.current_id + new_marker.header.frame_id = frame_id + new_marker.ns = name + new_marker.header.stamp = Time.now() + new_marker.type = marker_type + new_marker.action = Marker.ADD + new_marker.pose = marker_pose.pose + new_marker.scale.x = marker_scales[0] + new_marker.scale.y = marker_scales[1] + new_marker.scale.z = marker_scales[2] + new_marker.color.a = color_rgba.a + new_marker.color.r = color_rgba.r + new_marker.color.g = color_rgba.g + new_marker.color.b = color_rgba.b + + if path_to_resource is not None: + new_marker.mesh_resource = 'file://' + path_to_resource + + self.marker_array.markers.append(new_marker) + self.marker_overview[name] = new_marker.id + self.current_id += 1 + + def _update_marker(self, marker_id: int, new_pose: Pose) -> bool: + """ + Update an existing marker to a new pose + + :param marker_id: id of the marker that should be updated + :param new_pose: Pose where the updated marker is set + + :return: True if update was successful, False otherwise + """ + + # Find the marker with the specified ID + for marker in self.marker_array.markers: + if marker.id == marker_id: + # Update successful + marker.pose = new_pose + self.log_message = f"Marker '{marker.ns}' updated" + self.marker_array_pub.publish(self.marker_array) + return True + + # Update was not successful + logwarn(f"Marker {marker_id} not found for update") + return False + + def remove_marker(self, bw_object: Optional[ObjectDesignatorDescription] = None, name: Optional[str] = None): + """ + Remove a marker by object or name + + :param bw_object: Object which marker should be removed + :param name: Name of object that should be removed + """ + + if bw_object is not None: + bw_real = bw_object.resolve() + name = bw_real.name + + if name is None: + logerr('No name for object given, cannot remove marker') + return + + marker_id = self.marker_overview.pop(name) + + for marker in self.marker_array.markers: + if marker.id == marker_id: + marker.action = Marker.DELETE + + self.marker_array_pub.publish(self.marker_array) + self.marker_array.markers.pop(marker_id) + + loginfo(f"Removed Marker '{name}'") + + def clear_all_marker(self): + """ + Clear all existing markers + """ + for marker in self.marker_array.markers: + marker.action = Marker.DELETE + + self.marker_overview = {} + self.marker_array_pub.publish(self.marker_array) + + loginfo('Removed all markers') diff --git a/src/pycram/tasktree.py b/src/pycram/tasktree.py index 448e1f718..06440829e 100644 --- a/src/pycram/tasktree.py +++ b/src/pycram/tasktree.py @@ -2,28 +2,22 @@ # used for delayed evaluation of typing until python 3.11 becomes mainstream from __future__ import annotations - -from typing_extensions import TYPE_CHECKING - import datetime import inspect import logging from typing_extensions import List, Optional, Callable - import anytree import sqlalchemy.orm.session import tqdm - from .datastructures.world import World +from .helper import Singleton +from .orm.action_designator import Action from .orm.tasktree import TaskTreeNode as ORMTaskTreeNode from .orm.base import ProcessMetaData -from .plan_failures import PlanFailure +from .failures import PlanFailure from .datastructures.enums import TaskStatus from .datastructures.dataclasses import Color -if TYPE_CHECKING: - from .designators.performables import Action - class NoOperation: @@ -80,7 +74,7 @@ def __init__(self, action: Optional[Action] = NoOperation(), parent: Optional[Ta self.action = action self.status = TaskStatus.CREATED - self.start_time = None + self.start_time = datetime.datetime.now() self.end_time = None self.parent = parent self.reason: Optional[Exception] = reason @@ -117,7 +111,7 @@ def to_sql(self) -> ORMTaskTreeNode: else: reason = None - return ORMTaskTreeNode(self.start_time, self.end_time, self.status.name, reason) + return ORMTaskTreeNode(start_time=self.start_time, end_time=self.end_time, status=self.status, reason=reason) def insert(self, session: sqlalchemy.orm.session.Session, recursive: bool = True, parent: Optional[TaskTreeNode] = None, use_progress_bar: bool = True, @@ -186,7 +180,7 @@ def __enter__(self): self.suspended_tree = task_tree self.world_state = World.current_world.save_state() - self.simulated_root = TaskTreeNode() + self.simulated_root = TaskTree() task_tree = self.simulated_root World.current_world.add_text("Simulating...", [0, 0, 1.75], color=Color.from_rgb([0, 0, 0]), parent_object_id=1) @@ -202,21 +196,50 @@ def __exit__(self, exc_type, exc_val, exc_tb): World.current_world.remove_text() -task_tree: Optional[TaskTreeNode] = None -"""Current TaskTreeNode""" - - -def reset_tree() -> None: +class TaskTree(metaclass=Singleton): """ - Reset the current task tree to an empty root (NoOperation) node. + TaskTree represents the tree of functions that were called during a pycram plan. Consists of TaskTreeNodes. + Must be a singleton. """ - global task_tree - task_tree = TaskTreeNode() - task_tree.start_time = datetime.datetime.now() - task_tree.status = TaskStatus.RUNNING + def __init__(self): + """ + Create a new TaskTree with a root node. + """ + self.root = TaskTreeNode() + self.current_node = self.root -reset_tree() + def __len__(self): + """ + Get the number of nodes that are in this TaskTree. + + :return: The number of nodes. + """ + return len(self.root.children) + + def reset_tree(self): + """ + Reset the current task tree to an empty root (NoOperation) node. + """ + self.root = TaskTreeNode() + self.root.start_time = datetime.datetime.now() + self.root.status = TaskStatus.RUNNING + self.current_node = self.root + + def add_node(self, action: Optional[Action] = None) -> TaskTreeNode: + """ + Add a new node to the task tree and make it the current node. + + :param action: The action that is performed in this node. + :return: The new node. + """ + new_node = TaskTreeNode(action=action, parent=self.current_node) + self.current_node = new_node + return new_node + + +task_tree = TaskTree() +"""Current TaskTreeNode""" def with_tree(fun: Callable) -> Callable: @@ -227,7 +250,6 @@ def with_tree(fun: Callable) -> Callable: """ def handle_tree(*args, **kwargs): - # get the task tree global task_tree @@ -238,29 +260,31 @@ def handle_tree(*args, **kwargs): action = keyword_arguments.get("self", None) # create the task tree node - task_tree = TaskTreeNode(action, parent=task_tree) + task_tree.add_node(action) # Try to execute the task try: - task_tree.status = TaskStatus.CREATED - task_tree.start_time = datetime.datetime.now() + task_tree.current_node.status = TaskStatus.CREATED + task_tree.current_node.start_time = datetime.datetime.now() result = fun(*args, **kwargs) # if it succeeded set the flag - task_tree.status = TaskStatus.SUCCEEDED + task_tree.current_node.status = TaskStatus.SUCCEEDED # iff a PlanFailure occurs except PlanFailure as e: # log the error and set the flag - logging.exception("Task execution failed at %s. Reason %s" % (repr(task_tree), e)) - task_tree.reason = e - task_tree.status = TaskStatus.FAILED + logging.exception("Task execution failed at %s. Reason %s" % (repr(task_tree.current_node), e)) + task_tree.current_node.reason = e + task_tree.current_node.status = TaskStatus.FAILED raise e + finally: - # set and time and update current node pointer - task_tree.end_time = datetime.datetime.now() - task_tree = task_tree.parent + if task_tree.current_node.parent is not None: + task_tree.current_node.end_time = datetime.datetime.now() + task_tree.current_node = task_tree.current_node.parent + return result return handle_tree diff --git a/src/pycram/utils.py b/src/pycram/utils.py index 28b5109ac..30bd10538 100644 --- a/src/pycram/utils.py +++ b/src/pycram/utils.py @@ -7,14 +7,20 @@ GeneratorList -- implementation of generator list wrappers. """ from inspect import isgeneratorfunction -from typing_extensions import List, Tuple, Callable - import os +import math + +import numpy as np +from matplotlib import pyplot as plt +import matplotlib.colors as mcolors +from typing_extensions import Tuple, Callable, List, Dict, TYPE_CHECKING from .datastructures.pose import Pose -import math +from .local_transformer import LocalTransformer -from typing_extensions import Dict +if TYPE_CHECKING: + from .world_concepts.world_object import Object + from .robot_description import CameraDescription class bcolors: @@ -36,7 +42,7 @@ class bcolors: UNDERLINE = '\033[4m' -def _apply_ik(robot: 'pycram.world_concepts.WorldObject', pose_and_joint_poses: Tuple[Pose, Dict[str, float]]) -> None: +def _apply_ik(robot: 'Object', pose_and_joint_poses: Tuple[Pose, Dict[str, float]]) -> None: """ Apllies a list of joint poses calculated by an inverse kinematics solver to a robot @@ -46,7 +52,7 @@ def _apply_ik(robot: 'pycram.world_concepts.WorldObject', pose_and_joint_poses: """ pose, joint_states = pose_and_joint_poses robot.set_pose(pose) - robot.set_joint_positions(joint_states) + robot.set_multiple_joint_positions(joint_states) class GeneratorList: @@ -113,7 +119,7 @@ def axis_angle_to_quaternion(axis: List, angle: float) -> Tuple: z = normalized_axis[2] * math.sin(angle / 2) w = math.cos(angle / 2) - return (x, y, z, w) + return tuple((x, y, z, w)) class suppress_stdout_stderr(object): @@ -130,7 +136,7 @@ class suppress_stdout_stderr(object): def __init__(self): # Open a pair of null files - self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] + self.null_fds = [os.open(os.devnull, os.O_RDWR) for _ in range(2)] # Save the actual stdout (1) and stderr (2) file descriptors. self.save_fds = [os.dup(1), os.dup(2)] @@ -148,3 +154,257 @@ def __exit__(self, *_): # Close all file descriptors for fd in self.null_fds + self.save_fds: os.close(fd) + + +class RayTestUtils: + + def __init__(self, ray_test_batch: Callable, object_id_to_name: Dict = None): + """ + Initialize the ray test helper. + """ + self.local_transformer = LocalTransformer() + self.ray_test_batch = ray_test_batch + self.object_id_to_name = object_id_to_name + + def get_images_for_target(self, cam_pose: Pose, + camera_description: 'CameraDescription', + camera_frame: str, + size: int = 256, + camera_min_distance: float = 0.1, + camera_max_distance: int = 3, + plot: bool = False) -> List[np.ndarray]: + """ + Note: The returned color image is a repeated depth image in 3 channels. + """ + + # get the list of start positions of the rays. + rays_start_positions = self.get_camera_rays_start_positions(camera_description, camera_frame, cam_pose, size, + camera_min_distance).tolist() + + # get the list of end positions of the rays + rays_end_positions = self.get_camera_rays_end_positions(camera_description, camera_frame, cam_pose, size, + camera_max_distance).tolist() + + # apply the ray test + object_ids, distances = self.ray_test_batch(rays_start_positions, rays_end_positions, return_distance=True) + + # construct the images/masks + segmentation_mask = self.construct_segmentation_mask_from_ray_test_object_ids(object_ids, size) + depth_image = self.construct_depth_image_from_ray_test_distances(distances, size) + camera_min_distance + color_depth_image = self.construct_color_image_from_depth_image(depth_image) + + if plot: + self.plot_segmentation_mask(segmentation_mask) + self.plot_depth_image(depth_image) + + return [color_depth_image, depth_image, segmentation_mask] + + @staticmethod + def construct_segmentation_mask_from_ray_test_object_ids(object_ids: List[int], size: int) -> np.ndarray: + """ + Construct a segmentation mask from the object ids returned by the ray test. + + :param object_ids: The object ids. + :param size: The size of the grid. + :return: The segmentation mask. + """ + return np.array(object_ids).squeeze(axis=1).reshape(size, size) + + @staticmethod + def construct_depth_image_from_ray_test_distances(distances: List[float], size: int) -> np.ndarray: + """ + Construct a depth image from the distances returned by the ray test. + + :param distances: The distances. + :param size: The size of the grid. + :return: The depth image. + """ + return np.array(distances).reshape(size, size) + + @staticmethod + def construct_color_image_from_depth_image(depth_image: np.ndarray) -> np.ndarray: + """ + Construct a color image from the depth image. + + :param depth_image: The depth image. + :return: The color image. + """ + min_distance = np.min(depth_image) + max_distance = np.max(depth_image) + normalized_depth_image = (depth_image - min_distance) * 255 / (max_distance - min_distance) + return np.repeat(normalized_depth_image[:, :, np.newaxis], 3, axis=2).astype(np.uint8) + + def get_camera_rays_start_positions(self, camera_description: 'CameraDescription', camera_frame: str, + camera_pose: Pose, size: int, + camera_min_distance: float) -> np.ndarray: + + # get the start pose of the rays from the camera pose and minimum distance. + start_pose = self.get_camera_rays_start_pose(camera_description, camera_frame, camera_pose, camera_min_distance) + + # get the list of start positions of the rays. + return np.repeat(np.array([start_pose.position_as_list()]), size * size, axis=0) + + def get_camera_rays_start_pose(self, camera_description: 'CameraDescription', camera_frame: str, camera_pose: Pose, + camera_min_distance: float) -> Pose: + """ + Get the start position of the camera rays, which is the camera pose shifted by the minimum distance of the + camera. + + :param camera_description: The camera description. + :param camera_frame: The camera tf frame. + :param camera_pose: The camera pose. + :param camera_min_distance: The minimum distance from which the camera can see. + """ + camera_pose_in_camera_frame = self.local_transformer.transform_pose(camera_pose, camera_frame) + start_position = (np.array(camera_description.front_facing_axis) * camera_min_distance + + np.array(camera_pose_in_camera_frame.position_as_list())) + start_pose = Pose(start_position.tolist(), camera_pose_in_camera_frame.orientation_as_list(), camera_frame) + return self.local_transformer.transform_pose(start_pose, "map") + + def get_camera_rays_end_positions(self, camera_description: 'CameraDescription', camera_frame: str, + camera_pose: Pose, size: int, camera_max_distance: float = 3.0) -> np.ndarray: + """ + Get the end positions of the camera rays. + + :param camera_description: The camera description. + :param camera_frame: The camera frame. + :param camera_pose: The camera pose. + :param size: The size of the grid. + :param camera_max_distance: The maximum distance of the camera. + :return: The end positions of the camera rays. + """ + rays_horizontal_angles, rays_vertical_angles = self.construct_grid_of_camera_rays_angles(camera_description, + size) + rays_end_positions = self.get_end_positions_of_rays_from_angles_and_distance(rays_vertical_angles, + rays_horizontal_angles, + camera_max_distance) + return self.transform_points_from_camera_frame_to_world_frame(camera_pose, camera_frame, rays_end_positions) + + @staticmethod + def transform_points_from_camera_frame_to_world_frame(camera_pose: Pose, camera_frame: str, + points: np.ndarray) -> np.ndarray: + """ + Transform points from the camera frame to the world frame. + + :param camera_pose: The camera pose. + :param camera_frame: The camera frame. + :param points: The points to transform. + :return: The transformed points. + """ + cam_to_world_transform = camera_pose.to_transform(camera_frame) + return cam_to_world_transform.apply_transform_to_array_of_points(points) + + @staticmethod + def get_end_positions_of_rays_from_angles_and_distance(vertical_angles: np.ndarray, horizontal_angles: np.ndarray, + distance: float) -> np.ndarray: + """ + Get the end positions of the rays from the angles and the distance. + + :param vertical_angles: The vertical angles of the rays. + :param horizontal_angles: The horizontal angles of the rays. + :param distance: The distance of the rays. + :return: The end positions of the rays. + """ + rays_end_positions_x = distance * np.cos(vertical_angles) * np.sin(horizontal_angles) + rays_end_positions_x = rays_end_positions_x.reshape(-1) + rays_end_positions_z = distance * np.cos(vertical_angles) * np.cos(horizontal_angles) + rays_end_positions_z = rays_end_positions_z.reshape(-1) + rays_end_positions_y = distance * np.sin(vertical_angles) + rays_end_positions_y = rays_end_positions_y.reshape(-1) + return np.stack((rays_end_positions_x, rays_end_positions_y, rays_end_positions_z), axis=1) + + @staticmethod + def construct_grid_of_camera_rays_angles(camera_description: 'CameraDescription', + size: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Construct a 2D grid of camera rays angles. + + :param camera_description: The camera description. + :param size: The size of the grid. + :return: The 2D grid of the horizontal and the vertical angles of the camera rays. + """ + # get the camera fov angles + camera_horizontal_fov = camera_description.horizontal_angle + camera_vertical_fov = camera_description.vertical_angle + + # construct a 2d grid of rays angles + rays_horizontal_angles = np.linspace(-camera_horizontal_fov / 2, camera_horizontal_fov / 2, size) + rays_horizontal_angles = np.tile(rays_horizontal_angles, (size, 1)) + rays_vertical_angles = np.linspace(-camera_vertical_fov / 2, camera_vertical_fov / 2, size) + rays_vertical_angles = np.tile(rays_vertical_angles, (size, 1)).T + return rays_horizontal_angles, rays_vertical_angles + + @staticmethod + def plot_segmentation_mask(segmentation_mask, + object_id_to_name: Dict[int, str] = None): + """ + Plot the segmentation mask with different colors for each object. + + :param segmentation_mask: The segmentation mask. + :param object_id_to_name: The mapping from object id to object name. + """ + if object_id_to_name is None: + object_id_to_name = {uid: str(uid) for uid in np.unique(segmentation_mask)} + + # Create a custom color map + unique_ids = np.unique(segmentation_mask) + unique_ids = unique_ids[unique_ids != -1] # Exclude -1 values + + # Create a color map that assigns a unique color to each ID + colors = plt.cm.get_cmap('tab20', len(unique_ids)) # Use tab20 colormap for distinct colors + color_dict = {uid: colors(i) for i, uid in enumerate(unique_ids)} + + # Map each ID to its corresponding color + mask_shape = segmentation_mask.shape + segmentation_colored = np.zeros((mask_shape[0], mask_shape[1], 3)) + + for uid in unique_ids: + segmentation_colored[segmentation_mask == uid] = color_dict[uid][:3] # Ignore the alpha channel + + # Create a colormap for the color bar + cmap = mcolors.ListedColormap([color_dict[uid][:3] for uid in unique_ids]) + norm = mcolors.BoundaryNorm(boundaries=np.arange(len(unique_ids) + 1) - 0.5, ncolors=len(unique_ids)) + + # Plot the colored segmentation mask + fig, ax = plt.subplots() + _ = ax.imshow(segmentation_colored) + ax.axis('off') # Hide axes + ax.set_title('Segmentation Mask with Different Colors for Each Object') + + # Create color bar + cbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, ticks=np.arange(len(unique_ids))) + cbar.ax.set_yticklabels( + [object_id_to_name[uid] for uid in unique_ids]) # Label the color bar with object IDs + cbar.set_label('Object Name') + + plt.show() + + @staticmethod + def plot_depth_image(depth_image): + # Plot the depth image + fig, ax = plt.subplots() + cax = ax.imshow(depth_image, cmap='viridis', vmin=0, vmax=np.max(depth_image)) + ax.axis('off') # Hide axes + ax.set_title('Depth Image') + + # Create color bar + cbar = fig.colorbar(cax, ax=ax) + cbar.set_label('Depth Value') + + plt.show() + + +def wxyz_to_xyzw(wxyz: List[float]) -> List[float]: + """ + Convert a quaternion from WXYZ to XYZW format. + """ + return [wxyz[1], wxyz[2], wxyz[3], wxyz[0]] + + +def xyzw_to_wxyz(xyzw: List[float]) -> List[float]: + """ + Convert a quaternion from XYZW to WXYZ format. + + :param xyzw: The quaternion in XYZW format. + """ + return [xyzw[3], *xyzw[:3]] \ No newline at end of file diff --git a/src/pycram/validation/__init__.py b/src/pycram/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/validation/error_checkers.py b/src/pycram/validation/error_checkers.py new file mode 100644 index 000000000..f594b6bd6 --- /dev/null +++ b/src/pycram/validation/error_checkers.py @@ -0,0 +1,351 @@ +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import numpy as np +from tf.transformations import quaternion_multiply, quaternion_inverse +from typing_extensions import List, Union, Optional, Any, Sized, Iterable as T_Iterable, TYPE_CHECKING, Tuple + +from ..datastructures.enums import JointType +if TYPE_CHECKING: + from ..datastructures.pose import Pose + + +class ErrorChecker(ABC): + """ + An abstract class that resembles an error checker. It has two main methods, one for calculating the error between + two values and another for checking if the error is acceptable. + """ + def __init__(self, acceptable_error: Union[float, T_Iterable[float]], is_iterable: Optional[bool] = False): + """ + Initialize the error checker. + + :param acceptable_error: The acceptable error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + self._acceptable_error: np.ndarray = np.array(acceptable_error) + self.tiled_acceptable_error: Optional[np.ndarray] = None + self.is_iterable = is_iterable + + def reset(self) -> None: + """ + Reset the error checker. + """ + self.tiled_acceptable_error = None + + @property + def acceptable_error(self) -> np.ndarray: + return self._acceptable_error + + @acceptable_error.setter + def acceptable_error(self, new_acceptable_error: Union[float, T_Iterable[float]]) -> None: + self._acceptable_error = np.array(new_acceptable_error) + + def update_acceptable_error(self, new_acceptable_error: Optional[T_Iterable[float]] = None, + tile_to_match: Optional[Sized] = None,) -> None: + """ + Update the acceptable error with a new value, and tile it to match the length of the error if needed. + + :param new_acceptable_error: The new acceptable error. + :param tile_to_match: The iterable to match the length of the error with. + """ + if new_acceptable_error is not None: + self.acceptable_error = new_acceptable_error + if tile_to_match is not None and self.is_iterable: + self.update_tiled_acceptable_error(tile_to_match) + + def update_tiled_acceptable_error(self, tile_to_match: Sized) -> None: + """ + Tile the acceptable error to match the length of the error. + + :param tile_to_match: The object to match the length of the error. + :return: The tiled acceptable error. + """ + self.tiled_acceptable_error = np.tile(self.acceptable_error.flatten(), + len(tile_to_match) // self.acceptable_error.size) + + @abstractmethod + def _calculate_error(self, value_1: Any, value_2: Any) -> Union[float, List[float]]: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + pass + + def calculate_error(self, value_1: Any, value_2: Any) -> Union[float, List[float]]: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + if self.is_iterable: + return [self._calculate_error(v1, v2) for v1, v2 in zip(value_1, value_2)] + else: + return self._calculate_error(value_1, value_2) + + def is_error_acceptable(self, value_1: Any, value_2: Any) -> bool: + """ + Check if the error is acceptable. + + :param value_1: The first value. + :param value_2: The second value. + :return: Whether the error is acceptable. + """ + error = self.calculate_error(value_1, value_2) + if self.is_iterable: + error = np.array(error).flatten() + if self.tiled_acceptable_error is None or\ + len(error) != len(self.tiled_acceptable_error): + self.update_tiled_acceptable_error(error) + return np.all(error <= self.tiled_acceptable_error) + else: + return is_error_acceptable(error, self.acceptable_error) + + +class PoseErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Union[Tuple[float], T_Iterable[Tuple[float]]] = (1e-3, np.pi / 180), + is_iterable: Optional[bool] = False): + """ + Initialize the pose error checker. + + :param acceptable_error: The acceptable pose error (position error, orientation error). + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> List[float]: + """ + Calculate the error between two poses. + + :param value_1: The first pose. + :param value_2: The second pose. + """ + return calculate_pose_error(value_1, value_2) + + +class PositionErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the position error checker. + + :param acceptable_error: The acceptable position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two positions. + + :param value_1: The first position. + :param value_2: The second position. + :return: The error between the two positions. + """ + return calculate_position_error(value_1, value_2) + + +class OrientationErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = np.pi / 180, is_iterable: Optional[bool] = False): + """ + Initialize the orientation error checker. + + :param acceptable_error: The acceptable orientation error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two quaternions. + + :param value_1: The first quaternion. + :param value_2: The second quaternion. + :return: The error between the two quaternions. + """ + return calculate_orientation_error(value_1, value_2) + + +class SingleValueErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the single value error checker. + + :param acceptable_error: The acceptable error between two values. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + return abs(value_1 - value_2) + + +class RevoluteJointPositionErrorChecker(SingleValueErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = np.pi / 180, is_iterable: Optional[bool] = False): + """ + Initialize the revolute joint position error checker. + + :param acceptable_error: The acceptable revolute joint position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + +class PrismaticJointPositionErrorChecker(SingleValueErrorChecker): + + def __init__(self, acceptable_error: Optional[float] = 1e-3, is_iterable: Optional[bool] = False): + """ + Initialize the prismatic joint position error checker. + + :param acceptable_error: The acceptable prismatic joint position error. + :param is_iterable: Whether the error is iterable (i.e. list of errors). + """ + super().__init__(acceptable_error, is_iterable) + + +class IterableErrorChecker(ErrorChecker): + + def __init__(self, acceptable_error: Optional[T_Iterable[float]] = None): + """ + Initialize the iterable error checker. + + :param acceptable_error: The acceptable error between two values. + """ + super().__init__(acceptable_error, True) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error between the two values. + """ + return abs(value_1 - value_2) + + +class MultiJointPositionErrorChecker(IterableErrorChecker): + + def __init__(self, joint_types: List[JointType], acceptable_error: Optional[T_Iterable[float]] = None): + """ + Initialize the multi-joint position error checker. + + :param joint_types: The types of the joints. + :param acceptable_error: The acceptable error between two joint positions. + """ + self.joint_types = joint_types + if acceptable_error is None: + acceptable_error = [np.pi/180 if jt == JointType.REVOLUTE else 1e-3 for jt in joint_types] + super().__init__(acceptable_error) + + def _calculate_error(self, value_1: Any, value_2: Any) -> float: + """ + Calculate the error between two joint positions. + + :param value_1: The first joint position. + :param value_2: The second joint position. + :return: The error between the two joint positions. + """ + return calculate_joint_position_error(value_1, value_2) + + +def calculate_pose_error(pose_1: 'Pose', pose_2: 'Pose') -> List[float]: + """ + Calculate the error between two poses. + + :param pose_1: The first pose. + :param pose_2: The second pose. + :return: The error between the two poses. + """ + return [calculate_position_error(pose_1.position_as_list(), pose_2.position_as_list()), + calculate_orientation_error(pose_1.orientation_as_list(), pose_2.orientation_as_list())] + + +def calculate_position_error(position_1: List[float], position_2: List[float]) -> float: + """ + Calculate the error between two positions. + + :param position_1: The first position. + :param position_2: The second position. + :return: The error between the two positions. + """ + return np.linalg.norm(np.array(position_1) - np.array(position_2)) + + +def calculate_orientation_error(quat_1: List[float], quat_2: List[float]) -> float: + """ + Calculate the error between two quaternions. + + :param quat_1: The first quaternion. + :param quat_2: The second quaternion. + :return: The error between the two quaternions. + """ + return calculate_angle_between_quaternions(quat_1, quat_2) + + +def calculate_joint_position_error(joint_position_1: float, joint_position_2: float) -> float: + """ + Calculate the error between two joint positions. + + :param joint_position_1: The first joint position. + :param joint_position_2: The second joint position. + :return: The error between the two joint positions. + """ + return abs(joint_position_1 - joint_position_2) + + +def is_error_acceptable(error: Union[float, T_Iterable[float]], + acceptable_error: Union[float, T_Iterable[float]]) -> bool: + """ + Check if the error is acceptable. + + :param error: The error. + :param acceptable_error: The acceptable error. + :return: Whether the error is acceptable. + """ + if isinstance(error, Iterable): + return all([error_i <= acceptable_error_i for error_i, acceptable_error_i in zip(error, acceptable_error)]) + else: + return error <= acceptable_error + + +def calculate_angle_between_quaternions(quat_1: List[float], quat_2: List[float]) -> float: + """ + Calculates the angle between two quaternions. + + :param quat_1: The first quaternion. + :param quat_2: The second quaternion. + :return: A float value that represents the angle between the two quaternions. + """ + quat_diff = calculate_quaternion_difference(quat_1, quat_2) + quat_diff_angle = 2 * np.arctan2(np.linalg.norm(quat_diff[0:3]), quat_diff[3]) + if quat_diff_angle > np.pi: + quat_diff_angle = 2 * np.pi - quat_diff_angle + return quat_diff_angle + + +def calculate_quaternion_difference(quat_1: List[float], quat_2: List[float]) -> List[float]: + """ + Calculates the quaternion difference. + + :param quat_1: The quaternion of the object at the first time step. + :param quat_2: The quaternion of the object at the second time step. + :return: A list of float values that represent the quaternion difference. + """ + quat_diff = quaternion_multiply(quaternion_inverse(quat_1), quat_2) + return quat_diff diff --git a/src/pycram/validation/goal_validator.py b/src/pycram/validation/goal_validator.py new file mode 100644 index 000000000..66756ae52 --- /dev/null +++ b/src/pycram/validation/goal_validator.py @@ -0,0 +1,550 @@ +from time import sleep, time + +import numpy as np +from typing_extensions import Any, Callable, Optional, Union, Iterable, Dict, TYPE_CHECKING, Tuple + +from ..datastructures.enums import JointType +from .error_checkers import ErrorChecker, PoseErrorChecker, PositionErrorChecker, \ + OrientationErrorChecker, SingleValueErrorChecker +from ..ros.logging import logerr, logwarn + +if TYPE_CHECKING: + from ..datastructures.world import World + from ..world_concepts.world_object import Object + from ..datastructures.pose import Pose + from ..description import ObjectDescription + + Joint = ObjectDescription.Joint + Link = ObjectDescription.Link + +OptionalArgCallable = Union[Callable[[], Any], Callable[[Any], Any]] + + +class GoalValidator: + """ + A class to validate the goal by tracking the goal achievement progress. + """ + + raise_error: Optional[bool] = False + """ + Whether to raise an error if the goal is not achieved. + """ + + def __init__(self, error_checker: ErrorChecker, current_value_getter: OptionalArgCallable, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the goal validator. + + :param error_checker: The error checker. + :param current_value_getter: The current value getter function which takes an optional input and returns the + current value. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved, if given, will be + used to check if this percentage is achieved instead of the complete goal. + """ + self.error_checker: ErrorChecker = error_checker + self.current_value_getter: Callable[[Optional[Any]], Any] = current_value_getter + self.acceptable_percentage_of_goal_achieved: Optional[float] = acceptable_percentage_of_goal_achieved + self.goal_value: Optional[Any] = None + self.initial_error: Optional[np.ndarray] = None + self.current_value_getter_input: Optional[Any] = None + + def register_goal_and_wait_until_achieved(self, goal_value: Any, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Union[float, Iterable[float]]] = None, + max_wait_time: Optional[float] = 1, + time_per_read: Optional[float] = 0.01) -> None: + """ + Register the goal value and wait until the target is reached. + + :param goal_value: The goal value. + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + :param max_wait_time: The maximum time to wait. + :param time_per_read: The time to wait between each read. + """ + self.register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + self.wait_until_goal_is_achieved(max_wait_time, time_per_read) + + def wait_until_goal_is_achieved(self, max_wait_time: Optional[float] = 2, + time_per_read: Optional[float] = 0.01) -> None: + """ + Wait until the target is reached. + + :param max_wait_time: The maximum time to wait. + :param time_per_read: The time to wait between each read. + """ + if self.goal_value is None: + return # Skip if goal value is None + start_time = time() + current = self.current_value + while not self.goal_achieved: + sleep(time_per_read) + if time() - start_time > max_wait_time: + msg = f"Failed to achieve goal from initial error {self.initial_error} with" \ + f" goal {self.goal_value} within {max_wait_time}" \ + f" seconds, the current value is {current}, error is {self.current_error}, percentage" \ + f" of goal achieved is {self.percentage_of_goal_achieved}" + if self.raise_error: + logerr(msg) + raise TimeoutError(msg) + else: + logwarn(msg) + break + current = self.current_value + self.reset() + + def reset(self) -> None: + """ + Reset the goal validator. + """ + self.goal_value = None + self.initial_error = None + self.current_value_getter_input = None + self.error_checker.reset() + + @property + def _acceptable_error(self) -> np.ndarray: + """ + The acceptable error. + """ + if self.error_checker.is_iterable: + return self.tiled_acceptable_error + else: + return self.acceptable_error + + @property + def acceptable_error(self) -> np.ndarray: + """ + The acceptable error. + """ + return self.error_checker.acceptable_error + + @property + def tiled_acceptable_error(self) -> Optional[np.ndarray]: + """ + The tiled acceptable error. + """ + return self.error_checker.tiled_acceptable_error + + def register_goal(self, goal_value: Any, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Union[float, Iterable[float]]] = None): + """ + Register the goal value. + + :param goal_value: The goal value. + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + """ + if goal_value is None or (hasattr(goal_value, '__len__') and len(goal_value) == 0): + return # Skip if goal value is None or empty + self.goal_value = goal_value + self.current_value_getter_input = current_value_getter_input + self.update_initial_error(goal_value, initial_value=initial_value) + self.error_checker.update_acceptable_error(acceptable_error, self.initial_error) + + def update_initial_error(self, goal_value: Any, initial_value: Optional[Any] = None) -> None: + """ + Calculate the initial error. + + :param goal_value: The goal value. + :param initial_value: The initial value. + """ + if initial_value is None: + self.initial_error: np.ndarray = self.current_error + else: + self.initial_error: np.ndarray = self.calculate_error(goal_value, initial_value) + + @property + def current_value(self) -> Any: + """ + The current value of the monitored variable. + """ + if self.current_value_getter_input is not None: + return self.current_value_getter(self.current_value_getter_input) + else: + return self.current_value_getter() + + @property + def current_error(self) -> np.ndarray: + """ + The current error. + """ + return self.calculate_error(self.goal_value, self.current_value) + + def calculate_error(self, value_1: Any, value_2: Any) -> np.ndarray: + """ + Calculate the error between two values. + + :param value_1: The first value. + :param value_2: The second value. + :return: The error. + """ + return np.array(self.error_checker.calculate_error(value_1, value_2)).flatten() + + @property + def percentage_of_goal_achieved(self) -> float: + """ + The relative (relative to the acceptable error) achieved percentage of goal. + """ + percent_array = 1 - self.relative_current_error / self.relative_initial_error + percent_array_filtered = percent_array[self.relative_initial_error > self._acceptable_error] + if len(percent_array_filtered) == 0: + return 1 + else: + return np.mean(percent_array_filtered) + + @property + def actual_percentage_of_goal_achieved(self) -> float: + """ + The percentage of goal achieved. + """ + percent_array = 1 - self.current_error / np.maximum(self.initial_error, 1e-3) + percent_array_filtered = percent_array[self.initial_error > self._acceptable_error] + if len(percent_array_filtered) == 0: + return 1 + else: + return np.mean(percent_array_filtered) + + @property + def relative_current_error(self) -> np.ndarray: + """ + The relative current error (relative to the acceptable error). + """ + return self.get_relative_error(self.current_error, threshold=0) + + @property + def relative_initial_error(self) -> np.ndarray: + """ + The relative initial error (relative to the acceptable error). + """ + return np.maximum(self.initial_error, 1e-3) + + def get_relative_error(self, error: Any, threshold: Optional[float] = 1e-3) -> np.ndarray: + """ + Get the relative error by comparing the error with the acceptable error and filtering out the errors that are + less than the threshold. + + :param error: The error. + :param threshold: The threshold. + :return: The relative error. + """ + return np.maximum(error - self._acceptable_error, threshold) + + @property + def goal_achieved(self) -> bool: + """ + Check if the goal is achieved. + """ + if self.acceptable_percentage_of_goal_achieved is None: + return self.is_current_error_acceptable + else: + return self.percentage_of_goal_achieved >= self.acceptable_percentage_of_goal_achieved + + @property + def is_current_error_acceptable(self) -> bool: + """ + Check if the error is acceptable. + """ + return self.error_checker.is_error_acceptable(self.current_value, self.goal_value) + + +class PoseGoalValidator(GoalValidator): + """ + A class to validate the pose goal by tracking the goal achievement progress. + """ + + def __init__(self, current_pose_getter: OptionalArgCallable = None, + acceptable_error: Union[Tuple[float], Iterable[Tuple[float]]] = (1e-3, np.pi / 180), + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the pose goal validator. + + :param current_pose_getter: The current pose getter function which takes an optional input and returns the + current pose. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(PoseErrorChecker(acceptable_error, is_iterable=is_iterable), current_pose_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiPoseGoalValidator(PoseGoalValidator): + """ + A class to validate the multi-pose goal by tracking the goal achievement progress. + """ + + def __init__(self, current_poses_getter: OptionalArgCallable = None, + acceptable_error: Union[Tuple[float], Iterable[Tuple[float]]] = (1e-2, 5 * np.pi / 180), + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-pose goal validator. + + :param current_poses_getter: The current poses getter function which takes an optional input and returns the + current poses. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_poses_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class PositionGoalValidator(GoalValidator): + """ + A class to validate the position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_position_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = 1e-3, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the position goal validator. + + :param current_position_getter: The current position getter function which takes an optional input and + returns the current position. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of position vectors. + """ + super().__init__(PositionErrorChecker(acceptable_error, is_iterable=is_iterable), current_position_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiPositionGoalValidator(PositionGoalValidator): + """ + A class to validate the multi-position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_positions_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = 1e-3, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-position goal validator. + + :param current_positions_getter: The current positions getter function which takes an optional input and + returns the current positions. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_positions_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class OrientationGoalValidator(GoalValidator): + """ + A class to validate the orientation goal by tracking the goal achievement progress. + """ + + def __init__(self, current_orientation_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = np.pi / 180, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8, + is_iterable: Optional[bool] = False): + """ + Initialize the orientation goal validator. + + :param current_orientation_getter: The current orientation getter function which takes an optional input and + returns the current orientation. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of quaternions. + """ + super().__init__(OrientationErrorChecker(acceptable_error, is_iterable=is_iterable), current_orientation_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + + +class MultiOrientationGoalValidator(OrientationGoalValidator): + """ + A class to validate the multi-orientation goal by tracking the goal achievement progress. + """ + + def __init__(self, current_orientations_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = np.pi / 180, + acceptable_percentage_of_goal_achieved: Optional[float] = 0.8): + """ + Initialize the multi-orientation goal validator. + + :param current_orientations_getter: The current orientations getter function which takes an optional input and + returns the current orientations. + :param acceptable_error: The acceptable error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(current_orientations_getter, acceptable_error, acceptable_percentage_of_goal_achieved, + is_iterable=True) + + +class JointPositionGoalValidator(GoalValidator): + """ + A class to validate the joint position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_position_getter: OptionalArgCallable = None, + acceptable_error: Optional[float] = None, + acceptable_revolute_joint_position_error: float = np.pi / 180, + acceptable_prismatic_joint_position_error: float = 1e-3, + acceptable_percentage_of_goal_achieved: float = 0.8, + is_iterable: bool = False): + """ + Initialize the joint position goal validator. + + :param current_position_getter: The current position getter function which takes an optional input and returns + the current position. + :param acceptable_error: The acceptable error. + :param acceptable_revolute_joint_position_error: The acceptable orientation error. + :param acceptable_prismatic_joint_position_error: The acceptable position error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + :param is_iterable: Whether it is a sequence of joint positions. + """ + super().__init__(SingleValueErrorChecker(acceptable_error, is_iterable=is_iterable), current_position_getter, + acceptable_percentage_of_goal_achieved=acceptable_percentage_of_goal_achieved) + self.acceptable_orientation_error = acceptable_revolute_joint_position_error + self.acceptable_position_error = acceptable_prismatic_joint_position_error + + def register_goal(self, goal_value: Any, joint_type: JointType, + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[float] = None): + """ + Register the goal value. + + :param goal_value: The goal value. + :param joint_type: The joint type (e.g. REVOLUTE, PRISMATIC). + :param current_value_getter_input: The values that are used as input to the current value getter. + :param initial_value: The initial value. + :param acceptable_error: The acceptable error. + """ + if acceptable_error is None: + self.error_checker.acceptable_error = self.acceptable_orientation_error if joint_type == JointType.REVOLUTE\ + else self.acceptable_position_error + super().register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + + +class MultiJointPositionGoalValidator(GoalValidator): + """ + A class to validate the multi-joint position goal by tracking the goal achievement progress. + """ + + def __init__(self, current_positions_getter: OptionalArgCallable = None, + acceptable_error: Optional[Iterable[float]] = None, + acceptable_revolute_joint_position_error: float = np.pi / 180, + acceptable_prismatic_joint_position_error: float = 1e-3, + acceptable_percentage_of_goal_achieved: float = 0.8): + """ + Initialize the multi-joint position goal validator. + + :param current_positions_getter: The current positions getter function which takes an optional input and + returns the current positions. + :param acceptable_error: The acceptable error. + :param acceptable_revolute_joint_position_error: The acceptable orientation error. + :param acceptable_prismatic_joint_position_error: The acceptable position error. + :param acceptable_percentage_of_goal_achieved: The acceptable percentage of goal achieved. + """ + super().__init__(SingleValueErrorChecker(acceptable_error, is_iterable=True), current_positions_getter, + acceptable_percentage_of_goal_achieved) + self.acceptable_orientation_error = acceptable_revolute_joint_position_error + self.acceptable_position_error = acceptable_prismatic_joint_position_error + + def register_goal(self, goal_value: Any, joint_type: Iterable[JointType], + current_value_getter_input: Optional[Any] = None, + initial_value: Optional[Any] = None, + acceptable_error: Optional[Iterable[float]] = None): + if acceptable_error is None: + self.error_checker.acceptable_error = [self.acceptable_orientation_error if jt == JointType.REVOLUTE + else self.acceptable_position_error for jt in joint_type] + super().register_goal(goal_value, current_value_getter_input, initial_value, acceptable_error) + + +def validate_object_pose(pose_setter_func): + """ + A decorator to validate the object pose. + + :param pose_setter_func: The function to set the pose of the object. + """ + + def wrapper(world: 'World', obj: 'Object', pose: 'Pose'): + + world.pose_goal_validator.register_goal(pose, obj) + + if not pose_setter_func(world, obj, pose): + world.pose_goal_validator.reset() + return False + + world.pose_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_multiple_object_poses(pose_setter_func): + """ + A decorator to validate multiple object poses. + + :param pose_setter_func: The function to set multiple poses of the objects. + """ + + def wrapper(world: 'World', object_poses: Dict['Object', 'Pose']): + + world.multi_pose_goal_validator.register_goal(list(object_poses.values()), + list(object_poses.keys())) + + if not pose_setter_func(world, object_poses): + world.multi_pose_goal_validator.reset() + return False + + world.multi_pose_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_joint_position(position_setter_func): + """ + A decorator to validate the joint position. + + :param position_setter_func: The function to set the joint position. + """ + + def wrapper(world: 'World', joint: 'Joint', position: float): + + joint_type = joint.type + world.joint_position_goal_validator.register_goal(position, joint_type, joint) + + if not position_setter_func(world, joint, position): + world.joint_position_goal_validator.reset() + return False + + world.joint_position_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper + + +def validate_multiple_joint_positions(position_setter_func): + """ + A decorator to validate the joint positions, this function does not validate the virtual joints, + as in multiverse the virtual joints take command velocities and not positions, so after their goals + are set, they are zeroed thus can't be validated. (They are actually validated by the robot pose in case + of virtual mobile base joints) + + :param position_setter_func: The function to set the joint positions. + """ + + def wrapper(world: 'World', joint_positions: Dict['Joint', float]): + joint_positions_to_validate = {joint: position for joint, position in joint_positions.items() + if not joint.is_virtual} + joint_types = [joint.type for joint in joint_positions_to_validate.keys()] + world.multi_joint_position_goal_validator.register_goal(list(joint_positions_to_validate.values()), joint_types, + list(joint_positions_to_validate.keys())) + if not position_setter_func(world, joint_positions): + world.multi_joint_position_goal_validator.reset() + return False + + world.multi_joint_position_goal_validator.wait_until_goal_is_achieved() + return True + + return wrapper diff --git a/src/pycram/world_concepts/constraints.py b/src/pycram/world_concepts/constraints.py index aa41e3bdf..8de36bda1 100644 --- a/src/pycram/world_concepts/constraints.py +++ b/src/pycram/world_concepts/constraints.py @@ -2,7 +2,7 @@ import numpy as np from geometry_msgs.msg import Point -from typing_extensions import Union, List, Optional, TYPE_CHECKING +from typing_extensions import Union, List, Optional, TYPE_CHECKING, Self from ..datastructures.enums import JointType from ..datastructures.pose import Transform, Pose @@ -30,6 +30,44 @@ def __init__(self, self.child_to_constraint = child_to_constraint self._parent_to_child = None + def get_child_object_pose(self) -> Pose: + """ + :return: The pose of the child object. + """ + return self.child_link.object.pose + + def get_child_object_pose_given_parent(self, pose: Pose) -> Pose: + """ + Get the pose of the child object given the parent pose. + + :param pose: The parent object pose. + :return: The pose of the child object. + """ + pose = self.parent_link.get_pose_given_object_pose(pose) + child_link_pose = self.get_child_link_target_pose_given_parent(pose) + return self.child_link.get_object_pose_given_link_pose(child_link_pose) + + def set_child_link_pose(self): + """ + Set the target pose of the child object to the current pose of the child object in the parent object frame. + """ + self.child_link.set_pose(self.get_child_link_target_pose()) + + def get_child_link_target_pose(self) -> Pose: + """ + :return: The target pose of the child object. (The pose of the child object in the parent object frame) + """ + return self.parent_to_child_transform.to_pose() + + def get_child_link_target_pose_given_parent(self, parent_pose: Pose) -> Pose: + """ + Get the target pose of the child object link given the parent link pose. + + :param parent_pose: The parent link pose. + :return: The target pose of the child object link. + """ + return (parent_pose.to_transform(self.parent_link.tf_frame) * self.parent_to_child_transform).to_pose() + @property def parent_to_child_transform(self) -> Union[Transform, None]: if self._parent_to_child is None: @@ -44,8 +82,6 @@ def parent_to_child_transform(self, transform: Transform) -> None: @property def parent_object_id(self) -> int: """ - Returns the id of the parent object of the constraint. - :return: The id of the parent object of the constraint """ return self.parent_link.object_id @@ -53,8 +89,6 @@ def parent_object_id(self) -> int: @property def child_object_id(self) -> int: """ - Returns the id of the child object of the constraint. - :return: The id of the child object of the constraint """ return self.child_link.object_id @@ -62,8 +96,6 @@ def child_object_id(self) -> int: @property def parent_link_id(self) -> int: """ - Returns the id of the parent link of the constraint. - :return: The id of the parent link of the constraint """ return self.parent_link.id @@ -71,8 +103,6 @@ def parent_link_id(self) -> int: @property def child_link_id(self) -> int: """ - Returns the id of the child link of the constraint. - :return: The id of the child link of the constraint """ return self.child_link.id @@ -80,8 +110,6 @@ def child_link_id(self) -> int: @property def position_wrt_parent_as_list(self) -> List[float]: """ - Returns the constraint frame pose with respect to the parent origin as a list. - :return: The constraint frame pose with respect to the parent origin as a list """ return self.pose_wrt_parent.position_as_list() @@ -89,8 +117,6 @@ def position_wrt_parent_as_list(self) -> List[float]: @property def orientation_wrt_parent_as_list(self) -> List[float]: """ - Returns the constraint frame orientation with respect to the parent origin as a list. - :return: The constraint frame orientation with respect to the parent origin as a list """ return self.pose_wrt_parent.orientation_as_list() @@ -98,8 +124,6 @@ def orientation_wrt_parent_as_list(self) -> List[float]: @property def pose_wrt_parent(self) -> Pose: """ - Returns the joint frame pose with respect to the parent origin. - :return: The joint frame pose with respect to the parent origin """ return self.parent_to_constraint.to_pose() @@ -107,8 +131,6 @@ def pose_wrt_parent(self) -> Pose: @property def position_wrt_child_as_list(self) -> List[float]: """ - Returns the constraint frame pose with respect to the child origin as a list. - :return: The constraint frame pose with respect to the child origin as a list """ return self.pose_wrt_child.position_as_list() @@ -116,8 +138,6 @@ def position_wrt_child_as_list(self) -> List[float]: @property def orientation_wrt_child_as_list(self) -> List[float]: """ - Returns the constraint frame orientation with respect to the child origin as a list. - :return: The constraint frame orientation with respect to the child origin as a list """ return self.pose_wrt_child.orientation_as_list() @@ -125,8 +145,6 @@ def orientation_wrt_child_as_list(self) -> List[float]: @property def pose_wrt_child(self) -> Pose: """ - Returns the joint frame pose with respect to the child origin. - :return: The joint frame pose with respect to the child origin """ return self.child_to_constraint.to_pose() @@ -151,8 +169,6 @@ def __init__(self, @property def axis_as_list(self) -> List[float]: """ - Returns the axis of this constraint as a list. - :return: The axis of this constraint as a list of xyz """ return [self.axis.x, self.axis.y, self.axis.z] @@ -162,9 +178,10 @@ class Attachment(AbstractConstraint): def __init__(self, parent_link: Link, child_link: Link, - bidirectional: Optional[bool] = False, + bidirectional: bool = False, parent_to_child_transform: Optional[Transform] = None, - constraint_id: Optional[int] = None): + constraint_id: Optional[int] = None, + is_inverse: bool = False): """ Creates an attachment between the parent object link and the child object link. This could be a bidirectional attachment, meaning that both objects will move when one moves. @@ -180,48 +197,61 @@ def __init__(self, self.id = constraint_id self.bidirectional: bool = bidirectional self._loose: bool = False + self.is_inverse: bool = is_inverse - if self.parent_to_child_transform is None: + if parent_to_child_transform is not None: + self.parent_to_child_transform = parent_to_child_transform + + elif self.parent_to_child_transform is None: self.update_transform() if self.id is None: self.add_fixed_constraint() + @property + def parent_object(self): + return self.parent_link.object + + @property + def child_object(self): + return self.child_link.object + def update_transform_and_constraint(self) -> None: """ - Updates the transform and constraint of this attachment. + Update the transform and constraint of this attachment. """ self.update_transform() self.update_constraint() def update_transform(self) -> None: """ - Updates the transform of this attachment by calculating the transform from the parent link to the child link. + Update the transform of this attachment by calculating the transform from the parent link to the child link. """ self.parent_to_child_transform = self.calculate_transform() def update_constraint(self) -> None: """ - Updates the constraint of this attachment by removing the old constraint if one exists and adding a new one. + Update the constraint of this attachment by removing the old constraint if one exists and adding a new one. """ self.remove_constraint_if_exists() self.add_fixed_constraint() def add_fixed_constraint(self) -> None: """ - Adds a fixed constraint between the parent link and the child link. + Add a fixed constraint between the parent link and the child link. """ - self.id = self.parent_link.add_fixed_constraint_with_link(self.child_link) + self.id = self.parent_link.add_fixed_constraint_with_link(self.child_link, + self.parent_to_child_transform.invert()) def calculate_transform(self) -> Transform: """ - Calculates the transform from the parent link to the child link. + Calculate the transform from the parent link to the child link. """ return self.parent_link.get_transform_to_link(self.child_link) def remove_constraint_if_exists(self) -> None: """ - Removes the constraint between the parent and the child links if one exists. + Remove the constraint between the parent and the child links if one exists. """ if self.child_link in self.parent_link.constraint_ids: self.parent_link.remove_constraint_with_link(self.child_link) @@ -231,7 +261,7 @@ def get_inverse(self) -> 'Attachment': :return: A new Attachment object with the parent and child links swapped. """ attachment = Attachment(self.child_link, self.parent_link, self.bidirectional, - constraint_id=self.id) + constraint_id=self.id, is_inverse=not self.is_inverse) attachment.loose = not self._loose return attachment @@ -245,22 +275,15 @@ def loose(self) -> bool: @loose.setter def loose(self, loose: bool) -> None: """ - Sets the loose property of this attachment. + Set the loose property of this attachment. :param loose: If true, then the child object will not move when parent moves. """ self._loose = loose and not self.bidirectional - @property - def is_reversed(self) -> bool: - """ - :return: True if the parent and child links are swapped. - """ - return self.loose - def __del__(self) -> None: """ - Removes the constraint between the parent and the child links if one exists when the attachment is deleted. + Remove the constraint between the parent and the child links if one exists when the attachment is deleted. """ self.remove_constraint_if_exists() @@ -272,10 +295,11 @@ def __eq__(self, other): return (self.parent_link.name == other.parent_link.name and self.child_link.name == other.child_link.name and self.bidirectional == other.bidirectional + and self.loose == other.loose and np.allclose(self.parent_to_child_transform.translation_as_list(), - other.parent_to_child_transform.translation_as_list(), rtol=0, atol=1e-4) + other.parent_to_child_transform.translation_as_list(), rtol=0, atol=1e-3) and np.allclose(self.parent_to_child_transform.rotation_as_list(), - other.parent_to_child_transform.rotation_as_list(), rtol=0, atol=1e-4)) + other.parent_to_child_transform.rotation_as_list(), rtol=0, atol=1e-3)) def __hash__(self): return hash((self.parent_link.name, self.child_link.name, self.bidirectional, self.parent_to_child_transform)) diff --git a/src/pycram/world_concepts/world_object.py b/src/pycram/world_concepts/world_object.py index 19b04e3fa..3c93d4272 100644 --- a/src/pycram/world_concepts/world_object.py +++ b/src/pycram/world_concepts/world_object.py @@ -2,23 +2,34 @@ import logging import os +from pathlib import Path import numpy as np -import rospy +from deprecated import deprecated from geometry_msgs.msg import Point, Quaternion from typing_extensions import Type, Optional, Dict, Tuple, List, Union -from ..description import ObjectDescription, LinkDescription, Joint -from ..object_descriptors.urdf import ObjectDescription as URDFObject -from ..robot_descriptions import robot_description -from ..datastructures.world import WorldEntity, World -from ..world_concepts.constraints import Attachment from ..datastructures.dataclasses import (Color, ObjectState, LinkState, JointState, - AxisAlignedBoundingBox, VisualShape) + AxisAlignedBoundingBox, VisualShape, ClosestPointsList, + ContactPointsList) from ..datastructures.enums import ObjectType, JointType -from ..local_transformer import LocalTransformer from ..datastructures.pose import Pose, Transform -from ..robot_description import RobotDescriptionManager +from ..datastructures.world import World +from ..datastructures.world_entity import WorldEntity +from ..description import ObjectDescription, LinkDescription, Joint +from ..failures import ObjectAlreadyExists, WorldMismatchErrorBetweenObjects, UnsupportedFileExtension, \ + ObjectDescriptionUndefined +from ..local_transformer import LocalTransformer +from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription +from ..object_descriptors.urdf import ObjectDescription as URDF +from ..ros.logging import logwarn + +try: + from ..object_descriptors.mjcf import ObjectDescription as MJCF +except ImportError: + MJCF = None +from ..robot_description import RobotDescriptionManager, RobotDescription +from ..world_concepts.constraints import Attachment Link = ObjectDescription.Link @@ -28,18 +39,23 @@ class Object(WorldEntity): Represents a spawned Object in the World. """ - prospection_world_prefix: str = "prospection/" + tf_prospection_world_prefix: str = "prospection/" """ - The ObjectDescription of the object, this contains the name and type of the object as well as the path to the source - file. + The prefix for the tf frame of objects in the prospection world. """ - def __init__(self, name: str, obj_type: ObjectType, path: str, - description: Optional[Type[ObjectDescription]] = URDFObject, + extension_to_description_type: Dict[str, Type[ObjectDescription]] = {URDF.get_file_extension(): URDF} + """ + A dictionary that maps the file extension to the corresponding ObjectDescription type. + """ + + def __init__(self, name: str, obj_type: ObjectType, path: Optional[str] = None, + description: Optional[ObjectDescription] = None, pose: Optional[Pose] = None, world: Optional[World] = None, - color: Optional[Color] = Color(), - ignore_cached_files: Optional[bool] = False): + color: Color = Color(), + ignore_cached_files: bool = False, + scale_mesh: Optional[float] = None): """ The constructor loads the description file into the given World, if no World is specified the :py:attr:`~World.current_world` will be used. It is also possible to load .obj and .stl file into the World. @@ -48,37 +64,45 @@ def __init__(self, name: str, obj_type: ObjectType, path: str, :param name: The name of the object :param obj_type: The type of the object as an ObjectType enum. - :param path: The path to the source file, if only a filename is provided then the resources directories will be searched. + :param path: The path to the source file, if only a filename is provided then the resources directories will be + searched, it could be None in some cases when for example it is a generic object. :param description: The ObjectDescription of the object, this contains the joints and links of the object. :param pose: The pose at which the Object should be spawned - :param world: The World in which the object should be spawned, if no world is specified the :py:attr:`~World.current_world` will be used. + :param world: The World in which the object should be spawned, if no world is specified the + :py:attr:`~World.current_world` will be used. :param color: The rgba_color with which the object should be spawned. :param ignore_cached_files: If true the file will be spawned while ignoring cached files. + :param scale_mesh: The scale of the mesh. """ - super().__init__(-1, world) + super().__init__(-1, world if world is not None else World.current_world) + + pose = Pose() if pose is None else pose - if pose is None: - pose = Pose() - if name in [obj.name for obj in self.world.objects]: - rospy.logerr(f"An object with the name {name} already exists in the world.") - return None self.name: str = name + self.path: Optional[str] = path self.obj_type: ObjectType = obj_type self.color: Color = color - self.description = description() + self._resolve_description(path, description) self.cache_manager = self.world.cache_manager self.local_transformer = LocalTransformer() self.original_pose = self.local_transformer.transform_pose(pose, "map") self._current_pose = self.original_pose - self.id, self.path = self._load_object_and_get_id(path, ignore_cached_files) + if path is not None: + self.path = self.world.preprocess_object_file_and_get_its_cache_path(path, ignore_cached_files, + self.description, self.name, + scale_mesh=scale_mesh) - self.description.update_description_from_file(self.path) + self.description.update_description_from_file(self.path) - self.tf_frame = ((self.prospection_world_prefix if self.world.is_prospection_world else "") - + f"{self.name}") + if self.obj_type == ObjectType.ROBOT and not self.world.is_prospection_world: + self._update_world_robot_and_description() + + self.id = self._spawn_object_and_get_id() + + self.tf_frame = (self.tf_prospection_world_prefix if self.world.is_prospection_world else "") + self.name self._init_joint_name_and_id_map() self._init_link_name_and_id_map() @@ -88,26 +112,178 @@ def __init__(self, name: str, obj_type: ObjectType, path: str, self.attachments: Dict[Object, Attachment] = {} - if not self.world.is_prospection_world: - self._add_to_world_sync_obj_queue() + self.world.add_object(self) - self.world.objects.append(self) + def _resolve_description(self, path: Optional[str] = None, description: Optional[ObjectDescription] = None) -> None: + """ + Find the correct description type of the object and initialize it and set the description of this object to it. - if self.obj_type == ObjectType.ROBOT and not self.world.is_prospection_world: - rdm = RobotDescriptionManager() - rdm.load_description(self.name) - World.robot = self + :param path: The path to the source file. + :param description: The ObjectDescription of the object. + """ + if description is not None: + self.description = description + return + if path is None: + raise ObjectDescriptionUndefined(self.name) + extension = Path(path).suffix + if extension in self.extension_to_description_type: + self.description = self.extension_to_description_type[extension]() + elif extension in ObjectDescription.mesh_extensions: + self.description = self.world.conf.default_description_type() + else: + raise UnsupportedFileExtension(self.name, path) + + def set_mobile_robot_pose(self, pose: Pose) -> None: + """ + Set the goal for the mobile base joints of a mobile robot to reach a target pose. This is used for example when + the simulator does not support setting the pose of the robot directly (e.g. MuJoCo). + + :param pose: The target pose. + """ + goal = self.get_mobile_base_joint_goal(pose) + self.set_multiple_joint_positions(goal) + + def get_mobile_base_joint_goal(self, pose: Pose) -> Dict[str, float]: + """ + Get the goal for the mobile base joints of a mobile robot to reach a target pose. + + :param pose: The target pose. + :return: The goal for the mobile base joints. + """ + target_translation, target_angle = self.get_mobile_base_pose_difference(pose) + # Get the joints of the base link + mobile_base_joints = self.world.get_robot_mobile_base_joints() + return {mobile_base_joints.translation_x: target_translation.x, + mobile_base_joints.translation_y: target_translation.y, + mobile_base_joints.angular_z: target_angle} + + def get_mobile_base_pose_difference(self, pose: Pose) -> Tuple[Point, float]: + """ + Get the difference between the current and the target pose of the mobile base. + + :param pose: The target pose. + :return: The difference between the current and the target pose of the mobile base. + """ + return self.original_pose.get_position_diff(pose), self.original_pose.get_z_angle_difference(pose) + + @property + def joint_actuators(self) -> Optional[Dict[str, str]]: + """ + The joint actuators of the robot. + """ + if self.obj_type == ObjectType.ROBOT: + return self.robot_description.joint_actuators + return None + + @property + def has_actuators(self) -> bool: + """ + True if the object has actuators, otherwise False. + """ + return self.robot_description.has_actuators + + @property + def robot_description(self) -> RobotDescription: + """ + The current robot description. + """ + return self.world.robot_description + + def get_actuator_for_joint(self, joint: Joint) -> Optional[str]: + """ + Get the actuator name for a joint. + + :param joint: The joint object for which to get the actuator. + :return: The name of the actuator. + """ + return self.robot_description.get_actuator_for_joint(joint.name) + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the positions of multiple links of the object. + + :param links: The link objects of which to get the positions. + :return: The positions of the links. + """ + return self.world.get_multiple_link_positions(links) + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + """ + Get the orientations of multiple links of the object. + + :param links: The link objects of which to get the orientations. + :return: The orientations of the links. + """ + return self.world.get_multiple_link_orientations(links) + + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + """ + Get the poses of multiple links of the object. + + :param links: The link objects of which to get the poses. + :return: The poses of the links. + """ + return self.world.get_multiple_link_poses(links) + + def get_poses_of_attached_objects(self) -> Dict[Object, Pose]: + """ + Get the poses of the attached objects. + + :return: The poses of the attached objects + """ + return {child_object: attachment.get_child_object_pose() + for child_object, attachment in self.attachments.items() if not attachment.loose} + + def get_target_poses_of_attached_objects_given_parent(self, pose: Pose) -> Dict[Object, Pose]: + """ + Get the target poses of the attached objects of an object. Given the pose of the parent object. (i.e. the poses + to which the attached objects will move when the parent object is at the given pose) + + :param pose: The pose of the parent object. + :return: The target poses of the attached objects + """ + return {child_object: attachment.get_child_object_pose_given_parent(pose) for child_object, attachment + in self.attachments.items() if not attachment.loose} + + @property + def name(self): + """ + The name of the object. + """ + return self._name + + @name.setter + def name(self, name: str): + """ + Set the name of the object. + """ + self._name = name + if name in [obj.name for obj in self.world.objects]: + raise ObjectAlreadyExists(self) @property def pose(self): + """ + The current pose of the object. + """ return self.get_pose() @pose.setter def pose(self, pose: Pose): + """ + Set the pose of the object. + """ self.set_pose(pose) - def _load_object_and_get_id(self, path: Optional[str] = None, - ignore_cached_files: Optional[bool] = False) -> Tuple[int, Union[str, None]]: + @property + def transform(self): + """ + The current transform of the object. + """ + return self.get_pose().to_transform(self.tf_frame) + + def _spawn_object_and_get_id(self) -> int: """ Loads an object to the given World with the given position and orientation. The rgba_color will only be used when an .obj or .stl file is given. @@ -115,28 +291,18 @@ def _load_object_and_get_id(self, path: Optional[str] = None, and this URDf file will be loaded instead. When spawning a URDf file a new file will be created in the cache directory, if there exists none. This new file will have resolved mesh file paths, meaning there will be no references - to ROS packges instead there will be absolute file paths. + to ROS packages instead there will be absolute file paths. - :param path: The path to the description file, if None then no file will be loaded, this is useful when the PyCRAM is not responsible for loading the file but another system is. - :param ignore_cached_files: Whether to ignore files in the cache directory. :return: The unique id of the object and the path of the file that was loaded. """ - if path is not None: - try: - path = self.world.update_cache_dir_with_object(path, ignore_cached_files, self) - except FileNotFoundError as e: - logging.error("Could not generate description from file.") - raise e + if isinstance(self.description, GenericObjectDescription): + return self.world.load_generic_object_and_get_id(self.description, pose=self._current_pose) + + path = self.path if self.world.conf.let_pycram_handle_spawning else self.name try: - simulator_object_path = path - if simulator_object_path is None: - # This is useful when the object is already loaded in the simulator so it would use its name instead of - # its path - simulator_object_path = self.name - obj_id = self.world.load_object_and_get_id(simulator_object_path, Pose(self.get_position_as_list(), - self.get_orientation_as_list())) - return obj_id, path + obj_id = self.world.load_object_and_get_id(path, self._current_pose, self.obj_type) + return obj_id except Exception as e: logging.error( @@ -145,9 +311,31 @@ def _load_object_and_get_id(self, path: Optional[str] = None, os.remove(path) raise e + def _update_world_robot_and_description(self): + """ + Initialize the robot description of the object, load the description from the RobotDescriptionManager and set + the robot as the current robot in the World. Also add the virtual mobile base joints to the robot. + """ + rdm = RobotDescriptionManager() + rdm.load_description(self.description.name) + World.robot = self + self._add_virtual_move_base_joints() + + def _add_virtual_move_base_joints(self): + """ + Add the virtual mobile base joints to the robot description. + """ + virtual_joints = self.robot_description.virtual_mobile_base_joints + if virtual_joints is None: + return + child_link = self.description.get_root() + axes = virtual_joints.get_axes() + for joint_name, joint_type in virtual_joints.get_types().items(): + self.description.add_joint(joint_name, child_link, joint_type, axes[joint_name], is_virtual=True) + def _init_joint_name_and_id_map(self) -> None: """ - Creates a dictionary which maps the joint names to their unique ids and vice versa. + Create a dictionary which maps the joint names to their unique ids and vice versa. """ n_joints = len(self.joint_names) self.joint_name_to_id = dict(zip(self.joint_names, range(n_joints))) @@ -155,7 +343,7 @@ def _init_joint_name_and_id_map(self) -> None: def _init_link_name_and_id_map(self) -> None: """ - Creates a dictionary which maps the link names to their unique ids and vice versa. + Create a dictionary which maps the link names to their unique ids and vice versa. """ n_links = len(self.link_names) self.link_name_to_id: Dict[str, int] = dict(zip(self.link_names, range(n_links))) @@ -164,7 +352,7 @@ def _init_link_name_and_id_map(self) -> None: def _init_links_and_update_transforms(self) -> None: """ - Initializes the link objects from the URDF file and creates a dictionary which maps the link names to the + Initialize the link objects from the URDF file and creates a dictionary which maps the link names to the corresponding link objects. """ self.links = {} @@ -184,32 +372,54 @@ def _init_joints(self): """ self.joints = {} for joint_name, joint_id in self.joint_name_to_id.items(): - joint_description = self.description.get_joint_by_name(joint_name) - self.joints[joint_name] = self.description.Joint(joint_id, joint_description, self) + parsed_joint_description = self.description.get_joint_by_name(joint_name) + is_virtual = self.is_joint_virtual(joint_name) + self.joints[joint_name] = self.description.Joint(joint_id, parsed_joint_description, self, is_virtual) + + def is_joint_virtual(self, name: str): + """ + Check if a joint is virtual. + """ + return self.description.is_joint_virtual(name) + + @property + def virtual_joint_names(self): + """ + The names of the virtual joints. + """ + return self.description.virtual_joint_names - def _add_to_world_sync_obj_queue(self) -> None: + @property + def virtual_joints(self): """ - Adds this object to the objects queue of the WorldSync object of the World. + The virtual joints as a list. """ - self.world.world_sync.add_obj_queue.put(self) + return [joint for joint in self.joints.values() if joint.is_virtual] + + @property + def has_one_link(self) -> bool: + """ + True if the object has only one link, otherwise False. + """ + return len(self.links) == 1 @property def link_names(self) -> List[str]: """ - :return: The name of each link as a list. + The names of the links as a list. """ return self.world.get_object_link_names(self) @property def joint_names(self) -> List[str]: """ - :return: The name of each joint as a list. + The names of the joints as a list. """ return self.world.get_object_joint_names(self) def get_link(self, link_name: str) -> ObjectDescription.Link: """ - Returns the link object with the given name. + Return the link object with the given name. :param link_name: The name of the link. :return: The link object. @@ -218,7 +428,7 @@ def get_link(self, link_name: str) -> ObjectDescription.Link: def get_link_pose(self, link_name: str) -> Pose: """ - Returns the pose of the link with the given name. + Return the pose of the link with the given name. :param link_name: The name of the link. :return: The pose of the link. @@ -227,7 +437,7 @@ def get_link_pose(self, link_name: str) -> Pose: def get_link_position(self, link_name: str) -> Point: """ - Returns the position of the link with the given name. + Return the position of the link with the given name. :param link_name: The name of the link. :return: The position of the link. @@ -236,7 +446,7 @@ def get_link_position(self, link_name: str) -> Point: def get_link_position_as_list(self, link_name: str) -> List[float]: """ - Returns the position of the link with the given name. + Return the position of the link with the given name. :param link_name: The name of the link. :return: The position of the link. @@ -245,7 +455,7 @@ def get_link_position_as_list(self, link_name: str) -> List[float]: def get_link_orientation(self, link_name: str) -> Quaternion: """ - Returns the orientation of the link with the given name. + Return the orientation of the link with the given name. :param link_name: The name of the link. :return: The orientation of the link. @@ -254,7 +464,7 @@ def get_link_orientation(self, link_name: str) -> Quaternion: def get_link_orientation_as_list(self, link_name: str) -> List[float]: """ - Returns the orientation of the link with the given name. + Return the orientation of the link with the given name. :param link_name: The name of the link. :return: The orientation of the link. @@ -263,7 +473,7 @@ def get_link_orientation_as_list(self, link_name: str) -> List[float]: def get_link_tf_frame(self, link_name: str) -> str: """ - Returns the tf frame of the link with the given name. + Return the tf frame of the link with the given name. :param link_name: The name of the link. :return: The tf frame of the link. @@ -272,7 +482,7 @@ def get_link_tf_frame(self, link_name: str) -> str: def get_link_axis_aligned_bounding_box(self, link_name: str) -> AxisAlignedBoundingBox: """ - Returns the axis aligned bounding box of the link with the given name. + Return the axis aligned bounding box of the link with the given name. :param link_name: The name of the link. :return: The axis aligned bounding box of the link. @@ -281,7 +491,7 @@ def get_link_axis_aligned_bounding_box(self, link_name: str) -> AxisAlignedBound def get_transform_between_links(self, from_link: str, to_link: str) -> Transform: """ - Returns the transform between two links. + Return the transform between two links. :param from_link: The name of the link from which the transform should be calculated. :param to_link: The name of the link to which the transform should be calculated. @@ -290,7 +500,7 @@ def get_transform_between_links(self, from_link: str, to_link: str) -> Transform def get_link_color(self, link_name: str) -> Color: """ - Returns the color of the link with the given name. + Return the color of the link with the given name. :param link_name: The name of the link. :return: The color of the link. @@ -299,7 +509,7 @@ def get_link_color(self, link_name: str) -> Color: def set_link_color(self, link_name: str, color: List[float]) -> None: """ - Sets the color of the link with the given name. + Set the color of the link with the given name. :param link_name: The name of the link. :param color: The new color of the link. @@ -308,7 +518,7 @@ def set_link_color(self, link_name: str, color: List[float]) -> None: def get_link_geometry(self, link_name: str) -> Union[VisualShape, None]: """ - Returns the geometry of the link with the given name. + Return the geometry of the link with the given name. :param link_name: The name of the link. :return: The geometry of the link. @@ -317,7 +527,7 @@ def get_link_geometry(self, link_name: str) -> Union[VisualShape, None]: def get_link_transform(self, link_name: str) -> Transform: """ - Returns the transform of the link with the given name. + Return the transform of the link with the given name. :param link_name: The name of the link. :return: The transform of the link. @@ -326,7 +536,7 @@ def get_link_transform(self, link_name: str) -> Transform: def get_link_origin(self, link_name: str) -> Pose: """ - Returns the origin of the link with the given name. + Return the origin of the link with the given name. :param link_name: The name of the link. :return: The origin of the link as a 'Pose'. @@ -335,7 +545,7 @@ def get_link_origin(self, link_name: str) -> Pose: def get_link_origin_transform(self, link_name: str) -> Transform: """ - Returns the origin transform of the link with the given name. + Return the origin transform of the link with the given name. :param link_name: The name of the link. :return: The origin transform of the link. @@ -358,16 +568,16 @@ def __repr__(self): def remove(self) -> None: """ - Removes this object from the World it currently resides in. + Remove this object from the World it currently resides in. For the object to be removed it has to be detached from all objects it - is currently attached to. After this is done a call to world remove object is done + is currently attached to. After this call world remove object to remove this Object from the simulation/world. """ self.world.remove_object(self) - def reset(self, remove_saved_states=True) -> None: + def reset(self, remove_saved_states=False) -> None: """ - Resets the Object to the state it was first spawned in. + Reset the Object to the state it was first spawned in. All attached objects will be detached, all joints will be set to the default position of 0 and the object will be set to the position and orientation in which it was spawned. @@ -380,13 +590,23 @@ def reset(self, remove_saved_states=True) -> None: if remove_saved_states: self.remove_saved_states() + def has_type_environment(self) -> bool: + """ + Check if the object is of type environment. + + :return: True if the object is of type environment, False otherwise. + """ + return self.obj_type == ObjectType.ENVIRONMENT + def attach(self, child_object: Object, parent_link: Optional[str] = None, child_link: Optional[str] = None, - bidirectional: Optional[bool] = True) -> None: + bidirectional: bool = True, + coincide_the_objects: bool = False, + parent_to_child_transform: Optional[Transform] = None) -> None: """ - Attaches another object to this object. This is done by + Attach another object to this object. This is done by saving the transformation between the given link, if there is one, and the base pose of the other object. Additionally, the name of the link, to which the object is attached, will be saved. @@ -399,11 +619,15 @@ def attach(self, :param parent_link: The link name of this object. :param child_link: The link name of the other object. :param bidirectional: If the attachment should be a loose attachment. + :param coincide_the_objects: If True the object frames will be coincided. + :param parent_to_child_transform: The transform from the parent to the child object. """ parent_link = self.links[parent_link] if parent_link else self.root_link child_link = child_object.links[child_link] if child_link else child_object.root_link - attachment = Attachment(parent_link, child_link, bidirectional) + if coincide_the_objects and parent_to_child_transform is None: + parent_to_child_transform = Transform() + attachment = Attachment(parent_link, child_link, bidirectional, parent_to_child_transform) self.attachments[child_object] = attachment child_object.attachments[self] = attachment.get_inverse() @@ -412,7 +636,7 @@ def attach(self, def detach(self, child_object: Object) -> None: """ - Detaches another object from this object. This is done by + Detache another object from this object. This is done by deleting the attachment from the attachments dictionary of both objects and deleting the constraint of the simulator. Afterward the detachment event of the corresponding World will be fired. @@ -437,7 +661,7 @@ def update_attachment_with_object(self, child_object: Object): def get_position(self) -> Point: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -445,7 +669,7 @@ def get_position(self) -> Point: def get_orientation(self) -> Pose.orientation: """ - Returns the orientation of this object as a list of xyzw, representing a quaternion. + Return the orientation of this object as a list of xyzw, representing a quaternion. :return: A list of xyzw """ @@ -453,7 +677,7 @@ def get_orientation(self) -> Pose.orientation: def get_position_as_list(self) -> List[float]: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -461,7 +685,7 @@ def get_position_as_list(self) -> List[float]: def get_base_position_as_list(self) -> List[float]: """ - Returns the position of this Object as a list of xyz. + Return the position of this Object as a list of xyz. :return: The current position of this object """ @@ -469,7 +693,7 @@ def get_base_position_as_list(self) -> List[float]: def get_orientation_as_list(self) -> List[float]: """ - Returns the orientation of this object as a list of xyzw, representing a quaternion. + Return the orientation of this object as a list of xyzw, representing a quaternion. :return: A list of xyzw """ @@ -477,15 +701,17 @@ def get_orientation_as_list(self) -> List[float]: def get_pose(self) -> Pose: """ - Returns the position of this object as a list of xyz. Alias for :func:`~Object.get_position`. + Return the position of this object as a list of xyz. Alias for :func:`~Object.get_position`. :return: The current pose of this object """ + if self.world.conf.update_poses_from_sim_on_get: + self.update_pose() return self._current_pose - def set_pose(self, pose: Pose, base: Optional[bool] = False, set_attachments: Optional[bool] = True) -> None: + def set_pose(self, pose: Pose, base: bool = False, set_attachments: bool = True) -> None: """ - Sets the Pose of the object. + Set the Pose of the object. :param pose: New Pose for the object :param base: If True places the object base instead of origin at the specified position and orientation @@ -501,23 +727,24 @@ def set_pose(self, pose: Pose, base: Optional[bool] = False, set_attachments: Op self._set_attached_objects_poses() def reset_base_pose(self, pose: Pose): - self.world.reset_object_base_pose(self, pose) - self.update_pose() + if self.world.reset_object_base_pose(self, pose): + self.update_pose() def update_pose(self): """ - Updates the current pose of this object from the world, and updates the poses of all links. + Update the current pose of this object from the world, and updates the poses of all links. """ self._current_pose = self.world.get_object_pose(self) + # TODO: Probably not needed, need to test self._update_all_links_poses() self.update_link_transforms() def _update_all_links_poses(self): """ - Updates the poses of all links by getting them from the simulator. + Update the poses of all links by getting them from the simulator. """ for link in self.links.values(): - link._update_pose() + link.update_pose() def move_base_to_origin_pose(self) -> None: """ @@ -528,7 +755,7 @@ def move_base_to_origin_pose(self) -> None: def save_state(self, state_id) -> None: """ - Saves the state of this object by saving the state of all links and attachments. + Save the state of this object by saving the state of all links and attachments. :param state_id: The unique id of the state. """ @@ -538,7 +765,7 @@ def save_state(self, state_id) -> None: def save_links_states(self, state_id: int) -> None: """ - Saves the state of all links of this object. + Save the state of all links of this object. :param state_id: The unique id of the state. """ @@ -547,7 +774,7 @@ def save_links_states(self, state_id: int) -> None: def save_joints_states(self, state_id: int) -> None: """ - Saves the state of all joints of this object. + Save the state of all joints of this object. :param state_id: The unique id of the state. """ @@ -556,40 +783,107 @@ def save_joints_states(self, state_id: int) -> None: @property def current_state(self) -> ObjectState: - return ObjectState(self.get_pose().copy(), self.attachments.copy(), self.link_states.copy(), self.joint_states.copy()) + """ + The current state of this object as an ObjectState. + """ + return ObjectState(self.get_pose().copy(), self.attachments.copy(), self.link_states.copy(), + self.joint_states.copy(), self.world.conf.get_pose_tolerance()) @current_state.setter def current_state(self, state: ObjectState) -> None: - if self.get_pose().dist(state.pose) != 0.0: + """ + Set the current state of this object to the given state. + """ + if self.current_state != state: self.set_pose(state.pose, base=False, set_attachments=False) - - self.set_attachments(state.attachments) - self.link_states = state.link_states - self.joint_states = state.joint_states + self.set_attachments(state.attachments) + self.link_states = state.link_states + self.joint_states = state.joint_states def set_attachments(self, attachments: Dict[Object, Attachment]) -> None: """ - Sets the attachments of this object to the given attachments. + Set the attachments of this object to the given attachments. + + :param attachments: A dictionary with the object as key and the attachment as value. + """ + self.detach_objects_not_in_attachments(attachments) + self.attach_objects_in_attachments(attachments) + + def detach_objects_not_in_attachments(self, attachments: Dict[Object, Attachment]) -> None: + """ + Detach objects that are not in the attachments list and are in the current attachments list. + + :param attachments: A dictionary with the object as key and the attachment as value. + """ + copy_of_attachments = self.attachments.copy() + for obj, attachment in copy_of_attachments.items(): + original_obj = obj + if self.world.is_prospection_world and len(attachments) > 0 \ + and not list(attachments.keys())[0].world.is_prospection_world: + obj = self.world.get_object_for_prospection_object(obj) + if obj not in attachments: + if attachment.is_inverse: + original_obj.detach(self) + else: + self.detach(original_obj) + + def attach_objects_in_attachments(self, attachments: Dict[Object, Attachment]) -> None: + """ + Attach objects that are in the given attachments list but not in the current attachments list. :param attachments: A dictionary with the object as key and the attachment as value. """ for obj, attachment in attachments.items(): - if self.world.is_prospection_world and not obj.world.is_prospection_world: - # In case this object is in the prospection world and the other object is not, the attachment will no - # be set. - continue + is_prospection = self.world.is_prospection_world and not obj.world.is_prospection_world + if is_prospection: + obj = self.world.get_prospection_object_for_object(obj) if obj in self.attachments: if self.attachments[obj] != attachment: - self.detach(obj) + if attachment.is_inverse: + obj.detach(self) + else: + self.detach(obj) else: continue - self.attach(obj, attachment.parent_link.name, attachment.child_link.name, - attachment.bidirectional) + self.mimic_attachment_with_object(attachment, obj) + + def mimic_attachment_with_object(self, attachment: Attachment, child_object: Object) -> None: + """ + Mimic the given attachment for this and the given child objects. + + :param attachment: The attachment to mimic. + :param child_object: The child object. + """ + att_transform = self.get_attachment_transform_with_object(attachment, child_object) + if attachment.is_inverse: + child_object.attach(self, attachment.child_link.name, attachment.parent_link.name, + attachment.bidirectional, + parent_to_child_transform=att_transform.invert()) + else: + self.attach(child_object, attachment.parent_link.name, attachment.child_link.name, + attachment.bidirectional, parent_to_child_transform=att_transform) + + def get_attachment_transform_with_object(self, attachment: Attachment, child_object: Object) -> Transform: + """ + Return the attachment transform for the given parent and child objects, taking into account the prospection + world. + + :param attachment: The attachment. + :param child_object: The child object. + :return: The attachment transform. + """ + if self.world != child_object.world: + raise WorldMismatchErrorBetweenObjects(self, child_object) + att_transform = attachment.parent_to_child_transform.copy() + if self.world.is_prospection_world and not attachment.parent_object.world.is_prospection_world: + att_transform.frame = self.tf_prospection_world_prefix + att_transform.frame + att_transform.child_frame_id = self.tf_prospection_world_prefix + att_transform.child_frame_id + return att_transform @property def link_states(self) -> Dict[int, LinkState]: """ - Returns the current state of all links of this object. + The current state of all links of this object. :return: A dictionary with the link id as key and the current state of the link as value. """ @@ -598,7 +892,7 @@ def link_states(self) -> Dict[int, LinkState]: @link_states.setter def link_states(self, link_states: Dict[int, LinkState]) -> None: """ - Sets the current state of all links of this object. + Set the current state of all links of this object. :param link_states: A dictionary with the link id as key and the current state of the link as value. """ @@ -608,7 +902,7 @@ def link_states(self, link_states: Dict[int, LinkState]) -> None: @property def joint_states(self) -> Dict[int, JointState]: """ - Returns the current state of all joints of this object. + The current state of all joints of this object. :return: A dictionary with the joint id as key and the current state of the joint as value. """ @@ -617,16 +911,20 @@ def joint_states(self) -> Dict[int, JointState]: @joint_states.setter def joint_states(self, joint_states: Dict[int, JointState]) -> None: """ - Sets the current state of all joints of this object. + Set the current state of all joints of this object. :param joint_states: A dictionary with the joint id as key and the current state of the joint as value. """ for joint in self.joints.values(): - joint.current_state = joint_states[joint.id] + if joint.name not in self.robot_virtual_move_base_joints_names(): + joint.current_state = joint_states[joint.id] + + def robot_virtual_move_base_joints_names(self): + return self.robot_description.virtual_mobile_base_joints.names def remove_saved_states(self) -> None: """ - Removes all saved states of this object. + Remove all saved states of this object. """ super().remove_saved_states() self.remove_links_saved_states() @@ -634,28 +932,30 @@ def remove_saved_states(self) -> None: def remove_links_saved_states(self) -> None: """ - Removes all saved states of the links of this object. + Remove all saved states of the links of this object. """ for link in self.links.values(): link.remove_saved_states() def remove_joints_saved_states(self) -> None: """ - Removes all saved states of the joints of this object. + Remove all saved states of the joints of this object. """ for joint in self.joints.values(): joint.remove_saved_states() def _set_attached_objects_poses(self, already_moved_objects: Optional[List[Object]] = None) -> None: """ - Updates the positions of all attached objects. This is done + Update the positions of all attached objects. This is done by calculating the new pose in world coordinate frame and setting the base pose of the attached objects to this new pose. - After this the _set_attached_objects method of all attached objects - will be called. + After this call _set_attached_objects method for all attached objects. - :param already_moved_objects: A list of Objects that were already moved, these will be excluded to prevent loops in the update. + :param already_moved_objects: A list of Objects that were already moved, these will be excluded to prevent loops + in the update. """ + if not self.world.conf.let_pycram_move_attached_objects: + return if already_moved_objects is None: already_moved_objects = [] @@ -671,13 +971,12 @@ def _set_attached_objects_poses(self, already_moved_objects: Optional[List[Objec child.update_attachment_with_object(self) else: - link_to_object = attachment.parent_to_child_transform - child.set_pose(link_to_object.to_pose(), set_attachments=False) + child.set_pose(attachment.get_child_link_target_pose(), set_attachments=False) child._set_attached_objects_poses(already_moved_objects + [self]) def set_position(self, position: Union[Pose, Point, List], base=False) -> None: """ - Sets this Object to the given position, if base is true the bottom of the Object will be placed at the position + Set this Object to the given position, if base is true, place the bottom of the Object at the position instead of the origin in the center of the Object. The given position can either be a Pose, in this case only the position is used or a geometry_msgs.msg/Point which is the position part of a Pose. @@ -690,10 +989,13 @@ def set_position(self, position: Union[Pose, Point, List], base=False) -> None: pose.frame = position.frame elif isinstance(position, Point): target_position = position - elif isinstance(position, list): - target_position = position + elif isinstance(position, List): + if len(position) == 3: + target_position = Point(*position) + else: + raise ValueError("The given position has to be a list of 3 values.") else: - raise TypeError("The given position has to be a Pose, Point or a list of xyz.") + raise TypeError("The given position has to be a Pose, Point or an iterable of xyz values.") pose.position = target_position pose.orientation = self.get_orientation() @@ -701,7 +1003,7 @@ def set_position(self, position: Union[Pose, Point, List], base=False) -> None: def set_orientation(self, orientation: Union[Pose, Quaternion, List, Tuple, np.ndarray]) -> None: """ - Sets the orientation of the Object to the given orientation. Orientation can either be a Pose, in this case only + Set the orientation of the Object to the given orientation. Orientation can either be a Pose, in this case only the orientation of this pose is used or a geometry_msgs.msg/Quaternion which is the orientation of a Pose. :param orientation: Target orientation given as a list of xyzw. @@ -724,7 +1026,7 @@ def set_orientation(self, orientation: Union[Pose, Quaternion, List, Tuple, np.n def get_joint_id(self, name: str) -> int: """ - Returns the unique id for a joint name. As used by the world/simulator. + Return the unique id for a joint name. As used by the world/simulator. :param name: The joint name :return: The unique id @@ -733,35 +1035,35 @@ def get_joint_id(self, name: str) -> int: def get_root_link_description(self) -> LinkDescription: """ - Returns the root link of the URDF of this object. + Return the root link of the URDF of this object. :return: The root link as defined in the URDF of this object. """ for link_description in self.description.links: - if link_description.name == self.root_link_name: + if link_description.name == self.description.get_root(): return link_description @property def root_link(self) -> ObjectDescription.Link: """ - Returns the root link of this object. + The root link of this object. :return: The root link of this object. """ return self.links[self.description.get_root()] @property - def root_link_name(self) -> str: + def tip_link(self) -> ObjectDescription.Link: """ - Returns the name of the root link of this object. + The tip link of this object. - :return: The name of the root link of this object. + :return: The tip link of this object. """ - return self.description.get_root() + return self.links[self.description.get_tip()] def get_root_link_id(self) -> int: """ - Returns the unique id of the root link of this object. + Return the unique id of the root link of this object. :return: The unique id of the root link of this object. """ @@ -769,7 +1071,7 @@ def get_root_link_id(self) -> int: def get_link_id(self, link_name: str) -> int: """ - Returns a unique id for a link name. + Return a unique id for a link name. :param link_name: The name of the link. :return: The unique id of the link. @@ -778,7 +1080,7 @@ def get_link_id(self, link_name: str) -> int: def get_link_by_id(self, link_id: int) -> ObjectDescription.Link: """ - Returns the link for a given unique link id + Return the link for a given unique link id :param link_id: The unique id of the link. :return: The link object. @@ -787,40 +1089,50 @@ def get_link_by_id(self, link_id: int) -> ObjectDescription.Link: def reset_all_joints_positions(self) -> None: """ - Sets the current position of all joints to 0. This is useful if the joints should be reset to their default + Set the current position of all joints to 0. This is useful if the joints should be reset to their default """ - joint_names = list(self.joint_name_to_id.keys()) + joint_names = [joint.name for joint in self.joints.values()] + if len(joint_names) == 0: + return joint_positions = [0] * len(joint_names) - self.set_joint_positions(dict(zip(joint_names, joint_positions))) + self.set_multiple_joint_positions(dict(zip(joint_names, joint_positions))) - def set_joint_positions(self, joint_poses: dict) -> None: + def set_joint_position(self, joint_name: str, joint_position: float) -> None: """ - Sets the current position of multiple joints at once, this method should be preferred when setting - multiple joints at once instead of running :func:`~Object.set_joint_position` in a loop. + Set the position of the given joint to the given joint pose and updates the poses of all attached objects. - :param joint_poses: + :param joint_name: The name of the joint + :param joint_position: The target pose for this joint """ - for joint_name, joint_position in joint_poses.items(): - self.joints[joint_name].position = joint_position - # self.update_pose() - self._update_all_links_poses() - self.update_link_transforms() - self._set_attached_objects_poses() + if self.world.reset_joint_position(self.joints[joint_name], joint_position): + self._update_on_joint_position_change() - def set_joint_position(self, joint_name: str, joint_position: float) -> None: + @deprecated("Use set_multiple_joint_positions instead") + def set_joint_positions(self, joint_positions: Dict[str, float]) -> None: + self.set_multiple_joint_positions(joint_positions) + + def set_multiple_joint_positions(self, joint_positions: Dict[str, float]) -> None: """ - Sets the position of the given joint to the given joint pose and updates the poses of all attached objects. + Set the current position of multiple joints at once, this method should be preferred when setting + multiple joints at once instead of running :func:`~Object.set_joint_position` in a loop. - :param joint_name: The name of the joint - :param joint_position: The target pose for this joint + :param joint_positions: A dictionary with the joint names as keys and the target positions as values. """ - self.joints[joint_name].position = joint_position + joint_positions = {self.joints[joint_name]: joint_position + for joint_name, joint_position in joint_positions.items()} + if self.world.set_multiple_joint_positions(joint_positions): + self._update_on_joint_position_change() + + def _update_on_joint_position_change(self): + self.update_pose() self._update_all_links_poses() self.update_link_transforms() self._set_attached_objects_poses() def get_joint_position(self, joint_name: str) -> float: """ + Return the current position of the given joint. + :param joint_name: The name of the joint :return: The current position of the given joint """ @@ -828,6 +1140,8 @@ def get_joint_position(self, joint_name: str) -> float: def get_joint_damping(self, joint_name: str) -> float: """ + Return the damping of the given joint (friction). + :param joint_name: The name of the joint :return: The damping of the given joint """ @@ -835,6 +1149,8 @@ def get_joint_damping(self, joint_name: str) -> float: def get_joint_upper_limit(self, joint_name: str) -> float: """ + Return the upper limit of the given joint. + :param joint_name: The name of the joint :return: The upper limit of the given joint """ @@ -842,6 +1158,8 @@ def get_joint_upper_limit(self, joint_name: str) -> float: def get_joint_lower_limit(self, joint_name: str) -> float: """ + Return the lower limit of the given joint. + :param joint_name: The name of the joint :return: The lower limit of the given joint """ @@ -849,6 +1167,8 @@ def get_joint_lower_limit(self, joint_name: str) -> float: def get_joint_axis(self, joint_name: str) -> Point: """ + Return the axis of the given joint. + :param joint_name: The name of the joint :return: The axis of the given joint """ @@ -856,6 +1176,8 @@ def get_joint_axis(self, joint_name: str) -> Point: def get_joint_type(self, joint_name: str) -> JointType: """ + Return the type of the given joint. + :param joint_name: The name of the joint :return: The type of the given joint """ @@ -863,6 +1185,8 @@ def get_joint_type(self, joint_name: str) -> JointType: def get_joint_limits(self, joint_name: str) -> Tuple[float, float]: """ + Return the lower and upper limits of the given joint. + :param joint_name: The name of the joint :return: The lower and upper limits of the given joint """ @@ -870,6 +1194,8 @@ def get_joint_limits(self, joint_name: str) -> Tuple[float, float]: def get_joint_child_link(self, joint_name: str) -> ObjectDescription.Link: """ + Return the child link of the given joint. + :param joint_name: The name of the joint :return: The child link of the given joint """ @@ -877,6 +1203,8 @@ def get_joint_child_link(self, joint_name: str) -> ObjectDescription.Link: def get_joint_parent_link(self, joint_name: str) -> ObjectDescription.Link: """ + Return the parent link of the given joint. + :param joint_name: The name of the joint :return: The parent link of the given joint """ @@ -884,7 +1212,7 @@ def get_joint_parent_link(self, joint_name: str) -> ObjectDescription.Link: def find_joint_above_link(self, link_name: str, joint_type: JointType) -> str: """ - Traverses the chain from 'link' to the URDF origin and returns the first joint that is of type 'joint_type'. + Traverse the chain from 'link' to the URDF origin and return the first joint that is of type 'joint_type'. :param link_name: AbstractLink name above which the joint should be found :param joint_type: Joint type that should be searched for @@ -898,35 +1226,46 @@ def find_joint_above_link(self, link_name: str, joint_type: JointType) -> str: container_joint = element break if not container_joint: - rospy.logwarn(f"No joint of type {joint_type} found above link {link_name}") + logwarn(f"No joint of type {joint_type} found above link {link_name}") return container_joint + def get_multiple_joint_positions(self, joint_names: List[str]) -> Dict[str, float]: + """ + Return the positions of multiple joints at once. + + :param joint_names: A list of joint names. + :return: A dictionary with the joint names as keys and the joint positions as values. + """ + return self.world.get_multiple_joint_positions([self.joints[joint_name] for joint_name in joint_names]) + def get_positions_of_all_joints(self) -> Dict[str, float]: """ - Returns the positions of all joints of the object as a dictionary of joint names and joint positions. + Return the positions of all joints of the object as a dictionary of joint names and joint positions. :return: A dictionary with all joints positions'. """ return {j.name: j.position for j in self.joints.values()} - def update_link_transforms(self, transform_time: Optional[rospy.Time] = None) -> None: + def update_link_transforms(self, transform_time: Optional[Time] = None) -> None: """ - Updates the transforms of all links of this object using time 'transform_time' or the current ros time. + Update the transforms of all links of this object using time 'transform_time' or the current ros time. + + :param transform_time: The time to use for the transform update. """ for link in self.links.values(): link.update_transform(transform_time) - def contact_points(self) -> List: + def contact_points(self) -> ContactPointsList: """ - Returns a list of contact points of this Object with other Objects. + Return a list of contact points of this Object with other Objects. :return: A list of all contact points with other objects """ return self.world.get_object_contact_points(self) - def contact_points_simulated(self) -> List: + def contact_points_simulated(self) -> ContactPointsList: """ - Returns a list of all contact points between this Object and other Objects after stepping the simulation once. + Return a list of all contact points between this Object and other Objects after stepping the simulation once. :return: A list of contact points between this Object and other Objects """ @@ -936,9 +1275,28 @@ def contact_points_simulated(self) -> List: self.world.restore_state(state_id) return contact_points + def closest_points(self, max_distance: float) -> ClosestPointsList: + """ + Return a list of closest points between this Object and other Objects. + + :param max_distance: The maximum distance between the closest points + :return: A list of closest points between this Object and other Objects + """ + return self.world.get_object_closest_points(self, max_distance) + + def closest_points_with_obj(self, other_object: Object, max_distance: float) -> ClosestPointsList: + """ + Return a list of closest points between this Object and another Object. + + :param other_object: The other object + :param max_distance: The maximum distance between the closest points + :return: A list of closest points between this Object and the other Object + """ + return self.world.get_closest_points_between_objects(self, other_object, max_distance) + def set_color(self, rgba_color: Color) -> None: """ - Changes the color of this object, the color has to be given as a list + Change the color of this object, the color has to be given as a list of RGBA values. :param rgba_color: The color as Color object with RGBA values between 0 and 1 @@ -953,7 +1311,7 @@ def set_color(self, rgba_color: Color) -> None: def get_color(self) -> Union[Color, Dict[str, Color]]: """ - This method returns the rgba_color of this object. The return is either: + Return the rgba_color of this object. The return is either: 1. A Color object with RGBA values, this is the case if the object only has one link (this happens for example if the object is spawned from a .obj or .stl file) @@ -961,7 +1319,8 @@ def get_color(self) -> Union[Color, Dict[str, Color]]: Please keep in mind that not every link may have a rgba_color. This is dependent on the URDF from which the object is spawned. - :return: The rgba_color as Color object with RGBA values between 0 and 1 or a dict with the link name as key and the rgba_color as value. + :return: The rgba_color as Color object with RGBA values between 0 and 1 or a dict with the link name as key and + the rgba_color as value. """ link_to_color_dict = self.links_colors @@ -979,12 +1338,16 @@ def links_colors(self) -> Dict[str, Color]: def get_axis_aligned_bounding_box(self) -> AxisAlignedBoundingBox: """ + Return the axis aligned bounding box of this object. + :return: The axis aligned bounding box of this object. """ return self.world.get_object_axis_aligned_bounding_box(self) def get_base_origin(self) -> Pose: """ + Return the origin of the base/bottom of this object. + :return: the origin of the base/bottom of this object. """ aabb = self.get_axis_aligned_bounding_box() @@ -995,40 +1358,43 @@ def get_base_origin(self) -> Pose: def get_joint_by_id(self, joint_id: int) -> Joint: """ - Returns the joint object with the given id. + Return the joint object with the given id. :param joint_id: The unique id of the joint. :return: The joint object. """ return dict([(joint.id, joint) for joint in self.joints.values()])[joint_id] + def get_link_for_attached_objects(self) -> Dict[Object, ObjectDescription.Link]: + """ + Return a dictionary which maps attached object to the link of this object to which the given object is attached. + + :return: The link of this object to which the given object is attached. + """ + return {obj: attachment.parent_link for obj, attachment in self.attachments.items()} + def copy_to_prospection(self) -> Object: """ - Copies this object to the prospection world. + Copy this object to the prospection world. :return: The copied object in the prospection world. """ - obj = Object(self.name, self.obj_type, self.path, type(self.description), self.get_pose(), - self.world.prospection_world, self.color) - obj.current_state = self.current_state - return obj + return self.copy_to_world(self.world.prospection_world) - def __copy__(self) -> Object: + def copy_to_world(self, world: World) -> Object: """ - Returns a copy of this object. The copy will have the same name, type, path, description, pose, world and color. + Copy this object to the given world. - :return: A copy of this object. + :param world: The world to which the object should be copied. + :return: The copied object in the given world. """ - obj = Object(self.name, self.obj_type, self.path, type(self.description), self.get_pose(), - self.world.prospection_world, self.color) - obj.current_state = self.current_state + obj = Object(self.name, self.obj_type, self.path, self.description, self.get_pose(), + world, self.color) return obj def __eq__(self, other): - if not isinstance(other, Object): - return False - return (self.id == other.id and self.world == other.world and self.name == other.name - and self.obj_type == other.obj_type) + return (isinstance(other, Object) and self.id == other.id and self.name == other.name + and self.world == other.world) def __hash__(self): - return hash((self.name, self.obj_type, self.id, self.world.id)) + return hash((self.id, self.name, self.world)) diff --git a/src/pycram/world_reasoning.py b/src/pycram/world_reasoning.py index f4892676b..8fd0501b9 100644 --- a/src/pycram/world_reasoning.py +++ b/src/pycram/world_reasoning.py @@ -1,12 +1,14 @@ -from typing_extensions import List, Tuple, Optional, Union, Dict - import numpy as np +from typing_extensions import List, Tuple, Optional, Union, Dict -from .external_interfaces.ik import try_to_reach, try_to_reach_with_grasp +from .datastructures.dataclasses import ContactPointsList from .datastructures.pose import Pose, Transform +from .datastructures.world import World, UseProspectionWorld +from .external_interfaces.ik import try_to_reach, try_to_reach_with_grasp from .robot_description import RobotDescription +from .utils import RayTestUtils from .world_concepts.world_object import Object -from .datastructures.world import World, UseProspectionWorld +from .config import world_conf as conf def stable(obj: Object) -> bool: @@ -49,41 +51,44 @@ def contact( with UseProspectionWorld(): prospection_obj1 = World.current_world.get_prospection_object_for_object(object1) prospection_obj2 = World.current_world.get_prospection_object_for_object(object2) - World.current_world.perform_collision_detection() - con_points = World.current_world.get_contact_points_between_two_objects(prospection_obj1, prospection_obj2) - + con_points: ContactPointsList = World.current_world.get_contact_points_between_two_objects(prospection_obj1, + prospection_obj2) + objects_are_in_contact = len(con_points) > 0 if return_links: - contact_links = [] - for point in con_points: - contact_links.append((prospection_obj1.get_link_by_id(point[3]), - prospection_obj2.get_link_by_id(point[4]))) - return con_points != (), contact_links - + contact_links = [(point.link_a, point.link_b) for point in con_points] + return objects_are_in_contact, contact_links else: - return con_points != () + return objects_are_in_contact def get_visible_objects( camera_pose: Pose, - front_facing_axis: Optional[List[float]] = None) -> Tuple[np.ndarray, Pose]: + front_facing_axis: Optional[List[float]] = None, + plot_segmentation_mask: bool = False) -> Tuple[np.ndarray, Pose]: """ - Returns a segmentation mask of the objects that are visible from the given camera pose and the front facing axis. + Return a segmentation mask of the objects that are visible from the given camera pose and the front facing axis. :param camera_pose: The pose of the camera in world coordinate frame. :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz + :param plot_segmentation_mask: If the segmentation mask should be plotted :return: A segmentation mask of the objects that are visible and the pose of the point at exactly 2 meters in front of the camera in the direction of the front facing axis with respect to the world coordinate frame. """ - front_facing_axis = RobotDescription.current_robot_description.get_default_camera().front_facing_axis + if front_facing_axis is None: + front_facing_axis = RobotDescription.current_robot_description.get_default_camera().front_facing_axis - world_to_cam = camera_pose.to_transform("camera") + camera_frame = RobotDescription.current_robot_description.get_camera_frame() + world_to_cam = camera_pose.to_transform(camera_frame) - cam_to_point = Transform(list(np.multiply(front_facing_axis, 2)), [0, 0, 0, 1], "camera", + cam_to_point = Transform(list(np.multiply(front_facing_axis, 2)), [0, 0, 0, 1], camera_frame, "point") target_point = (world_to_cam * cam_to_point).to_pose() seg_mask = World.current_world.get_images_for_target(target_point, camera_pose)[2] + if plot_segmentation_mask: + RayTestUtils.plot_segmentation_mask(seg_mask) + return seg_mask, target_point @@ -91,7 +96,8 @@ def visible( obj: Object, camera_pose: Pose, front_facing_axis: Optional[List[float]] = None, - threshold: float = 0.8) -> bool: + threshold: float = 0.8, + plot_segmentation_mask: bool = False) -> bool: """ Checks if an object is visible from a given position. This will be achieved by rendering the object alone and counting the visible pixel, then rendering the complete scene and compare the visible pixels with the @@ -101,6 +107,7 @@ def visible( :param camera_pose: The pose of the camera in map frame :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz :param threshold: The minimum percentage of the object that needs to be visible for this method to return true. + :param plot_segmentation_mask: If the segmentation mask should be plotted. :return: True if the object is visible from the camera_position False if not """ with UseProspectionWorld(): @@ -115,7 +122,7 @@ def visible( else: obj.set_pose(Pose([100, 100, 0], [0, 0, 0, 1]), set_attachments=False) - seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis) + seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis, plot_segmentation_mask) max_pixel = np.array(seg_mask == prospection_obj.id).sum() World.current_world.restore_state(state_id) @@ -133,7 +140,8 @@ def visible( def occluding( obj: Object, camera_pose: Pose, - front_facing_axis: Optional[List[float]] = None) -> List[Object]: + front_facing_axis: Optional[List[float]] = None, + plot_segmentation_mask: bool = False) -> List[Object]: """ Lists all objects which are occluding the given object. This works similar to 'visible'. First the object alone will be rendered and the position of the pixels of the object in the picture will be saved. @@ -143,6 +151,7 @@ def occluding( :param obj: The object for which occlusion should be checked :param camera_pose: The pose of the camera in world coordinate frame :param front_facing_axis: The axis, of the camera frame, which faces to the front of the robot. Given as list of xyz + :param plot_segmentation_mask: If the segmentation mask should be plotted :return: A list of occluding objects """ @@ -156,7 +165,7 @@ def occluding( else: other_obj.set_pose(Pose([100, 100, 0], [0, 0, 0, 1])) - seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis) + seg_mask, target_point = get_visible_objects(camera_pose, front_facing_axis, plot_segmentation_mask) # All indices where the object that could be occluded is in the image # [0] at the end is to reduce by one dimension because dstack adds an unnecessary dimension @@ -224,17 +233,15 @@ def blocking( :return: A list of objects the robot is in collision with when reaching for the specified object or None if the pose or object is not reachable. """ - prospection_robot = World.current_world.get_prospection_object_for_object(robot) with UseProspectionWorld(): + prospection_robot = World.current_world.get_prospection_object_for_object(robot) if grasp: try_to_reach_with_grasp(pose_or_object, prospection_robot, gripper_name, grasp) else: try_to_reach(pose_or_object, prospection_robot, gripper_name) - block = [] - for obj in World.current_world.objects: - if contact(prospection_robot, obj): - block.append(World.current_world.get_object_for_prospection_object(obj)) + block = [World.current_world.get_object_for_prospection_object(obj) for obj in World.current_world.objects + if contact(prospection_robot, obj)] return block @@ -257,7 +264,7 @@ def link_pose_for_joint_config( joint_config: Dict[str, float], link_name: str) -> Pose: """ - Returns the pose a link would be in if the given joint configuration would be applied to the object. + Get the pose a link would be in if the given joint configuration would be applied to the object. This is done by using the respective object in the prospection world and applying the joint configuration to this one. After applying the joint configuration the link position is taken from there. diff --git a/src/pycram/worlds/bullet_world.py b/src/pycram/worlds/bullet_world.py index 77d6c1958..90851e466 100755 --- a/src/pycram/worlds/bullet_world.py +++ b/src/pycram/worlds/bullet_world.py @@ -5,18 +5,20 @@ import time import numpy as np -import pybullet as p -import rosgraph -import rospy +import pycram_bullet as p from geometry_msgs.msg import Point -from typing_extensions import List, Optional, Dict +from typing_extensions import List, Optional, Dict, Any +from ..datastructures.dataclasses import Color, AxisAlignedBoundingBox, MultiBody, VisualShape, BoxVisualShape, \ + ClosestPoint, LateralFriction, ContactPoint, ContactPointsList, ClosestPointsList from ..datastructures.enums import ObjectType, WorldMode, JointType from ..datastructures.pose import Pose -from ..object_descriptors.urdf import ObjectDescription from ..datastructures.world import World +from ..object_descriptors.generic import ObjectDescription as GenericObjectDescription +from ..object_descriptors.urdf import ObjectDescription +from ..validation.goal_validator import (validate_multiple_joint_positions, validate_joint_position, + validate_object_pose, validate_multiple_object_poses) from ..world_concepts.constraints import Constraint -from ..datastructures.dataclasses import Color, AxisAlignedBoundingBox, MultiBody, VisualShape, BoxVisualShape from ..world_concepts.world_object import Object Link = ObjectDescription.Link @@ -31,13 +33,7 @@ class is the main interface to the Bullet Physics Engine and should be used to s manipulate the Bullet World. """ - extension: str = ObjectDescription.get_file_extension() - - # Check is for sphinx autoAPI to be able to work in a CI workflow - if rosgraph.is_master_online(): # and "/pycram" not in rosnode.get_node_names(): - rospy.init_node('pycram') - - def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: bool = False, sim_frequency=240): + def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: bool = False): """ Creates a new simulation, the type decides of the simulation should be a rendered window or just run in the background. There can only be one rendered simulation. @@ -46,7 +42,7 @@ def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: boo :param mode: Can either be "GUI" for rendered window or "DIRECT" for non-rendered. The default is "GUI" :param is_prospection_world: For internal usage, decides if this BulletWorld should be used as a shadow world. """ - super().__init__(mode=mode, is_prospection_world=is_prospection_world, simulation_frequency=sim_frequency) + super().__init__(mode=mode, is_prospection_world=is_prospection_world) # This disables file caching from PyBullet, since this would also cache # files that can not be loaded @@ -60,7 +56,7 @@ def __init__(self, mode: WorldMode = WorldMode.DIRECT, is_prospection_world: boo self.set_gravity([0, 0, -9.8]) if not is_prospection_world: - _ = Object("floor", ObjectType.ENVIRONMENT, "plane" + self.extension, + _ = Object("floor", ObjectType.ENVIRONMENT, "plane.urdf", world=self) def _init_world(self, mode: WorldMode): @@ -68,7 +64,30 @@ def _init_world(self, mode: WorldMode): self._gui_thread.start() time.sleep(0.1) - def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None) -> int: + def load_generic_object_and_get_id(self, description: GenericObjectDescription, + pose: Optional[Pose] = None) -> int: + """ + Creates a visual and collision box in the simulation. + """ + # Create visual shape + vis_shape = p.createVisualShape(p.GEOM_BOX, halfExtents=description.shape_data, + rgbaColor=description.color.get_rgba(), physicsClientId=self.id) + + # Create collision shape + col_shape = p.createCollisionShape(p.GEOM_BOX, halfExtents=description.shape_data, physicsClientId=self.id) + + # Create MultiBody with both visual and collision shapes + obj_id = p.createMultiBody(baseMass=1.0, baseCollisionShapeIndex=col_shape, baseVisualShapeIndex=vis_shape, + basePosition=description.origin.position_as_list(), + baseOrientation=description.origin.orientation_as_list(), physicsClientId=self.id) + + if pose is not None: + self._set_object_pose_by_id(obj_id, pose) + # Assuming you have a list to keep track of created objects + return obj_id + + def load_object_and_get_id(self, path: Optional[str] = None, pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: if pose is None: pose = Pose() return self._load_object_and_get_id(path, pose) @@ -80,11 +99,21 @@ def _load_object_and_get_id(self, path: str, pose: Pose) -> int: basePosition=pose.position_as_list(), baseOrientation=pose.orientation_as_list(), physicsClientId=self.id) - def remove_object_from_simulator(self, obj: Object) -> None: - p.removeBody(obj.id, self.id) + def _remove_visual_object(self, obj_id: int) -> bool: + self._remove_body(obj_id) + return True + + def remove_object_from_simulator(self, obj: Object) -> bool: + self._remove_body(obj.id) + return True + + def _remove_body(self, body_id: int) -> Any: + """ + Remove a body from PyBullet using the body id. - def remove_object_by_id(self, obj_id: int) -> None: - p.removeBody(obj_id, self.id) + :param body_id: The id of the body. + """ + return p.removeBody(body_id, self.id) def add_constraint(self, constraint: Constraint) -> int: @@ -111,6 +140,21 @@ def get_object_joint_names(self, obj: Object) -> List[str]: return [p.getJointInfo(obj.id, i, physicsClientId=self.id)[1].decode('utf-8') for i in range(self.get_object_number_of_joints(obj))] + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + return {link.name: self.get_link_pose(link) for link in links} + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + return {link.name: self.get_link_position(link) for link in links} + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + return {link.name: self.get_link_orientation(link) for link in links} + + def get_link_position(self, link: Link) -> List[float]: + return self.get_link_pose(link).position_as_list() + + def get_link_orientation(self, link: Link) -> List[float]: + return self.get_link_pose(link).orientation_as_list() + def get_link_pose(self, link: ObjectDescription.Link) -> Pose: bullet_link_state = p.getLinkState(link.object_id, link.id, physicsClientId=self.id) return Pose(*bullet_link_state[4:6]) @@ -128,29 +172,94 @@ def get_object_number_of_links(self, obj: Object) -> int: def perform_collision_detection(self) -> None: p.performCollisionDetection(physicsClientId=self.id) - def get_object_contact_points(self, obj: Object) -> List: + def get_object_contact_points(self, obj: Object) -> ContactPointsList: """ - For a more detailed explanation of the - returned list please look at: - `PyBullet Doc `_ + Get the contact points of the object with akk other objects in the world. The contact points are returned as a + ContactPointsList object. + + :param obj: The object for which the contact points should be returned. + :return: The contact points of the object with all other objects in the world. """ self.perform_collision_detection() - return p.getContactPoints(obj.id, physicsClientId=self.id) + points_list = p.getContactPoints(obj.id, physicsClientId=self.id) + return ContactPointsList([ContactPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) - def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> List: + def get_contact_points_between_two_objects(self, obj_a: Object, obj_b: Object) -> ContactPointsList: self.perform_collision_detection() - return p.getContactPoints(obj1.id, obj2.id, physicsClientId=self.id) + points_list = p.getContactPoints(obj_a.id, obj_b.id, physicsClientId=self.id) + return ContactPointsList([ContactPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) - def reset_joint_position(self, joint: ObjectDescription.Joint, joint_position: str) -> None: + def get_closest_points_between_objects(self, obj_a: Object, obj_b: Object, distance: float) -> ClosestPointsList: + points_list = p.getClosestPoints(obj_a.id, obj_b.id, distance, physicsClientId=self.id) + return ClosestPointsList([ClosestPoint(**self.parse_points_list_to_args(point)) for point in points_list + if len(point) > 0]) + + def parse_points_list_to_args(self, point: List) -> Dict: + """ + Parses the list of points to a list of dictionaries with the keys as the names of the arguments of the + ContactPoint class. + + :param point: The list of points. + """ + return {"link_a": self.get_object_by_id(point[1]).get_link_by_id(point[3]), + "link_b": self.get_object_by_id(point[2]).get_link_by_id(point[4]), + "position_on_object_a": point[5], + "position_on_object_b": point[6], + "normal_on_b": point[7], + "distance": point[8], + "normal_force": point[9], + "lateral_friction_1": LateralFriction(point[10], point[11]), + "lateral_friction_2": LateralFriction(point[12], point[13])} + + @validate_multiple_joint_positions + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + for joint, joint_position in joint_positions.items(): + self.reset_joint_position(joint, joint_position) + return True + + @validate_joint_position + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: p.resetJointState(joint.object_id, joint.id, joint_position, physicsClientId=self.id) + return True - def reset_object_base_pose(self, obj: Object, pose: Pose) -> None: - p.resetBasePositionAndOrientation(obj.id, pose.position_as_list(), pose.orientation_as_list(), + def get_multiple_joint_positions(self, joints: List[Joint]) -> Dict[str, float]: + return {joint.name: self.get_joint_position(joint) for joint in joints} + + @validate_multiple_object_poses + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> bool: + for obj, pose in objects.items(): + self.reset_object_base_pose(obj, pose) + return True + + @validate_object_pose + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: + return self._set_object_pose_by_id(obj.id, pose) + + def _set_object_pose_by_id(self, obj_id: int, pose: Pose) -> bool: + p.resetBasePositionAndOrientation(obj_id, pose.position_as_list(), pose.orientation_as_list(), physicsClientId=self.id) + return True def step(self): p.stepSimulation(physicsClientId=self.id) + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + return {obj.name: self.get_object_pose(obj) for obj in objects} + + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + return {obj.name: self.get_object_pose(obj).position_as_list() for obj in objects} + + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + return {obj.name: self.get_object_pose(obj).orientation_as_list() for obj in objects} + + def get_object_position(self, obj: Object) -> List[float]: + return self.get_object_pose(obj).position_as_list() + + def get_object_orientation(self, obj: Object) -> List[float]: + return self.get_object_pose(obj).orientation_as_list() + def get_object_pose(self, obj: Object) -> Pose: return Pose(*p.getBasePositionAndOrientation(obj.id, physicsClientId=self.id)) @@ -193,7 +302,7 @@ def join_gui_thread_if_exists(self): if self._gui_thread: self._gui_thread.join() - def save_physics_simulator_state(self) -> int: + def save_physics_simulator_state(self, state_id: Optional[int] = None, use_same_id: bool = False) -> int: return p.saveState(physicsClientId=self.id) def restore_physics_simulator_state(self, state_id): @@ -202,14 +311,15 @@ def restore_physics_simulator_state(self, state_id): def remove_physics_simulator_state(self, state_id: int): p.removeState(state_id, physicsClientId=self.id) - def add_vis_axis(self, pose: Pose, - length: Optional[float] = 0.2) -> None: + def _add_vis_axis(self, pose: Pose, + length: Optional[float] = 0.2) -> int: """ Creates a Visual object which represents the coordinate frame at the given position and orientation. There can be an unlimited amount of vis axis objects. :param pose: The pose at which the axis should be spawned :param length: Optional parameter to configure the length of the axes + :return: The id of the spawned object """ pose_in_map = self.local_transformer.transform_pose(pose, "map") @@ -231,9 +341,11 @@ def add_vis_axis(self, pose: Pose, link_joint_axis=[Point(1, 0, 0), Point(0, 1, 0), Point(0, 0, 1)], link_collision_shape_indices=[-1, -1, -1]) - self.vis_axis.append(self.create_multi_body(multibody)) + body_id = self._create_multi_body(multibody) + self.vis_axis.append(body_id) + return body_id - def remove_vis_axis(self) -> None: + def _remove_vis_axis(self) -> None: """ Removes all spawned vis axis objects that are currently in this BulletWorld. """ @@ -250,13 +362,13 @@ def ray_test_batch(self, from_positions: List[List[float]], to_positions: List[L return p.rayTestBatch(from_positions, to_positions, numThreads=num_threads, physicsClientId=self.id) - def create_visual_shape(self, visual_shape: VisualShape) -> int: + def _create_visual_shape(self, visual_shape: VisualShape) -> int: return p.createVisualShape(visual_shape.visual_geometry_type.value, rgbaColor=visual_shape.rgba_color.get_rgba(), visualFramePosition=visual_shape.visual_frame_position, physicsClientId=self.id, **visual_shape.shape_data()) - def create_multi_body(self, multi_body: MultiBody) -> int: + def _create_multi_body(self, multi_body: MultiBody) -> int: return p.createMultiBody(baseVisualShapeIndex=-multi_body.base_visual_shape_index, linkVisualShapeIndices=multi_body.link_visual_shape_indices, basePosition=multi_body.base_pose.position_as_list(), @@ -289,9 +401,9 @@ def get_images_for_target(self, return list(p.getCameraImage(size, size, view_matrix, projection_matrix, physicsClientId=self.id))[2:5] - def add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, - size: Optional[float] = None, color: Optional[Color] = Color(), life_time: Optional[float] = 0, - parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: + def _add_text(self, text: str, position: List[float], orientation: Optional[List[float]] = None, + size: Optional[float] = None, color: Optional[Color] = Color(), life_time: Optional[float] = 0, + parent_object_id: Optional[int] = None, parent_link_id: Optional[int] = None) -> int: args = {} if orientation: args["textOrientation"] = orientation @@ -305,7 +417,7 @@ def add_text(self, text: str, position: List[float], orientation: Optional[List[ args["parentLinkIndex"] = parent_link_id return p.addUserDebugText(text, position, color.get_rgb(), physicsClientId=self.id, **args) - def remove_text(self, text_id: Optional[int] = None) -> None: + def _remove_text(self, text_id: Optional[int] = None) -> None: if text_id is not None: p.removeUserDebugItem(text_id, physicsClientId=self.id) else: @@ -395,7 +507,7 @@ def run(self): width, height, dist = (p.getDebugVisualizerCamera()[0], p.getDebugVisualizerCamera()[1], p.getDebugVisualizerCamera()[10]) - #print("width: ", width, "height: ", height, "dist: ", dist) + # print("width: ", width, "height: ", height, "dist: ", dist) camera_target_position = p.getDebugVisualizerCamera(self.world.id)[11] # Get vectors used for movement on x,y,z Vector @@ -550,5 +662,6 @@ def run(self): cameraTargetPosition=camera_target_position, physicsClientId=self.world.id) if visible == 0: camera_target_position = (0.0, -50, 50) - p.resetBasePositionAndOrientation(sphere_uid, camera_target_position, [0, 0, 0, 1], physicsClientId=self.world.id) + p.resetBasePositionAndOrientation(sphere_uid, camera_target_position, [0, 0, 0, 1], + physicsClientId=self.world.id) time.sleep(1. / 80.) diff --git a/src/pycram/worlds/multiverse.py b/src/pycram/worlds/multiverse.py new file mode 100644 index 000000000..b4b65380e --- /dev/null +++ b/src/pycram/worlds/multiverse.py @@ -0,0 +1,658 @@ +import logging +from time import sleep + +import numpy as np +from tf.transformations import quaternion_matrix +from typing_extensions import List, Dict, Optional, Union, Tuple + +from .multiverse_communication.client_manager import MultiverseClientManager +from .multiverse_communication.clients import MultiverseController, MultiverseReader, MultiverseWriter, MultiverseAPI +from ..config.multiverse_conf import MultiverseConfig +from ..datastructures.dataclasses import AxisAlignedBoundingBox, Color, ContactPointsList, ContactPoint +from ..datastructures.enums import WorldMode, JointType, ObjectType, MultiverseBodyProperty, MultiverseJointPosition, \ + MultiverseJointCMD +from ..datastructures.pose import Pose +from ..datastructures.world import World +from ..description import Link, Joint, ObjectDescription +from ..object_descriptors.mjcf import ObjectDescription as MJCF +from ..robot_description import RobotDescription +from ..ros.logging import logwarn, logerr +from ..utils import RayTestUtils, wxyz_to_xyzw, xyzw_to_wxyz +from ..validation.goal_validator import validate_object_pose, validate_multiple_joint_positions, \ + validate_joint_position, validate_multiple_object_poses +from ..world_concepts.constraints import Constraint +from ..world_concepts.world_object import Object + + +class Multiverse(World): + """ + This class implements an interface between Multiverse and PyCRAM. + """ + + conf: MultiverseConfig = MultiverseConfig + """ + The Multiverse configuration. + """ + + supported_joint_types = (JointType.REVOLUTE, JointType.CONTINUOUS, JointType.PRISMATIC) + """ + A Tuple for the supported pycram joint types in Multiverse. + """ + + added_multiverse_resources: bool = False + """ + A flag to check if the multiverse resources have been added. + """ + + simulation: Optional[str] = None + """ + The simulation name to be used in the Multiverse world (this is the name defined in + the multiverse configuration file). + """ + + Object.extension_to_description_type[MJCF.get_file_extension()] = MJCF + """ + Add the MJCF description extension to the extension to description type mapping for the objects. + """ + + def __init__(self, mode: Optional[WorldMode] = WorldMode.DIRECT, + is_prospection: Optional[bool] = False, + simulation_name: str = "pycram_test", + clear_cache: bool = False): + """ + Initialize the Multiverse Socket and the PyCram World. + + :param mode: The mode of the world (DIRECT or GUI). + :param is_prospection: Whether the world is prospection or not. + :param simulation_name: The name of the simulation. + :param clear_cache: Whether to clear the cache or not. + """ + + self.latest_save_id: Optional[int] = None + self.saved_simulator_states: Dict = {} + self._make_sure_multiverse_resources_are_added(clear_cache=clear_cache) + + if Multiverse.simulation is None: + if simulation_name is None: + logging.error("Simulation name not provided") + raise ValueError("Simulation name not provided") + Multiverse.simulation = simulation_name + + self.simulation = (self.conf.prospection_world_prefix if is_prospection else "") + Multiverse.simulation + self.client_manager = MultiverseClientManager(self.conf.simulation_wait_time_factor) + self._init_clients(is_prospection=is_prospection) + + World.__init__(self, mode, is_prospection) + + self._init_constraint_and_object_id_name_map_collections() + + self.ray_test_utils = RayTestUtils(self.ray_test_batch, self.object_id_to_name) + + if not self.is_prospection_world: + self._spawn_floor() + + if self.conf.use_static_mode: + self.api_requester.pause_simulation() + + def _init_clients(self, is_prospection: bool = False): + """ + Initialize the Multiverse clients that will be used to communicate with the Multiverse server. + Each client is responsible for a specific task, e.g. reading data from the server, writing data to the serve, + calling the API, or controlling the robot joints. + + :param is_prospection: Whether the world is prospection or not. + """ + self.reader: MultiverseReader = self.client_manager.create_reader( + is_prospection_world=is_prospection) + self.writer: MultiverseWriter = self.client_manager.create_writer( + self.simulation, + is_prospection_world=is_prospection) + self.api_requester: MultiverseAPI = self.client_manager.create_api_requester( + self.simulation, + is_prospection_world=is_prospection) + if self.conf.use_controller: + self.joint_controller: MultiverseController = self.client_manager.create_controller( + is_prospection_world=is_prospection) + + def _init_constraint_and_object_id_name_map_collections(self): + self.last_object_id: int = -1 + self.last_constraint_id: int = -1 + self.constraints: Dict[int, Constraint] = {} + self.object_name_to_id: Dict[str, int] = {} + self.object_id_to_name: Dict[int, str] = {} + + def _init_world(self, mode: WorldMode): + pass + + def _make_sure_multiverse_resources_are_added(self, clear_cache: bool = False): + """ + Add the multiverse resources to the pycram world resources, and change the data directory and cache manager. + + :param clear_cache: Whether to clear the cache or not. + """ + if not self.added_multiverse_resources: + if clear_cache: + World.cache_manager.clear_cache() + World.add_resource_path(self.conf.resources_path, prepend=True) + World.change_cache_dir_path(self.conf.resources_path) + self.added_multiverse_resources = True + + def remove_multiverse_resources(self): + """ + Remove the multiverse resources from the pycram world resources. + """ + if self.added_multiverse_resources: + World.remove_resource_path(self.conf.resources_path) + World.change_cache_dir_path(self.conf.cache_dir) + self.added_multiverse_resources = False + + def _spawn_floor(self): + """ + Spawn the plane in the simulator. + """ + self.floor = Object("floor", ObjectType.ENVIRONMENT, "plane.urdf", + world=self) + + def get_images_for_target(self, target_pose: Pose, + cam_pose: Pose, + size: int = 256, + camera_min_distance: float = 0.1, + camera_max_distance: int = 3, + plot: bool = False) -> List[np.ndarray]: + """ + Uses ray test to get the images for the target object. (target_pose is currently not used) + """ + camera_description = RobotDescription.current_robot_description.get_default_camera() + camera_frame = RobotDescription.current_robot_description.get_camera_frame() + return self.ray_test_utils.get_images_for_target(cam_pose, camera_description, camera_frame, + size, camera_min_distance, camera_max_distance, plot) + + @staticmethod + def get_joint_position_name(joint: Joint) -> MultiverseJointPosition: + """ + Get the attribute name of the joint position in the Multiverse from the pycram joint type. + + :param joint: The joint. + """ + return MultiverseJointPosition.from_pycram_joint_type(joint.type) + + def spawn_robot_with_controller(self, name: str, pose: Pose) -> None: + """ + Spawn the robot in the simulator. + + :param name: The name of the robot. + :param pose: The pose of the robot. + """ + actuator_joint_commands = { + actuator_name: [self.get_joint_cmd_name(self.robot_description.joint_types[joint_name]).value] + for joint_name, actuator_name in self.robot_joint_actuators.items() + } + self.joint_controller.init_controller(actuator_joint_commands) + self.writer.spawn_robot_with_actuators(name, pose.position_as_list(), + xyzw_to_wxyz(pose.orientation_as_list()), + actuator_joint_commands) + + def load_object_and_get_id(self, name: Optional[str] = None, + pose: Optional[Pose] = None, + obj_type: Optional[ObjectType] = None) -> int: + """ + Spawn the object in the simulator and return the object id. Object name has to be unique and has to be same as + the name of the object in the description file. + + :param name: The name of the object to be loaded. + :param pose: The pose of the object. + :param obj_type: The type of the object. + """ + if pose is None: + pose = Pose() + + # Do not spawn objects with type environment as they should be already present in the simulator through the + # multiverse description file (.muv file). + if not obj_type == ObjectType.ENVIRONMENT: + self.spawn_object(name, obj_type, pose) + + return self._update_object_id_name_maps_and_get_latest_id(name) + + def spawn_object(self, name: str, object_type: ObjectType, pose: Pose) -> None: + """ + Spawn the object in the simulator. + + :param name: The name of the object. + :param object_type: The type of the object. + :param pose: The pose of the object. + """ + if object_type == ObjectType.ROBOT and self.conf.use_controller: + self.spawn_robot_with_controller(name, pose) + else: + self._set_body_pose(name, pose) + + def _update_object_id_name_maps_and_get_latest_id(self, name: str) -> int: + """ + Update the object id name maps and return the latest object id. + + :param name: The name of the object. + :return: The latest object id. + """ + self.last_object_id += 1 + self.object_name_to_id[name] = self.last_object_id + self.object_id_to_name[self.last_object_id] = name + return self.last_object_id + + def get_object_joint_names(self, obj: Object) -> List[str]: + return [joint.name for joint in obj.description.joints if joint.type in self.supported_joint_types] + + def get_object_link_names(self, obj: Object) -> List[str]: + return [link.name for link in obj.description.links] + + def get_link_position(self, link: Link) -> List[float]: + return self.reader.get_body_position(link.name) + + def get_link_orientation(self, link: Link) -> List[float]: + return self.reader.get_body_orientation(link.name) + + def get_multiple_link_positions(self, links: List[Link]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_positions([link.name for link in links]) + + def get_multiple_link_orientations(self, links: List[Link]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_orientations([link.name for link in links]) + + @validate_joint_position + def reset_joint_position(self, joint: Joint, joint_position: float) -> bool: + if self.conf.use_controller and self.joint_has_actuator(joint): + self._reset_joint_position_using_controller(joint, joint_position) + else: + self._set_multiple_joint_positions_without_controller({joint: joint_position}) + return True + + def _reset_joint_position_using_controller(self, joint: Joint, joint_position: float) -> bool: + """ + Reset the position of a joint in the simulator using the controller. + + :param joint: The joint. + :param joint_position: The position of the joint. + :return: True if the joint position is reset successfully. + """ + self.joint_controller.set_body_property(self.get_actuator_for_joint(joint), + self.get_joint_cmd_name(joint.type), + [joint_position]) + return True + + @validate_multiple_joint_positions + def set_multiple_joint_positions(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints in the simulator. Also check if the joint is controlled by an actuator + and use the controller to set the joint position if the joint is controlled. + + :param joint_positions: The dictionary of joints and positions. + :return: True if the joint positions are set successfully (this means that the joint positions are set without + errors, but not necessarily that the joint positions are set to the specified values). + """ + + if self.conf.use_controller: + controlled_joints = self.get_controlled_joints(list(joint_positions.keys())) + if len(controlled_joints) > 0: + controlled_joint_positions = {joint: joint_positions[joint] for joint in controlled_joints} + self._set_multiple_joint_positions_using_controller(controlled_joint_positions) + joint_positions = {joint: joint_positions[joint] for joint in joint_positions.keys() + if joint not in controlled_joints} + if len(joint_positions) > 0: + self._set_multiple_joint_positions_without_controller(joint_positions) + + return True + + def get_controlled_joints(self, joints: Optional[List[Joint]] = None) -> List[Joint]: + """ + Get the joints that are controlled by an actuator from the list of joints. + + :param joints: The list of joints to check. + :return: The list of controlled joints. + """ + joints = self.robot.joints if joints is None else joints + return [joint for joint in joints if self.joint_has_actuator(joint)] + + def _set_multiple_joint_positions_without_controller(self, joint_positions: Dict[Joint, float]) -> None: + """ + Set the positions of multiple joints in the simulator without using the controller. + + :param joint_positions: The dictionary of joints and positions. + """ + joints_data = {joint.name: {self.get_joint_position_name(joint): [position]} + for joint, position in joint_positions.items()} + self.writer.send_multiple_body_data_to_server(joints_data) + + def _set_multiple_joint_positions_using_controller(self, joint_positions: Dict[Joint, float]) -> bool: + """ + Set the positions of multiple joints in the simulator using the controller. + + :param joint_positions: The dictionary of joints and positions. + """ + controlled_joints_data = {self.get_actuator_for_joint(joint): + {self.get_joint_cmd_name(joint.type): [position]} + for joint, position in joint_positions.items()} + self.joint_controller.send_multiple_body_data_to_server(controlled_joints_data) + return True + + def get_joint_position(self, joint: Joint) -> Optional[float]: + joint_position_name = self.get_joint_position_name(joint) + data = self.reader.get_body_data(joint.name, [joint_position_name]) + if data is not None: + return data[joint_position_name.value][0] + + def get_multiple_joint_positions(self, joints: List[Joint]) -> Optional[Dict[str, float]]: + joint_names = [joint.name for joint in joints] + data = self.reader.get_multiple_body_data(joint_names, {joint.name: [self.get_joint_position_name(joint)] + for joint in joints}) + if data is not None: + return {name: list(value.values())[0][0] for name, value in data.items()} + + @staticmethod + def get_joint_cmd_name(joint_type: JointType) -> MultiverseJointCMD: + """ + Get the attribute name of the joint command in the Multiverse from the pycram joint type. + + :param joint_type: The pycram joint type. + """ + return MultiverseJointCMD.from_pycram_joint_type(joint_type) + + def get_link_pose(self, link: Link) -> Optional[Pose]: + return self._get_body_pose(link.name) + + def get_multiple_link_poses(self, links: List[Link]) -> Dict[str, Pose]: + return self._get_multiple_body_poses([link.name for link in links]) + + def get_object_pose(self, obj: Object) -> Pose: + if obj.has_type_environment(): + return Pose() + return self._get_body_pose(obj.name) + + def get_multiple_object_poses(self, objects: List[Object]) -> Dict[str, Pose]: + """ + Set the poses of multiple objects in the simulator. If the object is of type environment, the pose will be + the default pose. + + :param objects: The list of objects. + :return: The dictionary of object names and poses. + """ + non_env_objects = [obj for obj in objects if not obj.has_type_environment()] + all_poses = self._get_multiple_body_poses([obj.name for obj in non_env_objects]) + all_poses.update({obj.name: Pose() for obj in objects if obj.has_type_environment()}) + return all_poses + + @validate_object_pose + def reset_object_base_pose(self, obj: Object, pose: Pose) -> bool: + if obj.has_type_environment(): + return False + + if (obj.obj_type == ObjectType.ROBOT and + RobotDescription.current_robot_description.virtual_mobile_base_joints is not None): + obj.set_mobile_robot_pose(pose) + else: + self._set_body_pose(obj.name, pose) + + return True + + @validate_multiple_object_poses + def reset_multiple_objects_base_poses(self, objects: Dict[Object, Pose]) -> None: + """ + Reset the poses of multiple objects in the simulator. + + :param objects: The dictionary of objects and poses. + """ + for obj in objects.keys(): + if (obj.obj_type == ObjectType.ROBOT and + RobotDescription.current_robot_description.virtual_mobile_base_joints is not None): + obj.set_mobile_robot_pose(objects[obj]) + objects = {obj: pose for obj, pose in objects.items() if obj.obj_type not in [ObjectType.ENVIRONMENT, + ObjectType.ROBOT]} + self._set_multiple_body_poses({obj.name: pose for obj, pose in objects.items()}) + + def _set_body_pose(self, body_name: str, pose: Pose) -> None: + """ + Reset the pose of a body (object, link, or joint) in the simulator. + + :param body_name: The name of the body. + :param pose: The pose of the body. + """ + self._set_multiple_body_poses({body_name: pose}) + + def _set_multiple_body_poses(self, body_poses: Dict[str, Pose]) -> None: + """ + Reset the poses of multiple bodies in the simulator. + + :param body_poses: The dictionary of body names and poses. + """ + self.writer.set_multiple_body_poses({name: {MultiverseBodyProperty.POSITION: pose.position_as_list(), + MultiverseBodyProperty.ORIENTATION: + xyzw_to_wxyz(pose.orientation_as_list()), + MultiverseBodyProperty.RELATIVE_VELOCITY: [0.0] * 6} + for name, pose in body_poses.items()}) + + def _get_body_pose(self, body_name: str, wait: Optional[bool] = True) -> Optional[Pose]: + """ + Get the pose of a body in the simulator. + + :param body_name: The name of the body. + :param wait: Whether to wait until the pose is received. + :return: The pose of the body. + """ + data = self.reader.get_body_pose(body_name, wait) + return Pose(data[MultiverseBodyProperty.POSITION.value], + wxyz_to_xyzw(data[MultiverseBodyProperty.ORIENTATION.value])) + + def _get_multiple_body_poses(self, body_names: List[str]) -> Dict[str, Pose]: + """ + Get the poses of multiple bodies in the simulator. + + :param body_names: The list of body names. + """ + return self.reader.get_multiple_body_poses(body_names) + + def get_multiple_object_positions(self, objects: List[Object]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_positions([obj.name for obj in objects]) + + def get_object_position(self, obj: Object) -> List[float]: + return self.reader.get_body_position(obj.name) + + def get_multiple_object_orientations(self, objects: List[Object]) -> Dict[str, List[float]]: + return self.reader.get_multiple_body_orientations([obj.name for obj in objects]) + + def get_object_orientation(self, obj: Object) -> List[float]: + return self.reader.get_body_orientation(obj.name) + + def multiverse_reset_world(self): + """ + Reset the world using the Multiverse API. + """ + self.writer.reset_world() + + def disconnect_from_physics_server(self) -> None: + MultiverseClientManager.stop_all_clients() + + def join_threads(self) -> None: + self.reader.stop_thread = True + self.reader.join() + + def _remove_visual_object(self, obj_id: int) -> bool: + logwarn("Currently multiverse does not create visual objects") + return False + + def remove_object_from_simulator(self, obj: Object) -> bool: + if obj.obj_type != ObjectType.ENVIRONMENT: + self.writer.remove_body(obj.name) + return True + logwarn("Cannot remove environment objects") + return False + + def add_constraint(self, constraint: Constraint) -> int: + + if constraint.type != JointType.FIXED: + logging.error("Only fixed constraints are supported in Multiverse") + raise ValueError + + if not self.conf.let_pycram_move_attached_objects: + self.api_requester.attach(constraint) + + return self._update_constraint_collection_and_get_latest_id(constraint) + + def _update_constraint_collection_and_get_latest_id(self, constraint: Constraint) -> int: + """ + Update the constraint collection and return the latest constraint id. + + :param constraint: The constraint to be added. + :return: The latest constraint id. + """ + self.last_constraint_id += 1 + self.constraints[self.last_constraint_id] = constraint + return self.last_constraint_id + + def remove_constraint(self, constraint_id) -> None: + constraint = self.constraints.pop(constraint_id) + self.api_requester.detach(constraint) + + def perform_collision_detection(self) -> None: + ... + + def get_object_contact_points(self, obj: Object) -> ContactPointsList: + """ + Note: Currently Multiverse only gets one contact point per contact objects. + """ + multiverse_contact_points = self.api_requester.get_contact_points(obj) + contact_points = ContactPointsList([]) + body_link = None + for point in multiverse_contact_points: + if point.body_name == "world": + point.body_name = "floor" + body_object = self.get_object_by_name(point.body_name) + if body_object is None: + for obj in self.objects: + for link in obj.links.values(): + if link.name == point.body_name: + body_link = link + break + else: + body_link = body_object.root_link + if body_link is None: + logging.error(f"Body link not found: {point.body_name}") + raise ValueError(f"Body link not found: {point.body_name}") + contact_points.append(ContactPoint(obj.root_link, body_link)) + contact_points[-1].force_x_in_world_frame = point.contact_force[0] + contact_points[-1].force_y_in_world_frame = point.contact_force[1] + contact_points[-1].force_z_in_world_frame = point.contact_force[2] + contact_points[-1].normal_on_b = point.contact_force[2] + contact_points[-1].normal_force = point.contact_force[2] + return contact_points + + @staticmethod + def _get_normal_force_on_object_from_contact_force(obj: Object, contact_force: List[float]) -> float: + """ + Get the normal force on an object from the contact force exerted by another object that is expressed in the + world frame. Thus transforming the contact force to the object frame is necessary. + + :param obj: The object. + :param contact_force: The contact force. + :return: The normal force on the object. + """ + obj_quat = obj.get_orientation_as_list() + obj_rot_matrix = quaternion_matrix(obj_quat)[:3, :3] + # invert the rotation matrix to get the transformation from world to object frame + obj_rot_matrix = np.linalg.inv(obj_rot_matrix) + contact_force_array = obj_rot_matrix @ np.array(contact_force).reshape(3, 1) + return contact_force_array.flatten().tolist()[2] + + def get_contact_points_between_two_objects(self, obj1: Object, obj2: Object) -> ContactPointsList: + obj1_contact_points = self.get_object_contact_points(obj1) + return obj1_contact_points.get_points_of_object(obj2) + + def ray_test(self, from_position: List[float], to_position: List[float]) -> Optional[int]: + ray_test_result = self.ray_test_batch([from_position], [to_position])[0] + return ray_test_result[0] if ray_test_result[0] != -1 else None + + def ray_test_batch(self, from_positions: List[List[float]], + to_positions: List[List[float]], + num_threads: int = 1, + return_distance: bool = False) -> Union[List, Tuple[List, List[float]]]: + """ + Note: Currently, num_threads is not used in Multiverse. + """ + ray_results = self.api_requester.get_objects_intersected_with_rays(from_positions, to_positions) + results = [] + distances = [] + for ray_result in ray_results: + results.append([]) + if ray_result.intersected(): + body_name = ray_result.body_name + if body_name == "world": + results[-1].append(0) # The floor id, which is always 0 since the floor is spawned first. + elif body_name in self.object_name_to_id.keys(): + results[-1].append(self.object_name_to_id[body_name]) + else: + for obj in self.objects: + if body_name in obj.links.keys(): + results[-1].append(obj.id) + break + else: + results[-1].append(-1) + if return_distance: + distances.append(ray_result.distance) + if return_distance: + return results, distances + else: + return results + + def step(self): + """ + Perform a simulation step in the simulator, this is useful when use_static_mode is True. + """ + if self.conf.use_static_mode: + self.api_requester.unpause_simulation() + sleep(self.simulation_time_step) + self.api_requester.pause_simulation() + + def save_physics_simulator_state(self, state_id: Optional[int] = None, use_same_id: bool = False) -> int: + if state_id is None: + self.latest_save_id = 0 if self.latest_save_id is None else self.latest_save_id + int(not use_same_id) + state_id = self.latest_save_id + save_name = f"save_{state_id}" + self.saved_simulator_states[state_id] = self.api_requester.save(save_name) + return state_id + + def remove_physics_simulator_state(self, state_id: int) -> None: + self.saved_simulator_states.pop(state_id) + + def restore_physics_simulator_state(self, state_id: int) -> None: + self.api_requester.load(self.saved_simulator_states[state_id]) + + def set_link_color(self, link: Link, rgba_color: Color): + logwarn("set_link_color is not implemented in Multiverse") + + def get_link_color(self, link: Link) -> Color: + logwarn("get_link_color is not implemented in Multiverse") + return Color() + + def get_colors_of_object_links(self, obj: Object) -> Dict[str, Color]: + logwarn("get_colors_of_object_links is not implemented in Multiverse") + return {} + + def get_object_axis_aligned_bounding_box(self, obj: Object) -> AxisAlignedBoundingBox: + logerr("get_object_axis_aligned_bounding_box for multi-link objects is not implemented in Multiverse") + raise NotImplementedError + + def get_link_axis_aligned_bounding_box(self, link: Link) -> AxisAlignedBoundingBox: + logerr("get_link_axis_aligned_bounding_box is not implemented in Multiverse") + raise NotImplementedError + + def set_realtime(self, real_time: bool) -> None: + logwarn("set_realtime is not implemented as an API in Multiverse, it is configured in the" + "multiverse configuration file (.muv file) as rtf_required where a value of 1 means real-time") + + def set_gravity(self, gravity_vector: List[float]) -> None: + logwarn("set_gravity is not implemented in Multiverse") + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the Multiverse world. + + :param obj: The object. + :return: True if the object exists, False otherwise. + """ + return self.api_requester.check_object_exists(obj) diff --git a/src/pycram/worlds/multiverse_communication/__init__.py b/src/pycram/worlds/multiverse_communication/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pycram/worlds/multiverse_communication/client_manager.py b/src/pycram/worlds/multiverse_communication/client_manager.py new file mode 100644 index 000000000..a1b768172 --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/client_manager.py @@ -0,0 +1,97 @@ +from typing_extensions import Optional, Type, Union, Dict + +from ...worlds.multiverse_communication.clients import MultiverseWriter, MultiverseAPI, MultiverseClient, \ + MultiverseReader, MultiverseController + +from ...config.multiverse_conf import MultiverseConfig as Conf + + +class MultiverseClientManager: + BASE_PORT: int = Conf.BASE_CLIENT_PORT + """ + The base port of the Multiverse client. + """ + clients: Optional[Dict[str, MultiverseClient]] = {} + """ + The list of Multiverse clients. + """ + last_used_port: int = BASE_PORT + + def __init__(self, simulation_wait_time_factor: Optional[float] = 1.0): + """ + Initialize the Multiverse client manager. + + :param simulation_wait_time_factor: The simulation wait time factor. + """ + self.simulation_wait_time_factor = simulation_wait_time_factor + + def create_reader(self, is_prospection_world: Optional[bool] = False) -> MultiverseReader: + """ + Create a Multiverse reader client. + + :param is_prospection_world: Whether the reader is connected to the prospection world. + """ + return self.create_client(MultiverseReader, "reader", is_prospection_world) + + def create_writer(self, simulation: str, is_prospection_world: Optional[bool] = False) -> MultiverseWriter: + """ + Create a Multiverse writer client. + + :param simulation: The name of the simulation that the writer is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the writer is connected to the prospection world. + """ + return self.create_client(MultiverseWriter, "writer", is_prospection_world, + simulation=simulation) + + def create_controller(self, is_prospection_world: Optional[bool] = False) -> MultiverseController: + """ + Create a Multiverse controller client. + + :param is_prospection_world: Whether the controller is connected to the prospection world. + """ + return self.create_client(MultiverseController, "controller", is_prospection_world) + + def create_api_requester(self, simulation: str, is_prospection_world: Optional[bool] = False) -> MultiverseAPI: + """ + Create a Multiverse API client. + + :param simulation: The name of the simulation that the API is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the API is connected to the prospection world. + """ + return self.create_client(MultiverseAPI, "api_requester", is_prospection_world, simulation=simulation) + + def create_client(self, + client_type: Type[MultiverseClient], + name: Optional[str] = None, + is_prospection_world: Optional[bool] = False, + **kwargs) -> Union[MultiverseClient, MultiverseAPI, + MultiverseReader, MultiverseWriter, MultiverseController]: + """ + Create a Multiverse client. + + :param client_type: The type of the client to create. + :param name: The name of the client. + :param is_prospection_world: Whether the client is connected to the prospection world. + :param kwargs: Any other keyword arguments that should be passed to the client constructor. + """ + MultiverseClientManager.last_used_port += 1 + name = (name or client_type.__name__) + f"_{self.last_used_port}" + client = client_type(name, self.last_used_port, is_prospection_world=is_prospection_world, + simulation_wait_time_factor=self.simulation_wait_time_factor, **kwargs) + self.clients[name] = client + return client + + @classmethod + def stop_all_clients(cls): + """ + Stop all clients. + """ + for client in cls.clients: + if isinstance(client, MultiverseReader): + client.stop_thread = True + client.join() + elif isinstance(client, MultiverseClient): + client.stop() + cls.clients = {} diff --git a/src/pycram/worlds/multiverse_communication/clients.py b/src/pycram/worlds/multiverse_communication/clients.py new file mode 100644 index 000000000..b10959a90 --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/clients.py @@ -0,0 +1,832 @@ +import datetime +import logging +import os +import threading +from time import time, sleep + +from typing_extensions import List, Dict, Tuple, Optional, Callable, Union + +from .socket import MultiverseSocket, MultiverseMetaData +from ...config.multiverse_conf import MultiverseConfig as Conf +from ...datastructures.dataclasses import RayResult, MultiverseContactPoint +from ...datastructures.enums import (MultiverseAPIName as API, MultiverseBodyProperty as BodyProperty, + MultiverseProperty as Property) +from ...datastructures.pose import Pose +from ...ros.logging import logwarn +from ...utils import wxyz_to_xyzw +from ...world_concepts.constraints import Constraint +from ...world_concepts.world_object import Object, Link + + +class MultiverseClient(MultiverseSocket): + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse client, which connects to the Multiverse server. + + :param name: The name of the client. + :param port: The port of the client. + :param is_prospection_world: Whether the client is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor (default is 1.0), which can be used to + increase or decrease the wait time for the simulation. + """ + meta_data = MultiverseMetaData() + meta_data.simulation_name = (Conf.prospection_world_prefix if is_prospection_world else "") + name + meta_data.world_name = ((Conf.prospection_world_prefix if is_prospection_world else "") + + meta_data.world_name) + self.is_prospection_world = is_prospection_world + super().__init__(port=str(port), meta_data=meta_data) + self.simulation_wait_time_factor = simulation_wait_time_factor + self.run() + + +class MultiverseReader(MultiverseClient): + MAX_WAIT_TIME_FOR_DATA: datetime.timedelta = Conf.READER_MAX_WAIT_TIME_FOR_DATA + """ + The maximum wait time for the data in seconds. + """ + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse reader, which reads the data from the Multiverse server in a separate thread. + This class provides methods to get data (e.g., position, orientation) from the Multiverse server. + + :param port: The port of the Multiverse reader client. + :param is_prospection_world: Whether the reader is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + + self.request_meta_data["receive"][""] = [""] + + self.data_lock = threading.Lock() + self.thread = threading.Thread(target=self.receive_all_data_from_server) + self.stop_thread = False + + self.thread.start() + + def get_body_pose(self, name: str, wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body pose from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The position and orientation of the body. + """ + return self.get_body_data(name, [BodyProperty.POSITION, BodyProperty.ORIENTATION], wait=wait) + + def get_multiple_body_poses(self, body_names: List[str], wait: bool = False) -> Optional[Dict[str, Pose]]: + """ + Get the body poses from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The positions and orientations of the bodies as a dictionary. + """ + data = self.get_multiple_body_data(body_names, + {name: [BodyProperty.POSITION, BodyProperty.ORIENTATION] + for name in body_names + }, + wait=wait) + if data is not None: + return {name: Pose(data[name][BodyProperty.POSITION.value], + wxyz_to_xyzw(data[name][BodyProperty.ORIENTATION.value])) + for name in body_names} + + def get_body_position(self, name: str, wait: bool = False) -> Optional[List[float]]: + """ + Get the body position from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The position of the body. + """ + return self.get_body_property(name, BodyProperty.POSITION, wait=wait) + + def get_multiple_body_positions(self, body_names: List[str], + wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body positions from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The positions of the bodies as a dictionary. + """ + return self.get_multiple_body_properties(body_names, [BodyProperty.POSITION], wait=wait) + + def get_body_orientation(self, name: str, wait: bool = False) -> Optional[List[float]]: + """ + Get the body orientation from the multiverse server. + + :param name: The name of the body. + :param wait: Whether to wait for the data. + :return: The orientation of the body. + """ + return self.get_body_property(name, BodyProperty.ORIENTATION, wait=wait) + + def get_multiple_body_orientations(self, body_names: List[str], + wait: bool = False) -> Optional[Dict[str, List[float]]]: + """ + Get the body orientations from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param wait: Whether to wait for the data. + :return: The orientations of the bodies as a dictionary. + """ + data = self.get_multiple_body_properties(body_names, [BodyProperty.ORIENTATION], wait=wait) + if data is not None: + return {name: wxyz_to_xyzw(data[name][BodyProperty.ORIENTATION.value]) for name in body_names} + + def get_body_property(self, name: str, property_: Property, wait: bool = False) -> Optional[List[float]]: + """ + Get the body property from the multiverse server. + + :param name: The name of the body. + :param property_: The property of the body as a Property. + :param wait: Whether to wait for the data. + :return: The property of the body. + """ + data = self.get_body_data(name, [property_], wait=wait) + if data is not None: + return data[property_.value] + + def get_multiple_body_properties(self, body_names: List[str], properties: List[Property], + wait: bool = False) -> Optional[Dict[str, Dict[str, List[float]]]]: + """ + Get the body properties from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param wait: Whether to wait for the data. + :return: The properties of the bodies as a dictionary. + """ + return self.get_multiple_body_data(body_names, {name: properties for name in body_names}, wait=wait) + + def get_body_data(self, name: str, + properties: Optional[List[Property]] = None, + wait: bool = False) -> Optional[Dict]: + """ + Get the body data from the multiverse server. + + :param name: The name of the body. + :param properties: The properties of the body. + :param wait: Whether to wait for the data. + :return: The body data as a dictionary. + """ + if wait: + return self.wait_for_body_data(name, properties) + + data = self.get_received_data() + if self.check_for_body_data(name, data, properties): + return data[name] + + def get_multiple_body_data(self, body_names: List[str], + properties: Optional[Dict[str, List[Property]]] = None, + wait: bool = False) -> Optional[Dict]: + """ + Get the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param wait: Whether to wait for the data. + :return: The body data as a dictionary. + """ + + if wait: + return self.wait_for_multiple_body_data(body_names, properties) + + data = self.get_received_data() + if self.check_multiple_body_data(body_names, data, properties): + return {name: data[name] for name in body_names} + + def wait_for_body_data(self, name: str, properties: Optional[List[Property]] = None) -> Dict: + """ + Wait for the body data from the multiverse server. + + :param name: The name of the body. + :param properties: The properties of the body. + :return: The body data as a dictionary. + """ + return self._wait_for_body_data_template(name, self.check_for_body_data, properties)[name] + + def wait_for_multiple_body_data(self, body_names: List[str], + properties: Optional[Dict[str, List[Property]]] = None) -> Dict: + """ + Wait for the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :return: The body data as a dictionary. + """ + return self._wait_for_body_data_template(body_names, self.check_multiple_body_data, properties) + + def _wait_for_body_data_template(self, body_names: Union[str, List[str]], + check_func: Callable[[Union[str, List[str]], Dict, Union[Dict, List]], bool], + properties: Optional[Union[Dict, List]] = None) -> Dict: + """ + Wait for the body data from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param properties: The properties of the bodies. + :param check_func: The function to check if the data is received. + :return: The body data as a dictionary. + """ + start = time() + data_received_flag = False + while time() - start < self.MAX_WAIT_TIME_FOR_DATA.total_seconds(): + received_data = self.get_received_data() + data_received_flag = check_func(body_names, received_data, properties) + if data_received_flag: + return received_data + if not data_received_flag: + properties_str = "Data" if properties is None else f"Properties {properties}" + msg = f"{properties_str} for {body_names} not received within {self.MAX_WAIT_TIME_FOR_DATA} seconds" + logging.error(msg) + raise ValueError(msg) + + def check_multiple_body_data(self, body_names: List[str], data: Dict, + properties: Optional[Dict[str, List[Property]]] = None) -> bool: + """ + Check if the body data is received from the multiverse server for multiple bodies. + + :param body_names: The names of the bodies. + :param data: The data received from the multiverse server. + :param properties: The properties of the bodies. + :return: Whether the body data is received. + """ + if properties is None: + return all([self.check_for_body_data(name, data) for name in body_names]) + else: + return all([self.check_for_body_data(name, data, properties[name]) for name in body_names]) + + @staticmethod + def check_for_body_data(name: str, data: Dict, properties: Optional[List[Property]] = None) -> bool: + """ + Check if the body data is received from the multiverse server. + + :param name: The name of the body. + :param data: The data received from the multiverse server. + :param properties: The properties of the body. + :return: Whether the body data is received. + """ + if properties is None: + return name in data + else: + return name in data and all([prop.value in data[name] and None not in data[name][prop.value] + for prop in properties]) + + def get_received_data(self): + """ + Get the latest received data from the multiverse server. + """ + self.data_lock.acquire() + data = self.response_meta_data["receive"] + self.data_lock.release() + return data + + def receive_all_data_from_server(self): + """ + Get all data from the multiverse server. + """ + while not self.stop_thread: + self.request_meta_data["receive"][""] = [""] + self.data_lock.acquire() + self.send_and_receive_meta_data() + self.data_lock.release() + sleep(0.01) + self.stop() + + def join(self): + self.thread.join() + + +class MultiverseWriter(MultiverseClient): + + def __init__(self, name: str, port: int, simulation: Optional[str] = None, + is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0, **kwargs): + """ + Initialize the Multiverse writer, which writes the data to the Multiverse server. + This class provides methods to send data (e.g., position, orientation) to the Multiverse server. + + :param port: The port of the Multiverse writer client. + :param simulation: The name of the simulation that the writer is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the writer is connected to the prospection world. + :param simulation_wait_time_factor: The wait time factor for the simulation (default is 1.0), which can be used + to increase or decrease the wait time for the simulation. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + self.simulation = simulation + + def spawn_robot_with_actuators(self, robot_name: str, position: List[float], orientation: List[float], + actuator_joint_commands: Optional[Dict[str, List[str]]] = None) -> None: + """ + Spawn the robot with controlled actuators in the simulation. + + :param robot_name: The name of the robot. + :param position: The position of the robot. + :param orientation: The orientation of the robot. + :param actuator_joint_commands: A dictionary mapping actuator names to joint command names. + """ + send_meta_data = {robot_name: [BodyProperty.POSITION.value, BodyProperty.ORIENTATION.value, + BodyProperty.RELATIVE_VELOCITY.value]} + relative_velocity = [0.0] * 6 + data = [self.sim_time, *position, *orientation, *relative_velocity] + self.send_data_to_server(data, send_meta_data=send_meta_data, receive_meta_data=actuator_joint_commands) + + def _reset_request_meta_data(self, set_simulation_name: bool = True): + """ + Reset the request metadata. + + :param set_simulation_name: Whether to set the simulation name to the value of self.simulation_name. + """ + self.request_meta_data = { + "meta_data": self._meta_data.__dict__.copy(), + "send": {}, + "receive": {}, + } + if self.simulation is not None and set_simulation_name: + self.request_meta_data["meta_data"]["simulation_name"] = self.simulation + + def set_body_pose(self, body_name: str, position: List[float], orientation: List[float]) -> None: + """ + Set the body pose in the simulation. + + :param body_name: The name of the body. + :param position: The position of the body. + :param orientation: The orientation of the body. + """ + self.send_body_data_to_server(body_name, + {BodyProperty.POSITION: position, + BodyProperty.ORIENTATION: orientation, + BodyProperty.RELATIVE_VELOCITY: [0.0] * 6}) + + def set_multiple_body_poses(self, body_data: Dict[str, Dict[BodyProperty, List[float]]]) -> None: + """ + Set the body poses in the simulation for multiple bodies. + + :param body_data: The data to be sent for multiple bodies. + """ + self.send_multiple_body_data_to_server(body_data) + + def set_body_position(self, body_name: str, position: List[float]) -> None: + """ + Set the body position in the simulation. + + :param body_name: The name of the body. + :param position: The position of the body. + """ + self.set_body_property(body_name, BodyProperty.POSITION, position) + + def set_body_orientation(self, body_name: str, orientation: List[float]) -> None: + """ + Set the body orientation in the simulation. + + :param body_name: The name of the body. + :param orientation: The orientation of the body. + """ + self.set_body_property(body_name, BodyProperty.ORIENTATION, orientation) + + def set_body_property(self, body_name: str, property_: Property, value: List[float]) -> None: + """ + Set the body property in the simulation. + + :param body_name: The name of the body. + :param property_: The property of the body. + :param value: The value of the property. + """ + self.send_body_data_to_server(body_name, {property_: value}) + + def remove_body(self, body_name: str) -> None: + """ + Remove the body from the simulation. + + :param body_name: The name of the body. + """ + self.send_data_to_server([self.sim_time], + send_meta_data={body_name: []}, + receive_meta_data={body_name: []}) + + def reset_world(self) -> None: + """ + Reset the world in the simulation. + """ + self.send_data_to_server([0], set_simulation_name=False) + + def send_body_data_to_server(self, body_name: str, body_data: Dict[Property, List[float]]) -> Dict: + """ + Send data to the multiverse server. + + :param body_name: The name of the body. + :param body_data: The data to be sent. + :return: The response from the server. + """ + send_meta_data = {body_name: list(map(str, body_data.keys()))} + flattened_data = [value for data in body_data.values() for value in data] + return self.send_data_to_server([self.sim_time, *flattened_data], send_meta_data=send_meta_data) + + def send_multiple_body_data_to_server(self, body_data: Dict[str, Dict[Property, List[float]]]) -> Dict: + """ + Send data to the multiverse server for multiple bodies. + + :param body_data: The data to be sent for multiple bodies. + :return: The response from the server. + """ + send_meta_data = {body_name: list(map(str, data.keys())) for body_name, data in body_data.items()} + response_meta_data = self.send_meta_data_and_get_response(send_meta_data) + body_names = list(response_meta_data["send"].keys()) + flattened_data = [value for body_name in body_names for data in body_data[body_name].values() + for value in data] + self.send_data = [self.sim_time, *flattened_data] + self.send_and_receive_data() + return self.response_meta_data + + def send_meta_data_and_get_response(self, send_meta_data: Dict) -> Dict: + """ + Send metadata to the multiverse server and get the response. + + :param send_meta_data: The metadata to be sent. + :return: The response from the server. + """ + self._reset_request_meta_data() + self.request_meta_data["send"] = send_meta_data + self.send_and_receive_meta_data() + return self.response_meta_data + + def send_data_to_server(self, data: List, + send_meta_data: Optional[Dict] = None, + receive_meta_data: Optional[Dict] = None, + set_simulation_name: bool = True) -> Dict: + """ + Send data to the multiverse server. + + :param data: The data to be sent. + :param send_meta_data: The metadata to be sent. + :param receive_meta_data: The metadata to be received. + :param set_simulation_name: Whether to set the simulation name to the value of self.simulation. + :return: The response from the server. + """ + self._reset_request_meta_data(set_simulation_name=set_simulation_name) + if send_meta_data: + self.request_meta_data["send"] = send_meta_data + if receive_meta_data: + self.request_meta_data["receive"] = receive_meta_data + self.send_and_receive_meta_data() + self.send_data = data + self.send_and_receive_data() + return self.response_meta_data + + +class MultiverseController(MultiverseWriter): + + def __init__(self, name: str, port: int, is_prospection_world: bool = False, **kwargs): + """ + Initialize the Multiverse controller, which controls the robot in the simulation. + This class provides methods to send controller data to the Multiverse server. + + :param port: The port of the Multiverse controller client. + :param is_prospection_world: Whether the controller is connected to the prospection world. + """ + super().__init__(name, port, is_prospection_world=is_prospection_world) + + def init_controller(self, actuator_joint_commands: Dict[str, List[str]]) -> None: + """ + Initialize the controller by sending the controller data to the multiverse server. + + :param actuator_joint_commands: A dictionary mapping actuator names to joint command names. + """ + self.send_data_to_server([self.sim_time] + [0.0] * len(actuator_joint_commands), + send_meta_data=actuator_joint_commands) + + +class MultiverseAPI(MultiverseClient): + API_REQUEST_WAIT_TIME: datetime.timedelta = datetime.timedelta(milliseconds=200) + """ + The wait time for the API request in seconds. + """ + APIs_THAT_NEED_WAIT_TIME: List[API] = [API.ATTACH] + + def __init__(self, name: str, port: int, simulation: str, is_prospection_world: bool = False, + simulation_wait_time_factor: float = 1.0): + """ + Initialize the Multiverse API, which sends API requests to the Multiverse server. + This class provides methods like attach and detach objects, get contact points, and other API requests. + + :param port: The port of the Multiverse API client. + :param simulation: The name of the simulation that the API is connected to + (usually the name defined in the .muv file). + :param is_prospection_world: Whether the API is connected to the prospection world. + :param simulation_wait_time_factor: The simulation wait time factor, which can be used to increase or decrease + the wait time for the simulation. + """ + super().__init__(name, port, is_prospection_world, simulation_wait_time_factor=simulation_wait_time_factor) + self.simulation = simulation + self.wait: bool = False # Whether to wait after sending the API request. + + def save(self, save_name: str, save_directory: Optional[str] = None) -> str: + """ + Save the current state of the simulation. + + :param save_name: The name of the save. + :param save_directory: The path to save the simulation, can be relative or absolute. If the path is relative, + it will be saved in the saved folder in multiverse. + :return: The save path. + """ + response = self._request_single_api_callback(API.SAVE, self.get_save_path(save_name, save_directory)) + return response[0] + + def load(self, save_name: str, save_directory: Optional[str] = None) -> None: + """ + Load the saved state of the simulation. + + :param save_name: The name of the save. + :param save_directory: The path to load the simulation, can be relative or absolute. If the path is relative, + it will be loaded from the saved folder in multiverse. + """ + self._request_single_api_callback(API.LOAD, self.get_save_path(save_name, save_directory)) + + @staticmethod + def get_save_path(save_name: str, save_directory: Optional[str] = None) -> str: + """ + Get the save path. + + :param save_name: The save name. + :param save_directory: The save directory. + :return: The save path. + """ + return save_name if save_directory is None else os.path.join(save_directory, save_name) + + def attach(self, constraint: Constraint) -> None: + """ + Request to attach the child link to the parent link. + + :param constraint: The constraint. + """ + self.wait = True + parent_link_name, child_link_name = self.get_constraint_link_names(constraint) + attachment_pose = self._get_attachment_pose_as_string(constraint) + self._attach(child_link_name, parent_link_name, attachment_pose) + + def _attach(self, child_link_name: str, parent_link_name: str, attachment_pose: str) -> None: + """ + Attach the child link to the parent link. + + :param child_link_name: The name of the child link. + :param parent_link_name: The name of the parent link. + :param attachment_pose: The attachment pose. + """ + self._request_single_api_callback(API.ATTACH, child_link_name, parent_link_name, + attachment_pose) + + def get_constraint_link_names(self, constraint: Constraint) -> Tuple[str, str]: + """ + Get the link names of the constraint. + + :param constraint: The constraint. + :return: The link names of the constraint. + """ + return self.get_parent_link_name(constraint), self.get_constraint_child_link_name(constraint) + + def get_parent_link_name(self, constraint: Constraint) -> str: + """ + Get the parent link name of the constraint. + + :param constraint: The constraint. + :return: The parent link name of the constraint. + """ + return self.get_link_name_for_constraint(constraint.parent_link) + + def get_constraint_child_link_name(self, constraint: Constraint) -> str: + """ + Get the child link name of the constraint. + + :param constraint: The constraint. + :return: The child link name of the constraint. + """ + return self.get_link_name_for_constraint(constraint.child_link) + + @staticmethod + def get_link_name_for_constraint(link: Link) -> str: + """ + Get the link name from link object, if the link belongs to a one link object, return the object name. + + :param link: The link. + :return: The link name. + """ + return link.name if not link.is_only_link else link.object.name + + def detach(self, constraint: Constraint) -> None: + """ + Request to detach the child link from the parent link. + + :param constraint: The constraint. + """ + parent_link_name, child_link_name = self.get_constraint_link_names(constraint) + self._detach(child_link_name, parent_link_name) + + def _detach(self, child_link_name: str, parent_link_name: str) -> None: + """ + Detach the child link from the parent link. + + :param child_link_name: The name of the child link. + :param parent_link_name: The name of the parent link. + """ + self._request_single_api_callback(API.DETACH, child_link_name, parent_link_name) + + def _get_attachment_pose_as_string(self, constraint: Constraint) -> str: + """ + Get the attachment pose as a string. + + :param constraint: The constraint. + :return: The attachment pose as a string. + """ + pose = constraint.parent_to_child_transform.to_pose() + return self._pose_to_string(pose) + + @staticmethod + def _pose_to_string(pose: Pose) -> str: + """ + Convert the pose to a string. + + :param pose: The pose. + :return: The pose as a string. + """ + return f"{pose.position.x} {pose.position.y} {pose.position.z} {pose.orientation.w} {pose.orientation.x} " \ + f"{pose.orientation.y} {pose.orientation.z}" + + def check_object_exists(self, obj: Object) -> bool: + """ + Check if the object exists in the simulation. + + :param obj: The object. + :return: Whether the object exists in the simulation. + """ + return self._request_single_api_callback(API.EXIST, obj.name)[0] == 'yes' + + def get_contact_points(self, obj: Object) -> List[MultiverseContactPoint]: + """ + Request the contact points of an object, this includes the object names and the contact forces and torques. + + :param obj: The object. + :return: The contact points of the object as a list of MultiverseContactPoint. + """ + api_response_data = self._get_contact_points(obj.name) + body_names = api_response_data[API.GET_CONTACT_BODIES] + contact_efforts = self._parse_constraint_effort(api_response_data[API.GET_CONSTRAINT_EFFORT]) + return [MultiverseContactPoint(body_names[i], contact_efforts[:3], contact_efforts[3:]) + for i in range(len(body_names))] + + def get_objects_intersected_with_rays(self, from_positions: List[List[float]], + to_positions: List[List[float]]) -> List[RayResult]: + """ + Get the rays intersections with the objects from the from_positions to the to_positions. + + :param from_positions: The starting positions of the rays. + :param to_positions: The ending positions of the rays. + :return: The rays intersections with the objects as a list of RayResult. + """ + get_rays_response = self._get_rays(from_positions, to_positions) + return self._parse_get_rays_response(get_rays_response) + + def _get_rays(self, from_positions: List[List[float]], + to_positions: List[List[float]]) -> List[str]: + """ + Get the rays intersections with the objects from the from_positions to the to_positions. + + :param from_positions: The starting positions of the rays. + :param to_positions: The ending positions of the rays. + :return: The rays intersections with the objects as a dictionary. + """ + from_positions = self.list_of_positions_to_string(from_positions) + to_positions = self.list_of_positions_to_string(to_positions) + return self._request_single_api_callback(API.GET_RAYS, from_positions, to_positions) + + @staticmethod + def _parse_get_rays_response(response: List[str]) -> List[RayResult]: + """ + Parse the response of the get rays API. + + :param response: The response of the get rays API as a list of strings. + :return: The rays as a list of lists of floats. + """ + get_rays_results = [] + for ray_response in response: + if ray_response == "None": + get_rays_results.append(RayResult("", -1)) + else: + result = ray_response.split() + result[1] = float(result[1]) + get_rays_results.append(RayResult(*result)) + return get_rays_results + + @staticmethod + def list_of_positions_to_string(positions: List[List[float]]) -> str: + """ + Convert the list of positions to a string. + + :param positions: The list of positions. + :return: The list of positions as a string. + """ + return " ".join([f"{position[0]} {position[1]} {position[2]}" for position in positions]) + + @staticmethod + def _parse_constraint_effort(contact_effort: List[str]) -> List[float]: + """ + Parse the contact effort of an object. + + :param contact_effort: The contact effort of the object as a list of strings. + :return: The contact effort of the object as a list of floats. + """ + contact_effort = contact_effort[0].split() + if 'failed' in contact_effort: + logwarn("Failed to get contact effort") + return [0.0] * 6 + return list(map(float, contact_effort)) + + def _get_contact_points(self, object_name) -> Dict[API, List]: + """ + Request the contact points of an object. + + :param object_name: The name of the object. + :return: The contact points api response as a dictionary. + """ + return self._request_apis_callbacks({API.GET_CONTACT_BODIES: [object_name], + API.GET_CONSTRAINT_EFFORT: [object_name] + }) + + def pause_simulation(self) -> None: + """ + Pause the simulation. + """ + self._request_single_api_callback(API.PAUSE) + + def unpause_simulation(self) -> None: + """ + Unpause the simulation. + """ + self._request_single_api_callback(API.UNPAUSE) + + def _request_single_api_callback(self, api_name: API, *params) -> List[str]: + """ + Request a single API callback from the server. + + :param api_data: The API data to request the callback. + :return: The API response as a list of strings. + """ + response = self._request_apis_callbacks({api_name: list(params)}) + return response[api_name] + + def _request_apis_callbacks(self, api_data: Dict[API, List]) -> Dict[API, List[str]]: + """ + Request the API callbacks from the server. + + :param api_data: The API data to add to the request metadata. + :return: The API response as a list of strings. + """ + self._reset_api_callback() + for api_name, params in api_data.items(): + self._add_api_request(api_name.value, *params) + self._send_api_request() + responses = self._get_all_apis_responses() + if self.wait: + sleep(self.API_REQUEST_WAIT_TIME.total_seconds() * self.simulation_wait_time_factor) + self.wait = False + return responses + + def _get_all_apis_responses(self) -> Dict[API, List[str]]: + """ + Get all the API responses from the server. + + :return: The API responses as a list of APIData. + """ + list_of_api_responses = self.response_meta_data["api_callbacks_response"][self.simulation] + return {API[api_name.upper()]: response for api_response in list_of_api_responses + for api_name, response in api_response.items()} + + def _add_api_request(self, api_name: str, *params): + """ + Add an API request to the request metadata. + + :param api_name: The name of the API. + :param params: The parameters of the API. + """ + self.request_meta_data["api_callbacks"][self.simulation].append({api_name: list(params)}) + + def _send_api_request(self): + """ + Send the API request to the server. + """ + if "api_callbacks" not in self.request_meta_data: + logging.error("No API request to send") + raise ValueError + self.send_and_receive_meta_data() + self.request_meta_data.pop("api_callbacks") + + def _reset_api_callback(self): + """ + Initialize the API callback in the request metadata. + """ + self.request_meta_data["api_callbacks"] = {self.simulation: []} diff --git a/src/pycram/worlds/multiverse_communication/socket.py b/src/pycram/worlds/multiverse_communication/socket.py new file mode 100644 index 000000000..863a45d8a --- /dev/null +++ b/src/pycram/worlds/multiverse_communication/socket.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 + +"""Multiverse Client base class.""" + +from multiverse_client_pybind import MultiverseClientPybind # noqa +from typing_extensions import Optional, List, Dict, Callable, TypeVar + +from ...datastructures.dataclasses import MultiverseMetaData +from ...config.multiverse_conf import MultiverseConfig as Conf +from ...ros.logging import loginfo, logwarn + +T = TypeVar("T") + + +class MultiverseSocket: + + def __init__( + self, + port: str, + host: str = Conf.HOST, + meta_data: MultiverseMetaData = MultiverseMetaData(), + ) -> None: + """ + Initialize the MultiverseSocket, connect to the Multiverse Server and start the communication. + + :param port: The port of the client. + :param host: The host of the client. + :param meta_data: The metadata for the Multiverse Client as MultiverseMetaData. + """ + if not isinstance(port, str) or port == "": + raise ValueError(f"Must specify client port for {self.__class__.__name__}") + self._send_data = None + self.port = port + self.host = host + self._meta_data = meta_data + self.client_name = self._meta_data.simulation_name + self._multiverse_socket = MultiverseClientPybind( + f"{Conf.SERVER_HOST}:{Conf.SERVER_PORT}" + ) + self.request_meta_data = { + "meta_data": self._meta_data.__dict__, + "send": {}, + "receive": {}, + } + self._api_callbacks: Optional[Dict] = None + + self._start_time = 0.0 + + def run(self) -> None: + """Run the client.""" + self.log_info("Start") + self._run() + + def _run(self) -> None: + """Run the client, should call the _connect_and_start() method. It's left to the user to implement this method + in threaded or non-threaded fashion. + """ + self._connect_and_start() + + def stop(self) -> None: + """Stop the client.""" + self._disconnect() + + @property + def request_meta_data(self) -> Dict: + """The request_meta_data which is sent to the server. + """ + return self._request_meta_data + + @request_meta_data.setter + def request_meta_data(self, request_meta_data: Dict) -> None: + """Set the request_meta_data, make sure to clear the `send` and `receive` field before setting the request + """ + self._request_meta_data = request_meta_data + self._multiverse_socket.set_request_meta_data(self._request_meta_data) + + @property + def response_meta_data(self) -> Dict: + """Get the response_meta_data. + + :return: The response_meta_data as a dictionary. + """ + response_meta_data = self._multiverse_socket.get_response_meta_data() + assert isinstance(response_meta_data, dict) + if response_meta_data == {}: + message = f"[Client {self.port}] Receive empty response meta data." + self.log_warn(message) + return response_meta_data + + def send_and_receive_meta_data(self): + """ + Send and receive the metadata, this should be called before sending and receiving data. + """ + self._communicate(True) + + def send_and_receive_data(self): + """ + Send and receive the data, this should be called after sending and receiving the metadata. + """ + self._communicate(False) + + @property + def send_data(self) -> List[float]: + """Get the send_data.""" + return self._send_data + + @send_data.setter + def send_data(self, send_data: List[float]) -> None: + """Set the send_data, the first element should be the current simulation time, + the rest should be the data to send with the following order: + double -> uint8_t -> uint16_t + + :param send_data: The data to send. + """ + assert isinstance(send_data, list) + self._send_data = send_data + self._multiverse_socket.set_send_data(self._send_data) + + @property + def receive_data(self) -> List[float]: + """Get the receive_data, the first element should be the current simulation time, + the rest should be the received data with the following order: + double -> uint8_t -> uint16_t + + :return: The received data. + """ + receive_data = self._multiverse_socket.get_receive_data() + assert isinstance(receive_data, list) + return receive_data + + @property + def api_callbacks(self) -> Dict[str, Callable[[List[str]], List[str]]]: + """Get the api_callbacks. + + :return: The api_callbacks as a dictionary of function names and their respective callbacks. + """ + return self._api_callbacks + + @api_callbacks.setter + def api_callbacks(self, api_callbacks: Dict[str, Callable[[List[str]], List[str]]]) -> None: + """Set the api_callbacks. + + :param api_callbacks: The api_callbacks as a dictionary of function names and their respective callbacks. + """ + self._multiverse_socket.set_api_callbacks(api_callbacks) + self._api_callbacks = api_callbacks + + def _bind_request_meta_data(self, request_meta_data: T) -> T: + """Bind the request_meta_data before sending it to the server. + + :param request_meta_data: The request_meta_data to bind. + :return: The bound request_meta_data. + """ + pass + + def _bind_response_meta_data(self, response_meta_data: T) -> T: + """Bind the response_meta_data after receiving it from the server. + + :param response_meta_data: The response_meta_data to bind. + :return: The bound response_meta_data. + """ + pass + + def _bind_send_data(self, send_data: T) -> T: + """Bind the send_data before sending it to the server. + + :param send_data: The send_data to bind. + :return: The bound send_data. + """ + pass + + def _bind_receive_data(self, receive_data: T) -> T: + """Bind the receive_data after receiving it from the server. + + :param receive_data: The receive_data to bind. + :return: The bound receive_data. + """ + pass + + def _connect_and_start(self) -> None: + """Connect to the server and start the client. + """ + self._multiverse_socket.connect(self.host, self.port) + self._multiverse_socket.start() + self._start_time = self._multiverse_socket.get_time_now() + + def _disconnect(self) -> None: + """Disconnect from the server. + """ + self._multiverse_socket.disconnect() + + def _communicate(self, resend_request_meta_data: bool = False) -> bool: + """Communicate with the server. + + :param resend_request_meta_data: Resend the request metadata. + :return: True if the communication was successful, False otherwise. + """ + return self._multiverse_socket.communicate(resend_request_meta_data) + + def _restart(self) -> None: + """Restart the client. + """ + self._disconnect() + self._connect_and_start() + + def log_info(self, message: str) -> None: + """Log information. + + :param message: The message to log. + """ + loginfo(self._message_template(message)) + + def log_warn(self, message: str) -> None: + """Warn the user. + + :param message: The message to warn about. + """ + logwarn(self._message_template(message)) + + def _message_template(self, message: str) -> str: + return (f"[{self.__class__.__name__}:{self.port}]: {message} : sim time {self.sim_time}," + f" world time {self.world_time}") + + @property + def world_time(self) -> float: + """Get the world time from the server. + + :return: The world time. + """ + return self._multiverse_socket.get_world_time() + + @property + def sim_time(self) -> float: + """Get the current simulation time. + + :return: The current simulation time. + """ + return self._multiverse_socket.get_time_now() - self._start_time diff --git a/test/bullet_world_testcase.py b/test/bullet_world_testcase.py index 54a48922e..4bb0a27b9 100644 --- a/test/bullet_world_testcase.py +++ b/test/bullet_world_testcase.py @@ -2,6 +2,7 @@ import unittest import pycram.tasktree +from pycram.datastructures.world import UseProspectionWorld from pycram.worlds.bullet_world import BulletWorld from pycram.world_concepts.world_object import Object from pycram.datastructures.pose import Pose @@ -9,7 +10,7 @@ from pycram.process_module import ProcessModule from pycram.datastructures.enums import ObjectType, WorldMode from pycram.object_descriptors.urdf import ObjectDescription -from pycram.ros.viz_marker_publisher import VizMarkerPublisher +from pycram.ros_utils.viz_marker_publisher import VizMarkerPublisher from pycram.ontology.ontology import OntologyManager, SOMA_ONTOLOGY_IRI @@ -29,22 +30,26 @@ def setUpClass(cls): RobotDescription.current_robot_description.name + cls.extension) cls.kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen" + cls.extension) cls.cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - ObjectDescription, pose=Pose([1.3, 0.7, 0.95])) + pose=Pose([1.3, 0.7, 0.95])) ProcessModule.execution_delay = False cls.viz_marker_publisher = VizMarkerPublisher() OntologyManager(SOMA_ONTOLOGY_IRI) def setUp(self): - self.world.reset_world() + self.world.reset_world(remove_saved_states=True) + with UseProspectionWorld(): + pass # DO NOT WRITE TESTS HERE!!! # Test related to the BulletWorld should be written in test_bullet_world.py # Tests in here would not be properly executed in the CI def tearDown(self): - pycram.tasktree.reset_tree() + pycram.tasktree.task_tree.reset_tree() time.sleep(0.05) - self.world.reset_world() + self.world.reset_world(remove_saved_states=True) + with UseProspectionWorld(): + pass @classmethod def tearDownClass(cls): @@ -67,7 +72,7 @@ def setUpClass(cls): RobotDescription.current_robot_description.name + cls.extension) cls.kitchen = Object("kitchen", ObjectType.ENVIRONMENT, "kitchen" + cls.extension) cls.cereal = Object("cereal", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - ObjectDescription, pose=Pose([1.3, 0.7, 0.95])) + pose=Pose([1.3, 0.7, 0.95])) ProcessModule.execution_delay = False cls.viz_marker_publisher = VizMarkerPublisher() diff --git a/test/test_action_designator.py b/test/test_action_designator.py index 9b09feac8..1c7fb5a3d 100644 --- a/test/test_action_designator.py +++ b/test/test_action_designator.py @@ -17,16 +17,14 @@ class TestActionDesignatorGrounding(BulletWorldTestCase): def test_move_torso(self): description = action_designator.MoveTorsoAction([0.3]) - # SOMA ontology seems not provide a corresponding concept yet for MoveTorso - #self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().position, 0.3) with simulated_robot: description.resolve().perform() - self.assertEqual(self.world.robot.get_joint_position(RobotDescription.current_robot_description.torso_joint), 0.3) + self.assertEqual(self.world.robot.get_joint_position(RobotDescription.current_robot_description.torso_joint), + 0.3) def test_set_gripper(self): description = action_designator.SetGripperAction([Arms.LEFT], [GripperState.OPEN, GripperState.CLOSE]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().gripper, Arms.LEFT) self.assertEqual(description.ground().motion, GripperState.OPEN) self.assertEqual(len(list(iter(description))), 2) @@ -38,21 +36,18 @@ def test_set_gripper(self): def test_release(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.ReleaseAction([Arms.LEFT], object_description) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().gripper, Arms.LEFT) self.assertEqual(description.ground().object_designator.name, "milk") def test_grip(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.GripAction([Arms.LEFT], object_description, [0.5]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().gripper, Arms.LEFT) self.assertEqual(description.ground().object_designator.name, "milk") def test_park_arms(self): description = action_designator.ParkArmsAction([Arms.BOTH]) self.assertEqual(description.ground().arm, Arms.BOTH) - self.assertTrue(description.ontology_concept_holders) with simulated_robot: description.resolve().perform() for joint, pose in RobotDescription.current_robot_description.get_static_joint_chain("right", "park").items(): @@ -66,13 +61,11 @@ def test_navigate(self): with simulated_robot: description.resolve().perform() self.assertEqual(description.ground().target_location, Pose([1, 0, 0], [0, 0, 0, 1])) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(self.robot.get_pose(), Pose([1, 0, 0])) def test_pick_up(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.PickUpAction(object_description, [Arms.LEFT], [Grasp.FRONT]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().object_designator.name, "milk") with simulated_robot: NavigateActionPerformable(Pose([0.6, 0.4, 0], [0, 0, 0, 1])).perform() @@ -83,7 +76,6 @@ def test_pick_up(self): def test_place(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.PlaceAction(object_description, [Pose([1.3, 1, 0.9], [0, 0, 0, 1])], [Arms.LEFT]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().object_designator.name, "milk") with simulated_robot: NavigateActionPerformable(Pose([0.6, 0.4, 0], [0, 0, 0, 1])).perform() @@ -94,7 +86,6 @@ def test_place(self): def test_look_at(self): description = action_designator.LookAtAction([Pose([1, 0, 1])]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().target, Pose([1, 0, 1])) with simulated_robot: description.resolve().perform() @@ -105,7 +96,6 @@ def test_detect(self): self.milk.set_pose(Pose([1.5, 0, 1.2])) object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.DetectAction(object_description) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().object_designator.name, "milk") with simulated_robot: detected_object = description.resolve().perform() @@ -118,14 +108,12 @@ def test_detect(self): def test_open(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.OpenAction(object_description, [Arms.LEFT]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().object_designator.name, "milk") @unittest.skip def test_close(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.CloseAction(object_description, [Arms.LEFT]) - self.assertTrue(description.ontology_concept_holders) self.assertEqual(description.ground().object_designator.name, "milk") def test_transport(self): @@ -134,7 +122,6 @@ def test_transport(self): [Arms.LEFT], [Pose([-1.35, 0.78, 0.95], [0.0, 0.0, 0.16439898301071468, 0.9863939245479175])]) - self.assertTrue(description.ontology_concept_holders) with simulated_robot: action_designator.MoveTorsoAction([0.2]).resolve().perform() description.resolve().perform() @@ -148,7 +135,6 @@ def test_grasping(self): self.robot.set_pose(Pose([-2.14, 1.06, 0])) milk_desig = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.GraspingAction([Arms.RIGHT], milk_desig) - self.assertTrue(description.ontology_concept_holders) with simulated_robot: description.resolve().perform() dist = np.linalg.norm( @@ -161,7 +147,3 @@ def test_facing(self): FaceAtPerformable(self.milk.pose).perform() milk_in_robot_frame = LocalTransformer().transform_to_object_frame(self.milk.pose, self.robot) self.assertAlmostEqual(milk_in_robot_frame.position.y, 0.) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_attachment.py b/test/test_attachment.py index d521d752a..bf8487942 100644 --- a/test/test_attachment.py +++ b/test/test_attachment.py @@ -20,6 +20,20 @@ def test_detach(self): self.assertTrue(self.robot not in self.milk.attachments) self.assertTrue(self.milk not in self.robot.attachments) + def test_detach_sync_in_prospection_world(self): + self.milk.attach(self.robot) + with UseProspectionWorld(): + pass + self.milk.detach(self.robot) + with UseProspectionWorld(): + pass + self.assertTrue(self.milk not in self.robot.attachments) + self.assertTrue(self.robot not in self.milk.attachments) + prospection_milk = self.world.get_prospection_object_for_object(self.milk) + prospection_robot = self.world.get_prospection_object_for_object(self.robot) + self.assertTrue(prospection_milk not in prospection_robot.attachments) + self.assertTrue(prospection_robot not in prospection_milk.attachments) + def test_attachment_behavior(self): self.robot.attach(self.milk) @@ -52,27 +66,28 @@ def test_prospection_object_attachments_not_changed_with_real_object(self): time.sleep(0.05) milk_2.attach(cereal_2) time.sleep(0.05) - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - # self.assertTrue(cereal_2 not in prospection_milk.attachments) - prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) - # self.assertTrue(prospection_cereal in prospection_milk.attachments) - self.assertTrue(prospection_milk.attachments == {}) - - # Assert that when prospection object is moved, the real object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + # self.assertTrue(cereal_2 not in prospection_milk.attachments) + prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) + # self.assertTrue(prospection_cereal in prospection_milk.attachments) + self.assertTrue(prospection_cereal in prospection_milk.attachments.keys()) + + # Assert that when prospection object is moved, the real object is not moved prospection_milk_pos = prospection_milk.get_position() cereal_pos = cereal_2.get_position() - prospection_cereal_pos = prospection_cereal.get_position() + estimated_prospection_cereal_pos = prospection_cereal.get_position() + estimated_prospection_cereal_pos.x += 1 # Move prospection milk object prospection_milk_pos.x += 1 prospection_milk.set_position(prospection_milk_pos) - # Prospection object should not move + # Prospection cereal should move since it is attached to prospection milk new_prospection_cereal_pose = prospection_cereal.get_position() - self.assertTrue(new_prospection_cereal_pose == prospection_cereal_pos) + self.assertAlmostEqual(new_prospection_cereal_pose.x, estimated_prospection_cereal_pos.x, delta=0.01) - # Real cereal object should not move + # Also Real cereal object should not move since it is not affected by prospection milk new_cereal_pos = cereal_2.get_position() assumed_cereal_pos = cereal_pos self.assertTrue(new_cereal_pos == assumed_cereal_pos) @@ -80,22 +95,6 @@ def test_prospection_object_attachments_not_changed_with_real_object(self): self.world.remove_object(milk_2) self.world.remove_object(cereal_2) - def test_no_attachment_in_prospection_world(self): - milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - cereal_2 = Object("cereal_2", ObjectType.BREAKFAST_CEREAL, "breakfast_cereal.stl", - pose=Pose([1.3, 0.7, 0.95])) - - milk_2.attach(cereal_2) - - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_cereal = self.world.get_prospection_object_for_object(cereal_2) - - self.assertTrue(prospection_milk.attachments == {}) - self.assertTrue(prospection_cereal.attachments == {}) - - self.world.remove_object(milk_2) - self.world.remove_object(cereal_2) - def test_attaching_to_robot_and_moving(self): self.robot.attach(self.milk) milk_pos = self.milk.get_position() @@ -106,5 +105,3 @@ def test_attaching_to_robot_and_moving(self): new_milk_pos = self.milk.get_position() self.assertEqual(new_milk_pos.x, milk_pos.x + 1) - - diff --git a/test/test_bullet_world.py b/test/test_bullet_world.py index ec398df7a..565e98342 100644 --- a/test/test_bullet_world.py +++ b/test/test_bullet_world.py @@ -11,7 +11,7 @@ from pycram.object_descriptors.urdf import ObjectDescription from pycram.datastructures.dataclasses import Color from pycram.world_concepts.world_object import Object -from pycram.datastructures.world import UseProspectionWorld +from pycram.datastructures.world import UseProspectionWorld, World fix_missing_inertial = ObjectDescription.fix_missing_inertial @@ -53,8 +53,7 @@ def test_remove_object(self): self.assertTrue(milk_id in [obj.id for obj in self.world.objects]) self.world.remove_object(self.milk) self.assertTrue(milk_id not in [obj.id for obj in self.world.objects]) - BulletWorldTest.milk = Object("milk", ObjectType.MILK, "milk.stl", - ObjectDescription, pose=Pose([1.3, 1, 0.9])) + BulletWorldTest.milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) def test_remove_robot(self): robot_id = self.robot.id @@ -65,7 +64,7 @@ def test_remove_robot(self): RobotDescription.current_robot_description.name + self.extension) def test_get_joint_position(self): - self.assertEqual(self.robot.get_joint_position("head_pan_joint"), 0.0) + self.assertAlmostEqual(self.robot.get_joint_position("head_pan_joint"), 0.0, delta=0.01) def test_get_object_contact_points(self): self.assertEqual(len(self.robot.contact_points()), 0) @@ -136,51 +135,34 @@ def test_equal_world_states(self): time.sleep(2.5) self.robot.set_pose(Pose([1, 0, 0], [0, 0, 0, 1])) self.assertFalse(self.world.world_sync.check_for_equal()) - self.world.prospection_world.object_states = self.world.current_state.object_states - time.sleep(0.05) - self.assertTrue(self.world.world_sync.check_for_equal()) + with UseProspectionWorld(): + self.assertTrue(self.world.world_sync.check_for_equal()) def test_add_resource_path(self): self.world.add_resource_path("test") - self.assertTrue("test" in self.world.data_directory) + self.assertTrue("test" in self.world.get_data_directories()) def test_no_prospection_object_found_for_given_object(self): milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - time.sleep(0.05) try: prospection_milk_2 = self.world.get_prospection_object_for_object(milk_2) self.world.remove_object(milk_2) - time.sleep(0.1) self.world.get_prospection_object_for_object(milk_2) self.assertFalse(True) - except ValueError as e: - self.assertTrue(True) - - def test_no_object_found_for_given_prospection_object(self): - milk_2 = Object("milk_2", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - time.sleep(0.05) - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - self.assertTrue(self.world.get_object_for_prospection_object(prospection_milk) == milk_2) - try: - self.world.remove_object(milk_2) - self.world.get_object_for_prospection_object(prospection_milk) - time.sleep(0.1) - self.assertFalse(True) - except ValueError as e: + except KeyError as e: self.assertTrue(True) - time.sleep(0.05) def test_real_object_position_does_not_change_with_prospection_object(self): milk_2_pos = [1.3, 1, 0.9] milk_2 = Object("milk_3", ObjectType.MILK, "milk.stl", pose=Pose(milk_2_pos)) time.sleep(0.05) milk_2_pos = milk_2.get_position() - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_milk_pos = prospection_milk.get_position() - self.assertTrue(prospection_milk_pos == milk_2_pos) # Assert that when prospection object is moved, the real object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + prospection_milk_pos = prospection_milk.get_position() + self.assertTrue(prospection_milk_pos == milk_2_pos) prospection_milk_pos.x += 1 prospection_milk.set_position(prospection_milk_pos) self.assertTrue(prospection_milk.get_position() != milk_2.get_position()) @@ -191,32 +173,32 @@ def test_prospection_object_position_does_not_change_with_real_object(self): milk_2 = Object("milk_4", ObjectType.MILK, "milk.stl", pose=Pose(milk_2_pos)) time.sleep(0.05) milk_2_pos = milk_2.get_position() - prospection_milk = self.world.get_prospection_object_for_object(milk_2) - prospection_milk_pos = prospection_milk.get_position() - self.assertTrue(prospection_milk_pos == milk_2_pos) # Assert that when real object is moved, the prospection object is not moved with UseProspectionWorld(): + prospection_milk = self.world.get_prospection_object_for_object(milk_2) + prospection_milk_pos = prospection_milk.get_position() + self.assertTrue(prospection_milk_pos == milk_2_pos) milk_2_pos.x += 1 milk_2.set_position(milk_2_pos) self.assertTrue(prospection_milk.get_position() != milk_2.get_position()) self.world.remove_object(milk_2) def test_add_vis_axis(self): - self.world.add_vis_axis(self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame())) + self.world.add_vis_axis(self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_link())) self.assertTrue(len(self.world.vis_axis) == 1) self.world.remove_vis_axis() self.assertTrue(len(self.world.vis_axis) == 0) def test_add_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id = self.world.add_text("test", link.position_as_list, link.orientation_as_list, 1, Color(1, 0, 0, 1), 3, link.object_id, link.id) if self.world.mode == WorldMode.GUI: time.sleep(4) def test_remove_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id_1 = self.world.add_text("test 1", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, Color(1, 0, 0, 1), 0, link.object_id, link.id) text_id = self.world.add_text("test 2", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, @@ -229,7 +211,7 @@ def test_remove_text(self): time.sleep(3) def test_remove_all_text(self): - link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_frame()) + link: ObjectDescription.Link = self.robot.get_link(RobotDescription.current_robot_description.get_camera_link()) text_id_1 = self.world.add_text("test 1", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, Color(1, 0, 0, 1), 0, link.object_id, link.id) text_id = self.world.add_text("test 2", link.pose.position_as_list(), link.pose.orientation_as_list(), 1, diff --git a/test/test_bullet_world_reasoning.py b/test/test_bullet_world_reasoning.py index 3fafe27d4..8d8c1061b 100644 --- a/test/test_bullet_world_reasoning.py +++ b/test/test_bullet_world_reasoning.py @@ -20,28 +20,35 @@ def test_visible(self): self.milk.set_pose(Pose([1.5, 0, 1.2])) self.robot.set_pose(Pose()) time.sleep(1) - camera_frame = RobotDescription.current_robot_description.get_camera_frame() - self.world.add_vis_axis(self.robot.get_link_pose(camera_frame)) - self.assertTrue(btr.visible(self.milk, self.robot.get_link_pose(camera_frame), + camera_link = RobotDescription.current_robot_description.get_camera_link() + self.world.add_vis_axis(self.robot.get_link_pose(camera_link)) + self.assertTrue(btr.visible(self.milk, self.robot.get_link_pose(camera_link), RobotDescription.current_robot_description.get_default_camera().front_facing_axis)) def test_occluding(self): self.milk.set_pose(Pose([3, 0, 1.2])) self.robot.set_pose(Pose()) - self.assertTrue(btr.occluding(self.milk, self.robot.get_link_pose(RobotDescription.current_robot_description.get_camera_frame()), + self.assertTrue(btr.occluding(self.milk, self.robot.get_link_pose( + RobotDescription.current_robot_description.get_camera_link()), RobotDescription.current_robot_description.get_default_camera().front_facing_axis) != []) def test_reachable(self): self.robot.set_pose(Pose()) time.sleep(1) - self.assertTrue(btr.reachable(Pose([0.5, -0.7, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame())) - self.assertFalse(btr.reachable(Pose([2, 2, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame())) + self.assertTrue(btr.reachable(Pose([0.5, -0.7, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains[ + "right"].get_tool_frame())) + self.assertFalse(btr.reachable(Pose([2, 2, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains[ + "right"].get_tool_frame())) def test_blocking(self): self.milk.set_pose(Pose([0.5, -0.7, 1])) self.robot.set_pose(Pose()) time.sleep(2) - self.assertTrue(btr.blocking(Pose([0.5, -0.7, 1]), self.robot, RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame()) != []) + blocking = btr.blocking(Pose([0.5, -0.7, 1]), self.robot, + RobotDescription.current_robot_description.kinematic_chains["right"].get_tool_frame()) + self.assertTrue(blocking != []) def test_supporting(self): self.milk.set_pose(Pose([1.3, 0, 0.9])) diff --git a/test/test_cache_manager.py b/test/test_cache_manager.py index bad803f22..9f208d90f 100644 --- a/test/test_cache_manager.py +++ b/test/test_cache_manager.py @@ -1,20 +1,18 @@ +import os from pathlib import Path from bullet_world_testcase import BulletWorldTestCase -from pycram.datastructures.enums import ObjectType -from pycram.world_concepts.world_object import Object -import pathlib +from pycram.object_descriptors.urdf import ObjectDescription as URDFObject +from pycram.config import world_conf as conf class TestCacheManager(BulletWorldTestCase): def test_generate_description_and_write_to_cache(self): cache_manager = self.world.cache_manager - file_path = pathlib.Path(__file__).parent.resolve() - path = str(file_path) + "/../resources/apartment.urdf" + path = os.path.join(self.world.conf.resources_path, "objects/apartment.urdf") extension = Path(path).suffix - cache_path = self.world.cache_dir + "apartment.urdf" - apartment = Object("apartment", ObjectType.ENVIRONMENT, path) - cache_manager.generate_description_and_write_to_cache(path, apartment.name, extension, cache_path, - apartment.description) - self.assertTrue(cache_manager.is_cached(path, apartment.description)) + cache_path = os.path.join(self.world.conf.cache_dir, "apartment.urdf") + apartment = URDFObject(path) + apartment.generate_description_from_file(path, "apartment", extension, cache_path) + self.assertTrue(cache_manager.is_cached(path, apartment)) diff --git a/test/test_costmaps.py b/test/test_costmaps.py index 3258353e1..504bbe1a5 100644 --- a/test/test_costmaps.py +++ b/test/test_costmaps.py @@ -1,3 +1,5 @@ +import unittest + import numpy as np from random_events.variable import Continuous # import plotly.graph_objects as go @@ -5,11 +7,12 @@ from random_events.interval import * from bullet_world_testcase import BulletWorldTestCase -from pycram.costmaps import OccupancyCostmap +from pycram.costmaps import OccupancyCostmap, AlgebraicSemanticCostmap from pycram.datastructures.pose import Pose +import plotly.graph_objects as go -class TestCostmapsCase(BulletWorldTestCase): +class CostmapTestCase(BulletWorldTestCase): def test_raytest_bug(self): for i in range(30): @@ -55,3 +58,35 @@ def test_visualize(self): o = OccupancyCostmap(0.2, from_ros=False, size=200, resolution=0.02, origin=Pose([0, 0, 0], [0, 0, 0, 1])) o.visualize() + + +class SemanticCostmapTestCase(BulletWorldTestCase): + + def test_generate_map(self): + costmap = AlgebraicSemanticCostmap(self.kitchen, "kitchen_island_surface") + costmap.valid_area &= costmap.left() + costmap.valid_area &= costmap.top() + costmap.valid_area &= costmap.border(0.2) + self.assertEqual(len(costmap.valid_area.simple_sets), 2) + + def test_as_distribution(self): + costmap = AlgebraicSemanticCostmap(self.kitchen, "kitchen_island_surface") + costmap.valid_area &= costmap.right() & costmap.bottom() & costmap.border(0.2) + model = costmap.as_distribution() + self.assertEqual(len(model.nodes), 7) + # fig = go.Figure(model.plot(), model.plotly_layout()) + # fig.show() + # supp = model.support + # fig = go.Figure(supp.plot(), supp.plotly_layout()) + # fig.show() + + def test_iterate(self): + costmap = AlgebraicSemanticCostmap(self.kitchen, "kitchen_island_surface") + costmap.valid_area &= costmap.left() & costmap.top() & costmap.border(0.2) + for sample in iter(costmap): + self.assertIsInstance(sample, Pose) + self.assertTrue(costmap.valid_area.contains([sample.position.x, sample.position.y])) + + +class OntologySemanticLocationTestCase(unittest.TestCase): + ... \ No newline at end of file diff --git a/test/test_database_resolver.py b/test/test_database_resolver.py index 5f015d831..7392dbac2 100644 --- a/test/test_database_resolver.py +++ b/test/test_database_resolver.py @@ -2,7 +2,7 @@ import unittest import sqlalchemy import sqlalchemy.orm -import pycram.plan_failures +import pycram.failures from pycram.world_concepts.world_object import Object from pycram.datastructures.world import World from pycram.designators import action_designator @@ -24,7 +24,7 @@ pycrorm_uri = "mysql+pymysql://" + pycrorm_uri -@unittest.skipIf(pycrorm_uri is None, "pycrorm database is not available.") +@unittest.skip class DatabaseResolverTestCase(unittest.TestCase,): world: World milk: Object @@ -37,7 +37,8 @@ def setUpClass(cls) -> None: global pycrorm_uri cls.world = BulletWorld(WorldMode.DIRECT) cls.milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1.3, 1, 0.9])) - cls.robot = Object(robot_description.name, ObjectType.ROBOT, RobotDescription.current_robot_description.name + ".urdf") + cls.robot = Object(RobotDescription.current_robot_description.name, + ObjectType.ROBOT, RobotDescription.current_robot_description.name + ".urdf") ProcessModule.execution_delay = False cls.engine = sqlalchemy.create_engine(pycrorm_uri) diff --git a/test/test_description.py b/test/test_description.py index a0d12b3b2..e324384ac 100644 --- a/test/test_description.py +++ b/test/test_description.py @@ -1,3 +1,4 @@ +import os.path import pathlib from bullet_world_testcase import BulletWorldTestCase @@ -22,10 +23,16 @@ def test_joint_child_link(self): def test_generate_description_from_mesh(self): file_path = pathlib.Path(__file__).parent.resolve() - self.assertTrue(self.milk.description.generate_description_from_file(str(file_path) + "/../resources/cached/milk.stl", - "milk", ".stl")) + cache_path = self.world.cache_manager.cache_dir + cache_path = os.path.join(cache_path, f"{self.milk.description.name}.urdf") + self.milk.description.generate_from_mesh_file(str(file_path) + "/../resources/milk.stl", "milk", cache_path) + self.assertTrue(self.world.cache_manager.is_cached(f"{self.milk.name}", self.milk.description)) def test_generate_description_from_description_file(self): file_path = pathlib.Path(__file__).parent.resolve() - self.assertTrue(self.milk.description.generate_description_from_file(str(file_path) + "/../resources/cached/milk.urdf", - "milk", ".urdf")) + file_extension = self.robot.description.get_file_extension() + pr2_path = str(file_path) + f"/../resources/robots/{self.robot.description.name}{file_extension}" + cache_path = self.world.cache_manager.cache_dir + cache_path = os.path.join(cache_path, f"{self.robot.description.name}.urdf") + self.robot.description.generate_from_description_file(pr2_path, cache_path) + self.assertTrue(self.world.cache_manager.is_cached(self.robot.name, self.robot.description)) diff --git a/test/test_error_checkers.py b/test/test_error_checkers.py new file mode 100644 index 000000000..63bf06416 --- /dev/null +++ b/test/test_error_checkers.py @@ -0,0 +1,131 @@ +from unittest import TestCase + +import numpy as np +from tf.transformations import quaternion_from_euler + +from pycram.datastructures.enums import JointType +from pycram.validation.error_checkers import calculate_angle_between_quaternions, \ + PoseErrorChecker, PositionErrorChecker, OrientationErrorChecker, RevoluteJointPositionErrorChecker, \ + PrismaticJointPositionErrorChecker, MultiJointPositionErrorChecker + +from pycram.datastructures.pose import Pose + + +class TestErrorCheckers(TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + + def tearDown(self): + pass + + def test_calculate_quaternion_error(self): + quat_1 = [0.0, 0.0, 0.0, 1.0] + quat_2 = [0.0, 0.0, 0.0, 1.0] + error = calculate_angle_between_quaternions(quat_1, quat_2) + self.assertEqual(error, 0.0) + quat_2 = quaternion_from_euler(0, 0, np.pi/2) + error = calculate_angle_between_quaternions(quat_1, quat_2) + self.assertEqual(error, np.pi/2) + + def test_pose_error_checker(self): + pose_1 = Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]) + pose_2 = Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]) + error_checker = PoseErrorChecker() + error = error_checker.calculate_error(pose_1, pose_2) + self.assertEqual(error, [0.0, 0.0]) + self.assertTrue(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, np.pi/2, 0) + pose_2 = Pose([0, 1, np.sqrt(3)], quat) + error = error_checker.calculate_error(pose_1, pose_2) + self.assertAlmostEqual(error[0], 2, places=5) + self.assertEqual(error[1], np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, 0, np.pi/360) + pose_2 = Pose([0, 0.0001, 0.0001], quat) + self.assertTrue(error_checker.is_error_acceptable(pose_1, pose_2)) + quat = quaternion_from_euler(0, 0, np.pi / 179) + pose_2 = Pose([0, 0.0001, 0.0001], quat) + self.assertFalse(error_checker.is_error_acceptable(pose_1, pose_2)) + + def test_position_error_checker(self): + position_1 = [0.0, 0.0, 0.0] + position_2 = [0.0, 0.0, 0.0] + error_checker = PositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = [1.0, 1.0, 1.0] + error = error_checker.calculate_error(position_1, position_2) + self.assertAlmostEqual(error, np.sqrt(3), places=5) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_orientation_error_checker(self): + quat_1 = [0.0, 0.0, 0.0, 1.0] + quat_2 = [0.0, 0.0, 0.0, 1.0] + error_checker = OrientationErrorChecker() + error = error_checker.calculate_error(quat_1, quat_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(quat_1, quat_2)) + quat_2 = quaternion_from_euler(0, 0, np.pi/2) + error = error_checker.calculate_error(quat_1, quat_2) + self.assertEqual(error, np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(quat_1, quat_2)) + + def test_revolute_joint_position_error_checker(self): + position_1 = 0.0 + position_2 = 0.0 + error_checker = RevoluteJointPositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = np.pi/2 + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_prismatic_joint_position_error_checker(self): + position_1 = 0.0 + position_2 = 0.0 + error_checker = PrismaticJointPositionErrorChecker() + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 0.0) + self.assertTrue(error_checker.is_error_acceptable(position_1, position_2)) + position_2 = 1.0 + error = error_checker.calculate_error(position_1, position_2) + self.assertEqual(error, 1.0) + self.assertFalse(error_checker.is_error_acceptable(position_1, position_2)) + + def test_list_of_poses_error_checker(self): + poses_1 = [Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]), + Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0])] + poses_2 = [Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]), + Pose([0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0])] + error_checker = PoseErrorChecker(is_iterable=True) + error = error_checker.calculate_error(poses_1, poses_2) + self.assertEqual(error, [[0.0, 0.0], [0.0, 0.0]]) + self.assertTrue(error_checker.is_error_acceptable(poses_1, poses_2)) + quat = quaternion_from_euler(0, np.pi/2, 0) + poses_2 = [Pose([0, 1, np.sqrt(3)], quat), + Pose([0, 1, np.sqrt(3)], quat)] + error = error_checker.calculate_error(poses_1, poses_2) + self.assertAlmostEqual(error[0][0], 2, places=5) + self.assertEqual(error[0][1], np.pi/2) + self.assertAlmostEqual(error[1][0], 2, places=5) + self.assertEqual(error[1][1], np.pi/2) + self.assertFalse(error_checker.is_error_acceptable(poses_1, poses_2)) + + def test_multi_joint_error_checker(self): + positions_1 = [0.0, 0.0] + positions_2 = [np.pi/2, 0.1] + joint_types = [JointType.REVOLUTE, JointType.PRISMATIC] + error_checker = MultiJointPositionErrorChecker(joint_types) + error = error_checker.calculate_error(positions_1, positions_2) + self.assertEqual(error, [np.pi/2, 0.1]) + self.assertFalse(error_checker.is_error_acceptable(positions_1, positions_2)) + positions_2 = [np.pi/180, 0.0001] + self.assertTrue(error_checker.is_error_acceptable(positions_1, positions_2)) diff --git a/test/test_failure_handling.py b/test/test_failure_handling.py index b28420f96..190a48922 100644 --- a/test/test_failure_handling.py +++ b/test/test_failure_handling.py @@ -7,7 +7,7 @@ from pycram.designators.action_designator import ParkArmsAction from pycram.datastructures.enums import ObjectType, Arms, WorldMode from pycram.failure_handling import Retry -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import ProcessModule, simulated_robot from pycram.robot_description import RobotDescription from pycram.object_descriptors.urdf import ObjectDescription @@ -33,8 +33,8 @@ class FailureHandlingTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.world = BulletWorld(WorldMode.DIRECT) - cls.robot = Object(RobotDescription.current_robot_description.name, ObjectType.ROBOT, RobotDescription.current_robot_description.name + extension, - ObjectDescription) + cls.robot = Object(RobotDescription.current_robot_description.name, ObjectType.ROBOT, + RobotDescription.current_robot_description.name + extension) ProcessModule.execution_delay = True def setUp(self): diff --git a/test/test_goal_validator.py b/test/test_goal_validator.py new file mode 100644 index 000000000..9b79cc114 --- /dev/null +++ b/test/test_goal_validator.py @@ -0,0 +1,321 @@ +import numpy as np +from tf.transformations import quaternion_from_euler +from typing_extensions import Optional, List + +from bullet_world_testcase import BulletWorldTestCase +from pycram.datastructures.enums import JointType +from pycram.datastructures.pose import Pose +from pycram.robot_description import RobotDescription +from pycram.validation.error_checkers import PoseErrorChecker, PositionErrorChecker, \ + OrientationErrorChecker, RevoluteJointPositionErrorChecker, PrismaticJointPositionErrorChecker, \ + MultiJointPositionErrorChecker +from pycram.validation.goal_validator import GoalValidator, PoseGoalValidator, \ + PositionGoalValidator, OrientationGoalValidator, JointPositionGoalValidator, MultiJointPositionGoalValidator, \ + MultiPoseGoalValidator, MultiPositionGoalValidator, MultiOrientationGoalValidator + + +class TestGoalValidator(BulletWorldTestCase): + + def test_single_pose_goal(self): + pose_goal_validators = PoseGoalValidator(self.milk.get_pose) + self.validate_pose_goal(pose_goal_validators) + + def test_single_pose_goal_generic(self): + pose_goal_validators = GoalValidator(PoseErrorChecker(), self.milk.get_pose) + self.validate_pose_goal(pose_goal_validators) + + def validate_pose_goal(self, goal_validator): + milk_goal_pose = Pose([1.3, 1.5, 0.9]) + goal_validator.register_goal(milk_goal_pose) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0.5, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 0, places=5) + self.milk.set_pose(milk_goal_pose) + self.assertEqual(self.milk.get_pose(), milk_goal_pose) + self.assertTrue(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 1) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 0, places=5) + + def test_single_position_goal_generic(self): + goal_validator = GoalValidator(PositionErrorChecker(), self.cereal.get_position_as_list) + self.validate_position_goal(goal_validator) + + def test_single_position_goal(self): + goal_validator = PositionGoalValidator(self.cereal.get_position_as_list) + self.validate_position_goal(goal_validator) + + def validate_position_goal(self, goal_validator): + cereal_goal_position = [1.3, 1.5, 0.95] + goal_validator.register_goal(cereal_goal_position) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, 0.8) + self.cereal.set_position(cereal_goal_position) + self.assertEqual(self.cereal.get_position_as_list(), cereal_goal_position) + self.assertTrue(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 1) + self.assertEqual(goal_validator.current_error, 0) + + def test_single_orientation_goal_generic(self): + goal_validator = GoalValidator(OrientationErrorChecker(), self.cereal.get_orientation_as_list) + self.validate_orientation_goal(goal_validator) + + def test_single_orientation_goal(self): + goal_validator = OrientationGoalValidator(self.cereal.get_orientation_as_list) + self.validate_orientation_goal(goal_validator) + + def validate_orientation_goal(self, goal_validator): + cereal_goal_orientation = quaternion_from_euler(0, 0, np.pi / 2) + goal_validator.register_goal(cereal_goal_orientation) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, [np.pi / 2]) + self.cereal.set_orientation(cereal_goal_orientation) + for v1, v2 in zip(self.cereal.get_orientation_as_list(), cereal_goal_orientation.tolist()): + self.assertAlmostEqual(v1, v2, places=5) + self.assertTrue(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, 1, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 0, places=5) + + def test_single_revolute_joint_position_goal_generic(self): + goal_validator = GoalValidator(RevoluteJointPositionErrorChecker(), self.robot.get_joint_position) + self.validate_revolute_joint_position_goal(goal_validator) + + def test_single_revolute_joint_position_goal(self): + goal_validator = JointPositionGoalValidator(self.robot.get_joint_position) + self.validate_revolute_joint_position_goal(goal_validator, JointType.REVOLUTE) + + def validate_revolute_joint_position_goal(self, goal_validator, joint_type: Optional[JointType] = None): + goal_joint_position = -np.pi / 4 + joint_name = 'l_shoulder_lift_joint' + if joint_type is not None: + goal_validator.register_goal(goal_joint_position, joint_type, joint_name) + else: + goal_validator.register_goal(goal_joint_position, joint_name) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, abs(goal_joint_position)) + + for percent in [0.5, 1]: + self.robot.set_joint_position('l_shoulder_lift_joint', goal_joint_position * percent) + self.assertEqual(self.robot.get_joint_position('l_shoulder_lift_joint'), goal_joint_position * percent) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + + def test_single_prismatic_joint_position_goal_generic(self): + goal_validator = GoalValidator(PrismaticJointPositionErrorChecker(), self.robot.get_joint_position) + self.validate_prismatic_joint_position_goal(goal_validator) + + def test_single_prismatic_joint_position_goal(self): + goal_validator = JointPositionGoalValidator(self.robot.get_joint_position) + self.validate_prismatic_joint_position_goal(goal_validator, JointType.PRISMATIC) + + def validate_prismatic_joint_position_goal(self, goal_validator, joint_type: Optional[JointType] = None): + goal_joint_position = 0.2 + torso = RobotDescription.current_robot_description.torso_joint + if joint_type is not None: + goal_validator.register_goal(goal_joint_position, joint_type, torso) + else: + goal_validator.register_goal(goal_joint_position, torso) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertEqual(goal_validator.current_error, abs(goal_joint_position)) + + for percent in [0.5, 1]: + self.robot.set_joint_position(torso, goal_joint_position * percent) + self.assertEqual(self.robot.get_joint_position(torso), goal_joint_position * percent) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + + def test_multi_joint_goal_generic(self): + joint_types = [JointType.PRISMATIC, JointType.REVOLUTE] + goal_validator = GoalValidator(MultiJointPositionErrorChecker(joint_types), + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_multi_joint_goal(goal_validator) + + def test_multi_joint_goal(self): + joint_types = [JointType.PRISMATIC, JointType.REVOLUTE] + goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_multi_joint_goal(goal_validator, joint_types) + + def validate_multi_joint_goal(self, goal_validator, joint_types: Optional[List[JointType]] = None): + goal_joint_positions = np.array([0.2, -np.pi / 4]) + joint_names = ['torso_lift_joint', 'l_shoulder_lift_joint'] + if joint_types is not None: + goal_validator.register_goal(goal_joint_positions, joint_types, joint_names) + else: + goal_validator.register_goal(goal_joint_positions, joint_names) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([0.2, abs(-np.pi / 4)]), atol=0.001)) + + for percent in [0.5, 1]: + current_joint_positions = goal_joint_positions * percent + self.robot.set_multiple_joint_positions(dict(zip(joint_names, current_joint_positions.tolist()))) + self.assertTrue(np.allclose(self.robot.get_joint_position('torso_lift_joint'), current_joint_positions[0], + atol=0.001)) + self.assertTrue( + np.allclose(self.robot.get_joint_position('l_shoulder_lift_joint'), current_joint_positions[1], + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(0.2) * (1 - percent), places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], abs(-np.pi / 4) * (1 - percent), places=5) + + def test_list_of_poses_goal_generic(self): + goal_validator = GoalValidator(PoseErrorChecker(is_iterable=True), + lambda: [self.robot.get_pose(), self.robot.get_pose()]) + self.validate_list_of_poses_goal(goal_validator) + + def test_list_of_poses_goal(self): + goal_validator = MultiPoseGoalValidator(lambda: [self.robot.get_pose(), self.robot.get_pose()]) + self.validate_list_of_poses_goal(goal_validator) + + def validate_list_of_poses_goal(self, goal_validator): + position_goal = [0.0, 1.0, 0.0] + orientation_goal = np.array([0, 0, np.pi / 2]) + poses_goal = [Pose(position_goal, quaternion_from_euler(*orientation_goal.tolist())), + Pose(position_goal, quaternion_from_euler(*orientation_goal.tolist()))] + goal_validator.register_goal(poses_goal) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue( + np.allclose(goal_validator.current_error, np.array([1.0, np.pi / 2, 1.0, np.pi / 2]), atol=0.001)) + + for percent in [0.5, 1]: + current_orientation_goal = orientation_goal * percent + current_pose_goal = Pose([0.0, 1.0 * percent, 0.0], + quaternion_from_euler(*current_orientation_goal.tolist())) + self.robot.set_pose(current_pose_goal) + self.assertTrue(np.allclose(self.robot.get_position_as_list(), current_pose_goal.position_as_list(), + atol=0.001)) + self.assertTrue(np.allclose(self.robot.get_orientation_as_list(), current_pose_goal.orientation_as_list(), + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 1 - percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], np.pi * (1 - percent) / 2, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[2], (1 - percent), places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[3], np.pi * (1 - percent) / 2, places=5) + + def test_list_of_positions_goal_generic(self): + goal_validator = GoalValidator(PositionErrorChecker(is_iterable=True), + lambda: [self.robot.get_position_as_list(), self.robot.get_position_as_list()]) + self.validate_list_of_positions_goal(goal_validator) + + def test_list_of_positions_goal(self): + goal_validator = MultiPositionGoalValidator(lambda: [self.robot.get_position_as_list(), + self.robot.get_position_as_list()]) + self.validate_list_of_positions_goal(goal_validator) + + def validate_list_of_positions_goal(self, goal_validator): + position_goal = [0.0, 1.0, 0.0] + positions_goal = [position_goal, position_goal] + goal_validator.register_goal(positions_goal) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([1.0, 1.0]), atol=0.001)) + + for percent in [0.5, 1]: + current_position_goal = [0.0, 1.0 * percent, 0.0] + self.robot.set_position(current_position_goal) + self.assertTrue(np.allclose(self.robot.get_position_as_list(), current_position_goal, atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], 1 - percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], 1 - percent, places=5) + + def test_list_of_orientations_goal_generic(self): + goal_validator = GoalValidator(OrientationErrorChecker(is_iterable=True), + lambda: [self.robot.get_orientation_as_list(), + self.robot.get_orientation_as_list()]) + self.validate_list_of_orientations_goal(goal_validator) + + def test_list_of_orientations_goal(self): + goal_validator = MultiOrientationGoalValidator(lambda: [self.robot.get_orientation_as_list(), + self.robot.get_orientation_as_list()]) + self.validate_list_of_orientations_goal(goal_validator) + + def validate_list_of_orientations_goal(self, goal_validator): + orientation_goal = np.array([0, 0, np.pi / 2]) + orientations_goals = [quaternion_from_euler(*orientation_goal.tolist()), + quaternion_from_euler(*orientation_goal.tolist())] + goal_validator.register_goal(orientations_goals) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, np.array([np.pi / 2, np.pi / 2]), atol=0.001)) + + for percent in [0.5, 1]: + current_orientation_goal = orientation_goal * percent + self.robot.set_orientation(quaternion_from_euler(*current_orientation_goal.tolist())) + self.assertTrue(np.allclose(self.robot.get_orientation_as_list(), + quaternion_from_euler(*current_orientation_goal.tolist()), + atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], np.pi * (1 - percent) / 2, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], np.pi * (1 - percent) / 2, places=5) + + def test_list_of_revolute_joint_positions_goal_generic(self): + goal_validator = GoalValidator(RevoluteJointPositionErrorChecker(is_iterable=True), + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_list_of_revolute_joint_positions_goal(goal_validator) + + def test_list_of_revolute_joint_positions_goal(self): + goal_validator = MultiJointPositionGoalValidator( + lambda x: list(self.robot.get_multiple_joint_positions(x).values())) + self.validate_list_of_revolute_joint_positions_goal(goal_validator, [JointType.REVOLUTE, JointType.REVOLUTE]) + + def validate_list_of_revolute_joint_positions_goal(self, goal_validator, + joint_types: Optional[List[JointType]] = None): + goal_joint_position = -np.pi / 4 + goal_joint_positions = np.array([goal_joint_position, goal_joint_position]) + joint_names = ['l_shoulder_lift_joint', 'r_shoulder_lift_joint'] + if joint_types is not None: + goal_validator.register_goal(goal_joint_positions, joint_types, joint_names) + else: + goal_validator.register_goal(goal_joint_positions, joint_names) + self.assertFalse(goal_validator.goal_achieved) + self.assertEqual(goal_validator.actual_percentage_of_goal_achieved, 0) + self.assertTrue(np.allclose(goal_validator.current_error, + np.array([abs(goal_joint_position), abs(goal_joint_position)]), atol=0.001)) + + for percent in [0.5, 1]: + current_joint_position = goal_joint_positions * percent + self.robot.set_multiple_joint_positions(dict(zip(joint_names, current_joint_position))) + self.assertTrue(np.allclose(list(self.robot.get_multiple_joint_positions(joint_names).values()), + current_joint_position, atol=0.001)) + if percent == 1: + self.assertTrue(goal_validator.goal_achieved) + else: + self.assertFalse(goal_validator.goal_achieved) + self.assertAlmostEqual(goal_validator.actual_percentage_of_goal_achieved, percent, places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[0], abs(goal_joint_position) * (1 - percent), + places=5) + self.assertAlmostEqual(goal_validator.current_error.tolist()[1], abs(goal_joint_position) * (1 - percent), + places=5) diff --git a/test/test_language.py b/test/test_language.py index bb3fec509..362db9c0e 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -4,8 +4,9 @@ from pycram.designators.action_designator import * from pycram.designators.object_designator import BelieveObject from pycram.datastructures.enums import ObjectType, State +from pycram.failure_handling import RetryMonitor from pycram.fluent import Fluent -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure, NotALanguageExpression from pycram.datastructures.pose import Pose from pycram.language import Sequential, Language, Parallel, TryAll, TryInOrder, Monitor, Code from pycram.process_module import simulated_robot @@ -115,6 +116,80 @@ def monitor_func(): self.assertRaises(AttributeError, lambda: Monitor(monitor_func) >> Monitor(monitor_func)) + def test_retry_monitor_construction(self): + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + + def monitor_func(): + time.sleep(1) + return True + + def recovery1(): + return + + recover1 = Code(lambda: recovery1()) + recovery = {NotALanguageExpression: recover1} + + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6, recovery=recovery) + self.assertEqual(len(plan.recovery), 1) + self.assertIsInstance(plan.designator_description, Monitor) + + def test_retry_monitor_tries(self): + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + tries_counter = 0 + + def monitor_func(): + nonlocal tries_counter + tries_counter += 1 + return True + + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6) + try: + plan.perform() + except PlanFailure as e: + pass + self.assertEqual(tries_counter, 6) + + def test_retry_monitor_recovery(self): + recovery1_counter = 0 + recovery2_counter = 0 + + def monitor_func(): + if not hasattr(monitor_func, 'tries_counter'): + monitor_func.tries_counter = 0 + if monitor_func.tries_counter % 2: + monitor_func.tries_counter += 1 + return NotALanguageExpression + monitor_func.tries_counter += 1 + return PlanFailure + + def recovery1(): + nonlocal recovery1_counter + recovery1_counter += 1 + + def recovery2(): + nonlocal recovery2_counter + recovery2_counter += 1 + + recover1 = Code(lambda: recovery1()) + recover2 = Code(lambda: recovery2()) + recovery = {NotALanguageExpression: recover1, + PlanFailure: recover2} + + act = ParkArmsAction([Arms.BOTH]) + act2 = MoveTorsoAction([0.3]) + subplan = act + act2 >> Monitor(monitor_func) + plan = RetryMonitor(subplan, max_tries=6, recovery=recovery) + try: + plan.perform() + except PlanFailure as e: + pass + self.assertEqual(recovery1_counter, 2) + self.assertEqual(recovery2_counter, 3) + def test_repeat_construction(self): act = ParkArmsAction([Arms.BOTH]) act2 = MoveTorsoAction([0.3]) @@ -196,7 +271,7 @@ def raise_except(): plan = act + code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) self.assertEqual(state, State.FAILED) @@ -209,7 +284,7 @@ def raise_except(): plan = act - code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) @@ -223,7 +298,7 @@ def raise_except(): plan = act | code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) @@ -237,7 +312,7 @@ def raise_except(): plan = act ^ code with simulated_robot: - state = plan.perform() + state, _ = plan.perform() self.assertIsInstance(plan.exceptions[plan], list) self.assertIsInstance(plan.exceptions[plan][0], PlanFailure) self.assertEqual(len(plan.exceptions.keys()), 1) diff --git a/test/test_mjcf.py b/test/test_mjcf.py new file mode 100644 index 000000000..edfbb842b --- /dev/null +++ b/test/test_mjcf.py @@ -0,0 +1,40 @@ +from unittest import TestCase, skipIf +from dm_control import mjcf +try: + from pycram.object_descriptors.mjcf import ObjectDescription as MJCFObjDesc +except ImportError: + MJCFObjDesc = None + + +@skipIf(MJCFObjDesc is None, "Multiverse not found.") +class TestMjcf(TestCase): + model: MJCFObjDesc + + @classmethod + def setUpClass(cls): + # Example usage + model = mjcf.RootElement("test") + + model.default.dclass = 'default' + + # Define a simple model with bodies and joints + body1 = model.worldbody.add('body', name='body1') + body2 = body1.add('body', name='body2') + joint1 = body2.add('joint', name='joint1', type='hinge') + + body3 = body2.add('body', name='body3') + joint2 = body3.add('joint', name='joint2', type='slide') + + cls.model = MJCFObjDesc() + cls.model.update_description_from_string(model.to_xml_string()) + + def test_child_map(self): + self.assertEqual(self.model.child_map, {'body1': [('joint1', 'body2')], 'body2': [('joint2', 'body3')]}) + + def test_parent_map(self): + self.assertEqual(self.model.parent_map, {'body2': ('joint1', 'body1'), 'body3': ('joint2', 'body2')}) + + def test_get_chain(self): + self.assertEqual(self.model.get_chain('body1', 'body3'), + ['body1', 'joint1', 'body2', 'joint2', 'body3']) + diff --git a/test/test_move_and_pick_up.py b/test/test_move_and_pick_up.py index 013a8708d..2c4268950 100644 --- a/test/test_move_and_pick_up.py +++ b/test/test_move_and_pick_up.py @@ -9,7 +9,7 @@ from pycram.designators.action_designator import MoveTorsoActionPerformable from pycram.designators.specialized_designators.probabilistic.probabilistic_action import (MoveAndPickUp, GaussianCostmapModel) -from pycram.plan_failures import PlanFailure +from pycram.failures import PlanFailure from pycram.process_module import simulated_robot diff --git a/test/test_multiverse.py b/test/test_multiverse.py new file mode 100644 index 000000000..3164de50a --- /dev/null +++ b/test/test_multiverse.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +import os +import unittest + +import numpy as np +import psutil +from tf.transformations import quaternion_from_euler, quaternion_multiply +from typing_extensions import Optional, List + +from pycram.datastructures.dataclasses import ContactPointsList, ContactPoint +from pycram.datastructures.enums import ObjectType, Arms, JointType +from pycram.datastructures.pose import Pose +from pycram.robot_description import RobotDescriptionManager +from pycram.world_concepts.world_object import Object +from pycram.validation.error_checkers import calculate_angle_between_quaternions +from pycram.helper import get_robot_mjcf_path, parse_mjcf_actuators + +multiverse_installed = True +try: + from pycram.worlds.multiverse import Multiverse +except ImportError: + multiverse_installed = False + +processes = psutil.process_iter() +process_names = [p.name() for p in processes] +multiverse_running = True +mujoco_running = True +if 'multiverse_server' not in process_names: + multiverse_running = False +if 'mujoco' not in process_names: + mujoco_running = False + + +@unittest.skipIf(not multiverse_installed, "Multiverse is not installed.") +@unittest.skipIf(not multiverse_running, "Multiverse server is not running.") +@unittest.skipIf(not mujoco_running, "Mujoco is not running.") +class MultiversePyCRAMTestCase(unittest.TestCase): + if multiverse_installed: + multiverse: Multiverse + big_bowl: Optional[Object] = None + + @classmethod + def setUpClass(cls): + if not multiverse_installed: + return + cls.multiverse = Multiverse() + + @classmethod + def tearDownClass(cls): + cls.multiverse.exit(remove_saved_states=True) + cls.multiverse.remove_multiverse_resources() + + def tearDown(self): + self.multiverse.remove_all_objects() + + def test_spawn_xml_object(self): + bread = Object("bread_1", ObjectType.GENERIC_OBJECT, "bread_1.xml", pose=Pose([1, 1, 0.1])) + self.assert_poses_are_equal(bread.get_pose(), Pose([1, 1, 0.1])) + + def test_spawn_mesh_object(self): + milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1, 1, 0.1])) + self.assert_poses_are_equal(milk.get_pose(), Pose([1, 1, 0.1])) + self.multiverse.simulate(0.2) + contact_points = milk.contact_points() + self.assertTrue(len(contact_points) > 0) + + def test_parse_mjcf_actuators(self): + mjcf_file = get_robot_mjcf_path("pal_robotics", "tiago_dual") + self.assertTrue(os.path.exists(mjcf_file)) + joint_actuators = parse_mjcf_actuators(mjcf_file) + self.assertIsInstance(joint_actuators, dict) + self.assertTrue(len(joint_actuators) > 0) + self.assertTrue("arm_left_1_joint" in joint_actuators) + self.assertTrue("arm_right_1_joint" in joint_actuators) + self.assertTrue(joint_actuators["arm_right_1_joint"] == "arm_right_1_actuator") + + def test_get_actuator_for_joint(self): + robot = self.spawn_robot() + joint_name = "arm_right_1_joint" + actuator_name = robot.get_actuator_for_joint(robot.joints[joint_name]) + self.assertEqual(actuator_name, "arm_right_1_actuator") + + def test_get_images_for_target(self): + robot = self.spawn_robot(robot_name='pr2') + camera_description = self.multiverse.robot_description.get_default_camera() + camera_link_name = camera_description.link_name + camera_pose = robot.get_link_pose(camera_link_name) + camera_frame = self.multiverse.robot_description.get_camera_frame() + camera_front_facing_axis = camera_description.front_facing_axis + milk_spawn_position = np.array(camera_front_facing_axis) * 0.5 + orientation = camera_pose.to_transform(camera_frame).invert().rotation_as_list() + milk = self.spawn_milk(milk_spawn_position.tolist(), orientation, frame=camera_frame) + _, depth, segmentation_mask = self.multiverse.get_images_for_target(milk.pose, camera_pose, plot=False) + self.assertIsInstance(depth, np.ndarray) + self.assertIsInstance(segmentation_mask, np.ndarray) + self.assertTrue(depth.shape == (256, 256)) + self.assertTrue(segmentation_mask.shape == (256, 256)) + self.assertTrue(milk.id in np.unique(segmentation_mask).flatten().tolist()) + avg_depth_of_milk = np.mean(depth[segmentation_mask == milk.id]) + self.assertAlmostEqual(avg_depth_of_milk, 0.5, delta=0.1) + + def test_reset_world(self): + set_position = [1, 1, 0.1] + milk = self.spawn_milk(set_position) + milk.set_position(set_position) + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], set_position[:2], delta=self.multiverse.conf.position_tolerance) + self.multiverse.reset_world() + milk_pose = milk.get_pose() + self.assert_list_is_equal(milk_pose.position_as_list()[:2], + milk.original_pose.position_as_list()[:2], + delta=self.multiverse.conf.position_tolerance) + self.assert_orientation_is_equal(milk_pose.orientation_as_list(), milk.original_pose.orientation_as_list()) + + def test_spawn_robot_with_actuators_directly_from_multiverse(self): + if self.multiverse.conf.use_controller: + robot_name = "tiago_dual" + rdm = RobotDescriptionManager() + rdm.load_description(robot_name) + self.multiverse.spawn_robot_with_controller(robot_name, Pose([-2, -2, 0.001])) + + def test_spawn_object(self): + milk = self.spawn_milk([1, 1, 0.1]) + self.assertIsInstance(milk, Object) + milk_pose = milk.get_pose() + self.assert_list_is_equal(milk_pose.position_as_list()[:2], [1, 1], + delta=self.multiverse.conf.position_tolerance) + self.assert_orientation_is_equal(milk_pose.orientation_as_list(), milk.original_pose.orientation_as_list()) + + def test_remove_object(self): + milk = self.spawn_milk([1, 1, 0.1]) + milk.remove() + self.assertTrue(milk not in self.multiverse.objects) + self.assertFalse(self.multiverse.check_object_exists(milk)) + + def test_check_object_exists(self): + milk = self.spawn_milk([1, 1, 0.1]) + self.assertTrue(self.multiverse.check_object_exists(milk)) + + def test_set_position(self): + milk = self.spawn_milk([1, 1, 0.1]) + original_milk_position = milk.get_position_as_list() + original_milk_position[0] += 1 + milk.set_position(original_milk_position) + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], original_milk_position[:2], + delta=self.multiverse.conf.position_tolerance) + + def test_update_position(self): + milk = self.spawn_milk([1, 1, 0.1]) + milk.update_pose() + milk_position = milk.get_position_as_list() + self.assert_list_is_equal(milk_position[:2], [1, 1], delta=self.multiverse.conf.position_tolerance) + + def test_set_joint_position(self): + if self.multiverse.robot is None: + robot = self.spawn_robot() + else: + robot = self.multiverse.robot + step = 0.2 + for joint in ['torso_lift_joint']: + joint_type = robot.joints[joint].type + original_joint_position = robot.get_joint_position(joint) + robot.set_joint_position(joint, original_joint_position + step) + joint_position = robot.get_joint_position(joint) + if not self.multiverse.conf.use_controller: + delta = self.multiverse.conf.prismatic_joint_position_tolerance if joint_type == JointType.PRISMATIC \ + else self.multiverse.conf.revolute_joint_position_tolerance + else: + delta = 0.18 + self.assertAlmostEqual(joint_position, original_joint_position + step, delta=delta) + + def test_spawn_robot(self): + if self.multiverse.robot is not None: + robot = self.multiverse.robot + else: + robot = self.spawn_robot(robot_name="pr2") + self.assertIsInstance(robot, Object) + self.assertTrue(robot in self.multiverse.objects) + self.assertTrue(self.multiverse.robot.name == robot.name) + + def test_destroy_robot(self): + if self.multiverse.robot is None: + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + self.multiverse.robot.remove() + self.assertTrue(self.multiverse.robot not in self.multiverse.objects) + + def test_respawn_robot(self): + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + self.multiverse.robot.remove() + self.assertTrue(self.multiverse.robot not in self.multiverse.objects) + self.spawn_robot() + self.assertTrue(self.multiverse.robot in self.multiverse.objects) + + def test_set_robot_position(self): + step = -1 + for i in range(3): + self.spawn_robot() + new_position = [-3 + step * i, -3 + step * i, 0.001] + self.multiverse.robot.set_position(new_position) + robot_position = self.multiverse.robot.get_position_as_list() + self.assert_list_is_equal(robot_position[:2], new_position[:2], + delta=self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_set_robot_orientation(self): + self.spawn_robot() + for i in range(3): + current_quaternion = self.multiverse.robot.get_orientation_as_list() + # rotate by 45 degrees without using euler angles + rotation_quaternion = quaternion_from_euler(0, 0, np.pi / 4) + new_quaternion = quaternion_multiply(current_quaternion, rotation_quaternion) + self.multiverse.robot.set_orientation(new_quaternion) + robot_orientation = self.multiverse.robot.get_orientation_as_list() + quaternion_difference = calculate_angle_between_quaternions(new_quaternion, robot_orientation) + self.assertAlmostEqual(quaternion_difference, 0, delta=self.multiverse.conf.orientation_tolerance) + + def test_set_robot_pose(self): + self.spawn_robot(orientation=quaternion_from_euler(0, 0, np.pi / 4)) + position_step = -1 + angle_step = np.pi / 4 + num_steps = 10 + self.step_robot_pose(self.multiverse.robot, position_step, angle_step, num_steps) + position_step = 1 + angle_step = -np.pi / 4 + self.step_robot_pose(self.multiverse.robot, position_step, angle_step, num_steps) + + def step_robot_pose(self, robot, position_step, angle_step, num_steps): + original_position = robot.get_position_as_list() + original_orientation = robot.get_orientation_as_list() + for i in range(num_steps): + new_position = [original_position[0] + position_step * (i + 1), + original_position[1] + position_step * (i + 1), original_position[2]] + rotation_quaternion = quaternion_from_euler(0, 0, angle_step * (i + 1)) + new_quaternion = quaternion_multiply(original_orientation, rotation_quaternion) + new_pose = Pose(new_position, new_quaternion) + self.multiverse.robot.set_pose(new_pose) + robot_pose = self.multiverse.robot.get_pose() + self.assert_poses_are_equal(new_pose, robot_pose, + position_delta=self.multiverse.conf.position_tolerance, + orientation_delta=self.multiverse.conf.orientation_tolerance) + + def test_get_environment_pose(self): + apartment = Object("apartment", ObjectType.ENVIRONMENT, f"apartment.urdf") + pose = apartment.get_pose() + self.assertIsInstance(pose, Pose) + + def test_attach_object(self): + for _ in range(3): + milk = self.spawn_milk([1, 0.1, 0.1]) + cup = self.spawn_cup([1, 1.1, 0.1]) + milk.attach(cup) + self.assertTrue(cup in milk.attachments) + milk_position = milk.get_position_as_list() + milk_position[0] += 1 + cup_position = cup.get_position_as_list() + estimated_cup_position = cup_position.copy() + estimated_cup_position[0] += 1 + milk.set_position(milk_position) + new_cup_position = cup.get_position_as_list() + self.assert_list_is_equal(new_cup_position[:2], estimated_cup_position[:2], + self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_detach_object(self): + for i in range(2): + milk = self.spawn_milk([1, 0, 0.1]) + cup = self.spawn_cup([1, 1, 0.1]) + milk.attach(cup) + self.assertTrue(cup in milk.attachments) + milk.detach(cup) + self.assertTrue(cup not in milk.attachments) + milk_position = milk.get_position_as_list() + milk_position[0] += 1 + cup_position = cup.get_position_as_list() + estimated_cup_position = cup_position.copy() + milk.set_position(milk_position) + new_milk_position = milk.get_position_as_list() + new_cup_position = cup.get_position_as_list() + self.assert_list_is_equal(new_milk_position[:2], milk_position[:2], + self.multiverse.conf.position_tolerance) + self.assert_list_is_equal(new_cup_position[:2], estimated_cup_position[:2], + self.multiverse.conf.position_tolerance) + self.tearDown() + + def test_attach_with_robot(self): + milk = self.spawn_milk([-1, -1, 0.1]) + robot = self.spawn_robot() + ee_link = self.multiverse.get_arm_tool_frame_link(Arms.RIGHT) + # Get position of milk relative to robot end effector + robot.attach(milk, ee_link.name, coincide_the_objects=False) + self.assertTrue(robot in milk.attachments) + milk_initial_pose = milk.root_link.get_pose_wrt_link(ee_link) + robot_position = 1.57 + robot.set_joint_position("arm_right_2_joint", robot_position) + milk_pose = milk.root_link.get_pose_wrt_link(ee_link) + self.assert_poses_are_equal(milk_initial_pose, milk_pose) + + def test_get_object_contact_points(self): + for i in range(10): + milk = self.spawn_milk([1, 1, 0.01], [0, -0.707, 0, 0.707]) + contact_points = self.multiverse.get_object_contact_points(milk) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_b.object, self.multiverse.floor) + cup = self.spawn_cup([1, 1, 0.15]) + # This is needed because the cup is spawned in the air, so it needs to fall + # to get in contact with the milk + self.multiverse.simulate(0.3) + contact_points = self.multiverse.get_object_contact_points(cup) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_b.object, milk) + self.tearDown() + + def test_get_contact_points_between_two_objects(self): + for i in range(3): + milk = self.spawn_milk([1, 1, 0.01], [0, -0.707, 0, 0.707]) + cup = self.spawn_cup([1, 1, 0.15]) + # This is needed because the cup is spawned in the air so it needs to fall + # to get in contact with the milk + self.multiverse.simulate(0.3) + contact_points = self.multiverse.get_contact_points_between_two_objects(milk, cup) + self.assertIsInstance(contact_points, ContactPointsList) + self.assertEqual(len(contact_points), 1) + self.assertIsInstance(contact_points[0], ContactPoint) + self.assertTrue(contact_points[0].link_a.object, milk) + self.assertTrue(contact_points[0].link_b.object, cup) + self.tearDown() + + def test_get_one_ray(self): + milk = self.spawn_milk([1, 1, 0.1]) + intersected_object = self.multiverse.ray_test([1, 2, 0.1], [1, 1.5, 0.1]) + self.assertTrue(intersected_object is None) + intersected_object = self.multiverse.ray_test([1, 2, 0.1], [1, 1, 0.1]) + self.assertTrue(intersected_object == milk.id) + + def test_get_rays(self): + milk = self.spawn_milk([1, 1, 0.1]) + intersected_objects = self.multiverse.ray_test_batch([[1, 2, 0.1], [1, 2, 0.1]], + [[1, 1.5, 0.1], [1, 1, 0.1]]) + self.assertTrue(intersected_objects[0][0] == -1) + self.assertTrue(intersected_objects[1][0] == milk.id) + + @staticmethod + def spawn_big_bowl() -> Object: + big_bowl = Object("big_bowl", ObjectType.GENERIC_OBJECT, "BigBowl.obj", + pose=Pose([2, 2, 0.1], [0, 0, 0, 1])) + return big_bowl + + @staticmethod + def spawn_milk(position: List, orientation: Optional[List] = None, frame="map") -> Object: + if orientation is None: + orientation = [0, 0, 0, 1] + milk = Object("milk_box", ObjectType.MILK, "milk_box.xml", + pose=Pose(position, orientation, frame=frame)) + return milk + + def spawn_robot(self, position: Optional[List[float]] = None, + orientation: Optional[List[float]] = None, + robot_name: Optional[str] = 'tiago_dual', + replace: Optional[bool] = True) -> Object: + if position is None: + position = [-2, -2, 0.001] + if orientation is None: + orientation = [0, 0, 0, 1] + if self.multiverse.robot is None or replace: + if self.multiverse.robot is not None: + self.multiverse.robot.remove() + robot = Object(robot_name, ObjectType.ROBOT, f"{robot_name}.urdf", + pose=Pose(position, [0, 0, 0, 1])) + else: + robot = self.multiverse.robot + robot.set_position(position) + return robot + + @staticmethod + def spawn_cup(position: List) -> Object: + cup = Object("cup", ObjectType.GENERIC_OBJECT, "Cup.obj", + pose=Pose(position, [0, 0, 0, 1])) + return cup + + def assert_poses_are_equal(self, pose1: Pose, pose2: Pose, + position_delta: Optional[float] = None, orientation_delta: Optional[float] = None): + if position_delta is None: + position_delta = self.multiverse.conf.position_tolerance + if orientation_delta is None: + orientation_delta = self.multiverse.conf.orientation_tolerance + self.assert_position_is_equal(pose1.position_as_list(), pose2.position_as_list(), delta=position_delta) + self.assert_orientation_is_equal(pose1.orientation_as_list(), pose2.orientation_as_list(), + delta=orientation_delta) + + def assert_position_is_equal(self, position1: List[float], position2: List[float], delta: Optional[float] = None): + if delta is None: + delta = self.multiverse.conf.position_tolerance + self.assert_list_is_equal(position1, position2, delta=delta) + + def assert_orientation_is_equal(self, orientation1: List[float], orientation2: List[float], + delta: Optional[float] = None): + if delta is None: + delta = self.multiverse.conf.orientation_tolerance + self.assertAlmostEqual(calculate_angle_between_quaternions(orientation1, orientation2), 0, delta=delta) + + def assert_list_is_equal(self, list1: List, list2: List, delta: float): + for i in range(len(list1)): + self.assertAlmostEqual(list1[i], list2[i], delta=delta) diff --git a/test/test_object.py b/test/test_object.py index 74c22f7b8..bede0300b 100644 --- a/test/test_object.py +++ b/test/test_object.py @@ -5,15 +5,18 @@ from pycram.datastructures.enums import JointType, ObjectType from pycram.datastructures.pose import Pose from pycram.datastructures.dataclasses import Color +from pycram.failures import UnsupportedFileExtension from pycram.world_concepts.world_object import Object +from pycram.object_descriptors.generic import ObjectDescription as GenericObjectDescription from geometry_msgs.msg import Point, Quaternion import pathlib + class TestObject(BulletWorldTestCase): def test_wrong_object_description_path(self): - with self.assertRaises(FileNotFoundError): + with self.assertRaises(UnsupportedFileExtension): milk = Object("milk_not_found", ObjectType.MILK, "wrong_path.sk") def test_malformed_object_description(self): @@ -160,3 +163,12 @@ def test_object_equal(self): self.assertEqual(self.milk, self.milk) self.assertNotEqual(self.milk, self.cereal) self.assertNotEqual(self.milk, self.world) + + +class GenericObjectTestCase(BulletWorldTestCase): + + def test_init_generic_object(self): + gen_obj_desc = GenericObjectDescription("robokudo_object", [0,0,0], [0.1, 0.1, 0.1]) + obj = Object("robokudo_object", ObjectType.MILK, None, gen_obj_desc) + pose = obj.get_pose() + self.assertTrue(isinstance(pose, Pose)) diff --git a/test/test_ontology.py b/test/test_ontology.py index 022d9021b..f4754d4b2 100644 --- a/test/test_ontology.py +++ b/test/test_ontology.py @@ -28,9 +28,10 @@ from pycram.ontology.ontology import OntologyManager, SOMA_HOME_ONTOLOGY_IRI, SOMA_ONTOLOGY_IRI from pycram.ontology.ontology_common import (OntologyConceptHolderStore, OntologyConceptHolder, - ONTOLOGY_SQL_BACKEND_FILE_EXTENSION, ONTOLOGY_OWL_FILE_EXTENSION) - + ONTOLOGY_SQL_BACKEND_FILE_EXTENSION, ONTOLOGY_OWL_FILE_EXTENSION, + ONTOLOGY_SQL_IN_MEMORY_BACKEND) +DEFAULT_LOCAL_ONTOLOGY_IRI = "default.owl" class TestOntologyManager(unittest.TestCase): ontology_manager: OntologyManager main_ontology: Optional[owlready2.Ontology] @@ -39,23 +40,29 @@ class TestOntologyManager(unittest.TestCase): @classmethod def setUpClass(cls): - cls.ontology_manager = OntologyManager(SOMA_ONTOLOGY_IRI) + # Try loading from remote `SOMA_ONTOLOGY_IRI`, which will fail given no internet access + cls.ontology_manager = OntologyManager(main_ontology_iri=SOMA_ONTOLOGY_IRI, + main_sql_backend_filename=os.path.join(Path.home(), + f"{Path(SOMA_ONTOLOGY_IRI).stem}{ONTOLOGY_SQL_BACKEND_FILE_EXTENSION}")) if cls.ontology_manager.initialized(): - cls.main_ontology = cls.ontology_manager.main_ontology cls.soma = cls.ontology_manager.soma cls.dul = cls.ontology_manager.dul else: - cls.main_ontology = None + # Else, load from `DEFAULT_LOCAL_ONTOLOGY_IRI` cls.soma = None cls.dul = None + cls.ontology_manager.main_ontology_iri = DEFAULT_LOCAL_ONTOLOGY_IRI + cls.ontology_manager.main_ontology_sql_backend = ONTOLOGY_SQL_IN_MEMORY_BACKEND + cls.ontology_manager.create_main_ontology_world() + cls.ontology_manager.create_main_ontology() + cls.main_ontology = cls.ontology_manager.main_ontology @classmethod def tearDownClass(cls): save_dir = cls.ontology_manager.get_main_ontology_dir() owl_filepath = f"{save_dir}/{Path(cls.ontology_manager.main_ontology_iri).stem}{ONTOLOGY_OWL_FILE_EXTENSION}" - sql_filepath = f"{save_dir}/{Path(owl_filepath).stem}{ONTOLOGY_SQL_BACKEND_FILE_EXTENSION}" os.remove(owl_filepath) - cls.remove_sql_file(sql_filepath) + cls.remove_sql_file(cls.ontology_manager.main_ontology_sql_backend) @classmethod def remove_sql_file(cls, sql_filepath: str): @@ -234,7 +241,7 @@ def test_ontology_reasoning(self): ontology_property_parent_class=owlready2.ObjectProperty, ontology=reasoning_ontology)) - # Define rules for "bigger_than" in [reasoning_ontology] + # Define rules for `transportability` & `co-residence` in [reasoning_ontology] with reasoning_ontology: def can_transport_itself(a: reasoning_ontology.Entity) -> bool: return a in a.can_transport @@ -295,11 +302,11 @@ def coresidents(a: reasoning_ontology.Entity, b: reasoning_ontology.Entity) -> b def test_ontology_save(self): save_dir = self.ontology_manager.get_main_ontology_dir() owl_filepath = f"{save_dir}/{Path(self.ontology_manager.main_ontology_iri).stem}{ONTOLOGY_OWL_FILE_EXTENSION}" - sql_filepath = f"{save_dir}/{Path(owl_filepath).stem}{ONTOLOGY_SQL_BACKEND_FILE_EXTENSION}" self.assertTrue(self.ontology_manager.save(owl_filepath)) self.assertTrue(Path(owl_filepath).is_file()) - self.assertTrue(Path(sql_filepath).is_file()) - + sql_backend = self.ontology_manager.main_ontology_sql_backend + if sql_backend != ONTOLOGY_SQL_IN_MEMORY_BACKEND: + self.assertTrue(Path(sql_backend).is_file()) if __name__ == '__main__': unittest.main() diff --git a/test/test_orm.py b/test/test_orm.py index 482ce14e3..6609e3992 100644 --- a/test/test_orm.py +++ b/test/test_orm.py @@ -1,5 +1,7 @@ import os +import time import unittest +import time from sqlalchemy import select import sqlalchemy.orm import pycram.orm.action_designator @@ -9,18 +11,20 @@ import pycram.orm.tasktree import pycram.tasktree from bullet_world_testcase import BulletWorldTestCase +from pycram.datastructures.dataclasses import Color +from pycram.ontology.ontology import OntologyManager, SOMA_ONTOLOGY_IRI +from pycram.ros_utils.viz_marker_publisher import VizMarkerPublisher from pycram.world_concepts.world_object import Object from pycram.designators import action_designator, object_designator, motion_designator -from pycram.designators.action_designator import ParkArmsActionPerformable, MoveTorsoActionPerformable, \ - SetGripperActionPerformable, PickUpActionPerformable, NavigateActionPerformable, TransportActionPerformable, \ - OpenActionPerformable, CloseActionPerformable, DetectActionPerformable, LookAtActionPerformable -from pycram.designators.object_designator import BelieveObject -from pycram.datastructures.enums import ObjectType +from pycram.designators.action_designator import * +from pycram.designators.object_designator import BelieveObject, ObjectPart +from pycram.datastructures.enums import ObjectType, WorldMode from pycram.datastructures.pose import Pose from pycram.process_module import simulated_robot -from pycram.tasktree import with_tree +from pycram.tasktree import with_tree, task_tree from pycram.orm.views import PickUpWithContextView -from pycram.datastructures.enums import Arms, Grasp, GripperState +from pycram.datastructures.enums import Arms, Grasp, GripperState, ObjectType +from pycram.worlds.bullet_world import BulletWorld class DatabaseTestCaseMixin(BulletWorldTestCase): @@ -39,7 +43,6 @@ def setUp(self): def tearDown(self): super().tearDown() - pycram.tasktree.reset_tree() pycram.orm.base.ProcessMetaData.reset() pycram.orm.base.Base.metadata.drop_all(self.engine) self.session.close() @@ -56,7 +59,7 @@ def test_schema_creation(self): self.assertTrue("NavigateAction" in tables) self.assertTrue("MoveTorsoAction" in tables) self.assertTrue("SetGripperAction" in tables) - self.assertTrue("Release" in tables) + self.assertTrue("ReleaseAction" in tables) self.assertTrue("GripAction" in tables) self.assertTrue("PickUpAction" in tables) self.assertTrue("PlaceAction" in tables) @@ -183,7 +186,7 @@ def test_plan_serialization(self): PickUpActionPerformable(object_description.resolve(), Arms.LEFT, Grasp.FRONT).perform() description.resolve().perform() pycram.orm.base.ProcessMetaData().description = "Unittest" - tt = pycram.tasktree.task_tree + tt = pycram.tasktree.task_tree.root tt.insert(self.session) action_results = self.session.scalars(select(pycram.orm.action_designator.Action)).all() motion_results = self.session.scalars(select(pycram.orm.motion_designator.Motion)).all() @@ -226,6 +229,24 @@ def test_transportAction(self): milk_object = self.session.scalars(select(pycram.orm.object_designator.Object)).first() self.assertEqual(milk_object.pose, result[0].object.pose) + def test_pickUpAction(self): + object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) + previous_position = object_description.resolve().pose + with simulated_robot: + NavigateActionPerformable(Pose([0.6, 0.4, 0], [0, 0, 0, 1])).perform() + PickUpActionPerformable(object_description.resolve(), Arms.LEFT, Grasp.FRONT).perform() + NavigateActionPerformable(Pose([1.3, 1, 0.9], [0, 0, 0, 1])).perform() + PlaceActionPerformable(object_description.resolve(), Arms.LEFT, Pose([2.0, 1.6, 1.8], [0, 0, 0, 1])).perform() + pycram.orm.base.ProcessMetaData().description = "pickUpAction_test" + pycram.tasktree.task_tree.root.insert(self.session) + result = self.session.scalars(select(pycram.orm.base.Position) + .join(pycram.orm.action_designator.PickUpAction.object) + .join(pycram.orm.object_designator.Object.pose) + .join(pycram.orm.base.Pose.position)).first() + self.assertEqual(result.x, previous_position.position.x) + self.assertEqual(result.y, previous_position.position.y) + self.assertEqual(result.z, previous_position.position.z) + def test_lookAt_and_detectAction(self): object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) action = DetectActionPerformable(object_description.resolve()) @@ -252,7 +273,7 @@ def test_setGripperAction(self): def test_open_and_closeAction(self): apartment = Object("apartment", ObjectType.ENVIRONMENT, "apartment.urdf") apartment_desig = BelieveObject(names=["apartment"]).resolve() - handle_desig = object_designator.ObjectPart(names=["handle_cab10_t"], part_of=apartment_desig).resolve() + handle_desig = object_designator.ObjectPart(names=["handle_cab10_t"], part_of=apartment_desig, type=ObjectType.ENVIRONMENT).resolve() self.kitchen.set_pose(Pose([20, 20, 0], [0, 0, 0, 1])) @@ -274,7 +295,61 @@ def test_open_and_closeAction(self): apartment.remove() +class BelieveObjectTestCase(unittest.TestCase): + engine: sqlalchemy.engine + session: sqlalchemy.orm.Session + + @classmethod + def setUpClass(cls): + cls.engine = sqlalchemy.create_engine("sqlite+pysqlite:///:memory:", echo=False) + environment_path = "apartment.urdf" + cls.world = BulletWorld(WorldMode.DIRECT) + cls.robot = Object("pr2", ObjectType.ROBOT, path="pr2.urdf", pose=Pose([1, 2, 0])) + cls.apartment = Object(environment_path[:environment_path.find(".")], ObjectType.ENVIRONMENT, environment_path) + cls.milk = Object("milk", ObjectType.MILK, "milk.stl", pose=Pose([1, -1.78, 0.55], [1, 0, 0, 0]), + color=Color(1, 0, 0, 1)) + cls.viz_marker_publisher = VizMarkerPublisher() + OntologyManager(SOMA_ONTOLOGY_IRI) + + def setUp(self): + self.world.reset_world() + pycram.orm.base.Base.metadata.create_all(self.engine) + self.session = sqlalchemy.orm.Session(bind=self.engine) + + def tearDown(self): + pycram.tasktree.task_tree.reset_tree() + time.sleep(0.05) + pycram.orm.base.ProcessMetaData.reset() + pycram.orm.base.Base.metadata.drop_all(self.engine) + self.session.close() + self.world.reset_world() + + @classmethod + def tearDownClass(cls): + cls.viz_marker_publisher._stop_publishing() + cls.world.exit() + + def test_believe_object(self): + # TODO: Find better way to separate BelieveObject no pose from Object pose + + with simulated_robot: + ParkArmsAction([Arms.BOTH]).resolve().perform() + + MoveTorsoAction([0.25]).resolve().perform() + NavigateAction(target_locations=[Pose([2, -1.89, 0])]).resolve().perform() + + LookAtAction(targets=[Pose([1, -1.78, 0.55])]).resolve().perform() + + object_desig = DetectAction(BelieveObject(types=[ObjectType.MILK])).resolve().perform() + TransportAction(object_desig, [Arms.LEFT], [Pose([4.8, 3.55, 0.8])]).resolve().perform() + + ParkArmsAction([Arms.BOTH]).resolve().perform() + pycram.orm.base.ProcessMetaData().description = "BelieveObject_test" + task_tree.root.insert(self.session) + + class ViewsSchemaTest(DatabaseTestCaseMixin): + def test_view_creation(self): pycram.orm.base.ProcessMetaData().description = "view_creation_test" pycram.tasktree.task_tree.root.insert(self.session) @@ -287,14 +362,16 @@ def test_view_creation(self): self.assertEqual(view.__table__.columns[3].name, "torso_height") self.assertEqual(view.__table__.columns[4].name, "relative_x") self.assertEqual(view.__table__.columns[5].name, "relative_y") - self.assertEqual(view.__table__.columns[6].name, "quaternion_x") - self.assertEqual(view.__table__.columns[7].name, "quaternion_y") - self.assertEqual(view.__table__.columns[8].name, "quaternion_z") - self.assertEqual(view.__table__.columns[9].name, "quaternion_w") + self.assertEqual(view.__table__.columns[6].name, "x") + self.assertEqual(view.__table__.columns[7].name, "y") + self.assertEqual(view.__table__.columns[8].name, "z") + self.assertEqual(view.__table__.columns[9].name, "w") self.assertEqual(view.__table__.columns[10].name, "obj_type") self.assertEqual(view.__table__.columns[11].name, "status") def test_pickUpWithContextView(self): + if self.engine.dialect.name == "sqlite": + return object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.PlaceAction(object_description, [Pose([1.3, 1, 0.9], [0, 0, 0, 1])], [Arms.LEFT]) self.assertEqual(description.ground().object_designator.name, "milk") @@ -315,6 +392,8 @@ def test_pickUpWithContextView(self): self.assertEqual(result.quaternion_w, 1) def test_pickUpWithContextView_conditions(self): + if self.engine.dialect.name == "sqlite": + return object_description = object_designator.ObjectDesignatorDescription(names=["milk"]) description = action_designator.PlaceAction(object_description, [Pose([1.3, 1, 0.9], [0, 0, 0, 1])], [Arms.LEFT]) self.assertEqual(description.ground().object_designator.name, "milk") diff --git a/test/test_robot_description.py b/test/test_robot_description.py index 3e3985054..e845d9ec6 100644 --- a/test/test_robot_description.py +++ b/test/test_robot_description.py @@ -3,7 +3,7 @@ from pycram.robot_description import RobotDescription, KinematicChainDescription, EndEffectorDescription, \ CameraDescription, RobotDescriptionManager from pycram.datastructures.enums import Arms, GripperState -from urdf_parser_py.urdf import URDF +from pycram.object_descriptors.urdf import ObjectDescription as URDF class TestRobotDescription(unittest.TestCase): @@ -11,7 +11,8 @@ class TestRobotDescription(unittest.TestCase): @classmethod def setUpClass(cls): cls.path = str(pathlib.Path(__file__).parent.resolve()) + '/../resources/robots/' + "pr2" + '.urdf' - cls.urdf_obj = URDF.from_xml_file(cls.path) + cls.path_turtlebot = str(pathlib.Path(__file__).parent.resolve()) + '/../resources/robots/' + "turtlebot" + '.urdf' + cls.urdf_obj = URDF(cls.path) def test_robot_description_construct(self): robot_description = RobotDescription("pr2", "base_link", "torso_lift_link", "torso_lift_joint", self.path) @@ -190,3 +191,13 @@ def test_load_robot_description(self): rdm.register_description(robot_description) rdm.load_description("pr2_test2") self.assertIs(RobotDescription.current_robot_description, robot_description) + + def test_robot_description_turtlebot(self): + robot_description = RobotDescription("turtlebot", "base_link", "base_link", "base_joint", self.path_turtlebot) + self.assertEqual(robot_description.name, "turtlebot") + self.assertEqual(robot_description.base_link, "base_link") + self.assertEqual(robot_description.torso_link, "base_link") + self.assertEqual(robot_description.torso_joint, "base_joint") + self.assertTrue(type(robot_description.urdf_object) is URDF) + self.assertEqual(len(robot_description.links), 11) + self.assertEqual(len(robot_description.joints), 10) diff --git a/test/test_task_tree.py b/test/test_task_tree.py index dee20a698..01bda73c8 100644 --- a/test/test_task_tree.py +++ b/test/test_task_tree.py @@ -8,7 +8,7 @@ import unittest import anytree from bullet_world_testcase import BulletWorldTestCase -import pycram.plan_failures +import pycram.failures from pycram.designators import object_designator, action_designator @@ -27,7 +27,7 @@ def plan(self): def setUp(self): super().setUp() - pycram.tasktree.reset_tree() + pycram.tasktree.task_tree.reset_tree() def test_tree_creation(self): """Test the creation and content of a task tree.""" @@ -48,11 +48,11 @@ def test_exception(self): @with_tree def failing_plan(): - raise pycram.plan_failures.PlanFailure("PlanFailure for UnitTesting") + raise pycram.failures.PlanFailure("PlanFailure for UnitTesting") - pycram.tasktree.reset_tree() + pycram.tasktree.task_tree.reset_tree() - self.assertRaises(pycram.plan_failures.PlanFailure, failing_plan) + self.assertRaises(pycram.failures.PlanFailure, failing_plan) tt = pycram.tasktree.task_tree @@ -85,6 +85,19 @@ def test_to_sql(self): result = tt.root.to_sql() self.assertIsNotNone(result) + def test_task_tree_singleton(self): + # Instantiate one TaskTree object + tree1 = pycram.tasktree.TaskTree() + + # Fill the tree + self.plan() + + # Instantiate another TaskTree object + tree2 = pycram.tasktree.TaskTree() + + # Check if both instances point to the same object and contain the same number of elements + self.assertEqual(len(tree1.root), len(tree2.root)) + self.assertIs(tree1, tree2) if __name__ == '__main__': unittest.main()