Skip to content

Commit

Permalink
Fix missing joint_position_action and add gripper action (#221)
Browse files Browse the repository at this point in the history
* Fix missing joint_position_action and add gripper action

* Remove requirements.txt as it is not needed anymore

* Change pyrep commit hash temporarily

* Update pyrep hash

* Remove unwanted comment

* Clean up environment for each task test

* Use pytest-xdist to parallelise tests

* Add verbose flag for pytest

* Fix test being flaky due to non-determinism
  • Loading branch information
eugeneteoh authored Apr 10, 2024
1 parent eece2ab commit 790a90e
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 26 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/task_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ jobs:
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
pip install ".[dev]"
python3 -m unittest discover tests/demos
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
3 changes: 2 additions & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ jobs:
export QT_QPA_PLATFORM_PLUGIN_PATH=$COPPELIASIM_ROOT
pip install ".[dev]"
python3 -m unittest discover tests/unit
pip install "pytest-xdist[psutil]"
pytest -v -n auto tests/unit
6 changes: 0 additions & 6 deletions requirements.txt

This file was deleted.

36 changes: 20 additions & 16 deletions rlbench/backend/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self,

self._robot_shapes = self.robot.arm.get_objects_in_tree(
object_type=ObjectType.SHAPE)
self._execute_demo_joint_position_action = None
self._joint_position_action = None

def load(self, task: Task) -> None:
"""Loads the task and positions at the centre of the workspace.
Expand Down Expand Up @@ -337,6 +337,8 @@ def get_demo(self, record: bool = True,
demo = []
if record:
self.pyrep.step() # Need this here or get_force doesn't work...
self._joint_position_action = None
gripper_open = 1.0 if self.robot.gripper.get_open_amount()[0] > 0.9 else 0.0
demo.append(self.get_observation())
while True:
success = False
Expand Down Expand Up @@ -366,7 +368,7 @@ def get_demo(self, record: bool = True,
while not done:
done = path.step()
self.step()
self._execute_demo_joint_position_action = path.get_executed_joint_position_action()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
self._demo_record_step(demo, record, callable_each_step)
success, term = self.task.success()

Expand All @@ -385,9 +387,10 @@ def get_demo(self, record: bool = True,
if not contains_param:
done = False
while not done:
done = gripper.actuate(1.0, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = 1.0
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -397,9 +400,10 @@ def get_demo(self, record: bool = True,
if not contains_param:
done = False
while not done:
done = gripper.actuate(0.0, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = 0.0
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -409,9 +413,10 @@ def get_demo(self, record: bool = True,
num = float(rest[:rest.index(')')])
done = False
while not done:
done = gripper.actuate(num, 0.04)
self.pyrep.step()
self.task.step()
gripper_open = num
done = gripper.actuate(gripper_open, 0.04)
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
if self._obs_config.record_gripper_closing:
self._demo_record_step(
demo, record, callable_each_step)
Expand All @@ -429,8 +434,8 @@ def get_demo(self, record: bool = True,
# (e.g. ball rowling to goal)
if not success:
for _ in range(10):
self.pyrep.step()
self.task.step()
self.step()
self._joint_position_action = np.append(path.get_executed_joint_position_action(), gripper_open)
self._demo_record_step(demo, record, callable_each_step)
success, term = self.task.success()
if success:
Expand Down Expand Up @@ -545,8 +550,7 @@ def _get_cam_data(cam: VisionSensor, name: str):
misc.update(_get_cam_data(self._cam_front, 'front_camera'))
misc.update(_get_cam_data(self._cam_wrist, 'wrist_camera'))
misc.update({"variation_index": self._variation_index})
if self._execute_demo_joint_position_action is not None:
if self._joint_position_action is not None:
# Store the actual requested joint positions during demo collection
misc.update({"executed_demo_joint_position_action": self._execute_demo_joint_position_action})
self._execute_demo_joint_position_action = None
misc.update({"joint_position_action": self._joint_position_action})
return misc
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_version(rel_path):
raise RuntimeError("Unable to find version string.")

core_requirements = [
"pyrep @ git+https://github.com/stepjam/PyRep.git@076ca15c57f2495a4194da03565891ab1aaa317e",
"pyrep @ git+https://github.com/stepjam/PyRep.git@cd9830b58ef09538562b785fc0c257f528f1762b",
"numpy",
"Pillow",
"pyquaternion",
Expand All @@ -60,7 +60,7 @@ def get_version(rel_path):
'rlbench.gym'
],
extras_require={
"dev": ["html-testRunner", "gym"]
"dev": ["pytest", "html-testRunner", "gym"]
},
package_data={'': ['*.ttm', '*.obj', '**/**/*.ttm', '**/**/*.obj'],
'rlbench': ['task_design.ttt']},
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,35 @@ def test_swap_arm(self):
robot_setup=robot_config)
self.env.launch()
self.env.shutdown()

def test_executed_jp_action(self):
for task_cls in [ReachTarget, TakeLidOffSaucepan]:
with self.subTest(task_cls=task_cls):
task = self.get_task(
task_cls, JointPosition(True))
num_episodes = 20
demos = task.get_demos(num_episodes, live_demos=True)
total_reward = 0.0
# Check if executed joint position action is stored
for demo in demos:
jp_action = []
self.assertTrue("joint_position_action" not in demo[0].misc)
for t, obs in enumerate(demo):
if t == 0:
# First timestep should not have an action
self.assertTrue('joint_position_action' not in obs.misc)
else:
self.assertTrue("joint_position_action" in obs.misc)
jp_action.append(obs.misc["joint_position_action"])

task.reset_to_demo(demo)
for t, action in enumerate(jp_action):
obs, reward, term = task.step(action)
if term:
break
total_reward += reward

success_rate = total_reward / num_episodes
self.assertTrue(success_rate >= 0.9)
self.env.shutdown()

0 comments on commit 790a90e

Please sign in to comment.