From 56c87ba537772db9597dda6f2239ab09c5c15bad Mon Sep 17 00:00:00 2001 From: Eugene Teoh Date: Sun, 12 May 2024 11:10:26 +0100 Subject: [PATCH] Support vector env --- examples/rlbench_gym.py | 2 +- examples/rlbench_gym_vector.py | 18 +++++ rlbench/gym/rlbench_env.py | 123 +++++++++++++++++---------------- 3 files changed, 81 insertions(+), 62 deletions(-) create mode 100644 examples/rlbench_gym_vector.py diff --git a/examples/rlbench_gym.py b/examples/rlbench_gym.py index dac21fc0f..2807a8c28 100644 --- a/examples/rlbench_gym.py +++ b/examples/rlbench_gym.py @@ -1,7 +1,7 @@ import gymnasium as gym import rlbench.gym -env = gym.make('reach_target-state-v0', render_mode='human') +env = gym.make('reach_target-vision-v0', render_mode="human") training_steps = 120 episode_length = 40 diff --git a/examples/rlbench_gym_vector.py b/examples/rlbench_gym_vector.py new file mode 100644 index 000000000..6c25a12e8 --- /dev/null +++ b/examples/rlbench_gym_vector.py @@ -0,0 +1,18 @@ +import gymnasium as gym +import rlbench.gym + +if __name__ == "__main__": + # Only works with spawn (multiprocessing) context + env = gym.make_vec('reach_target-vision-v0', num_envs=2, vectorization_mode="async", vector_kwargs={"context": "spawn"}) + + training_steps = 120 + episode_length = 40 + for i in range(training_steps): + if i % episode_length == 0: + print('Reset Episode') + obs = env.reset() + obs, reward, terminate, _, _ = env.step(env.action_space.sample()) + env.render() # Note: rendering increases step time. + + print('Done') + env.close() diff --git a/rlbench/gym/rlbench_env.py b/rlbench/gym/rlbench_env.py index 0dd870359..8febb54c5 100644 --- a/rlbench/gym/rlbench_env.py +++ b/rlbench/gym/rlbench_env.py @@ -3,24 +3,22 @@ import gymnasium as gym import numpy as np from gymnasium import spaces -from pyrep.const import RenderMode from pyrep.objects.dummy import Dummy from pyrep.objects.vision_sensor import VisionSensor +from pyrep.const import RenderMode + -from rlbench.action_modes.action_mode import JointPositionActionMode, MoveArmThenGripper -from rlbench.action_modes.arm_action_modes import JointPosition -from rlbench.action_modes.gripper_action_modes import Discrete +from rlbench.action_modes.action_mode import JointPositionActionMode from rlbench.environment import Environment from rlbench.observation_config import ObservationConfig - class RLBenchEnv(gym.Env): """An gym wrapper for RLBench.""" - - metadata = {'render_modes': ['human', 'rgb_array']} + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} def __init__(self, task_class, observation_mode='state', render_mode: Union[None, str] = None, action_mode=None): + self.task_class = task_class self._observation_mode = observation_mode self._render_mode = render_mode obs_config = ObservationConfig() @@ -32,76 +30,79 @@ def __init__(self, task_class, observation_mode='state', else: raise ValueError( 'Unrecognised observation_mode: %s.' % observation_mode) - + self.obs_config = obs_config if action_mode is None: action_mode = JointPositionActionMode() - self.env = Environment( - action_mode, obs_config=obs_config, headless=True) - self.env.launch() - self.task = self.env.get_task(task_class) - - _, obs = self.task.reset() - - action_bounds = action_mode.action_bounds() - self.action_space = spaces.Box( - low=action_bounds[0], high=action_bounds[1], shape=self.env.action_shape) - - if observation_mode == 'state': - self.observation_space = spaces.Box( - low=-np.inf, high=np.inf, shape=obs.get_low_dim_data().shape) - elif observation_mode == 'vision': - self.observation_space = spaces.Dict({ - "state": spaces.Box( - low=-np.inf, high=np.inf, - shape=obs.get_low_dim_data().shape), + self.action_mode = action_mode + + self.rlbench_env = Environment( + action_mode=self.action_mode, + obs_config=self.obs_config, + headless=True, + ) + self.rlbench_env.launch() + self.rlbench_task_env = self.rlbench_env.get_task(self.task_class) + if render_mode is not None: + cam_placeholder = Dummy("cam_cinematic_placeholder") + self.gym_cam = VisionSensor.create([640, 360]) + self.gym_cam.set_pose(cam_placeholder.get_pose()) + if render_mode == "human": + self.gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED) + else: + self.gym_cam.set_render_mode(RenderMode.OPENGL3) + _, obs = self.rlbench_task_env.reset() + + self.observation_space = { + "state": spaces.Box( + low=-np.inf, high=np.inf, shape=obs.get_low_dim_data().shape), + } + if observation_mode == 'vision': + self.observation_space.update({ "left_shoulder_rgb": spaces.Box( - low=0, high=1, shape=obs.left_shoulder_rgb.shape), + low=0, high=255, shape=obs.left_shoulder_rgb.shape, dtype=np.uint8), "right_shoulder_rgb": spaces.Box( - low=0, high=1, shape=obs.right_shoulder_rgb.shape), + low=0, high=255, shape=obs.right_shoulder_rgb.shape, dtype=np.uint8), "wrist_rgb": spaces.Box( - low=0, high=1, shape=obs.wrist_rgb.shape), + low=0, high=255, shape=obs.wrist_rgb.shape, dtype=np.uint8), "front_rgb": spaces.Box( - low=0, high=1, shape=obs.front_rgb.shape), - }) - - if render_mode is not None: - # Add the camera to the scene - cam_placeholder = Dummy('cam_cinematic_placeholder') - self._gym_cam = VisionSensor.create([640, 360]) - self._gym_cam.set_pose(cam_placeholder.get_pose()) - if render_mode == 'human': - self._gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED) - else: - self._gym_cam.set_render_mode(RenderMode.OPENGL3) - - def _extract_obs(self, obs) -> Dict[str, np.ndarray]: - if self._observation_mode == 'state': - return obs.get_low_dim_data() - elif self._observation_mode == 'vision': - return { - "state": obs.get_low_dim_data(), - "left_shoulder_rgb": obs.left_shoulder_rgb, - "right_shoulder_rgb": obs.right_shoulder_rgb, - "wrist_rgb": obs.wrist_rgb, - "front_rgb": obs.front_rgb, - } - - def render(self) -> Union[None, np.ndarray]: + low=0, high=255, shape=obs.front_rgb.shape, dtype=np.uint8), + }) + self.observation_space = spaces.Dict(self.observation_space) + + action_low, action_high = action_mode.action_bounds() + self.action_space = spaces.Box( + low=action_low, high=action_high, shape=self.rlbench_env.action_shape) + + def _extract_obs(self, rlbench_obs): + gym_obs = {} + gym_obs["state"] = np.float32(rlbench_obs.get_low_dim_data()) + if self._observation_mode == 'vision': + gym_obs.update({ + "left_shoulder_rgb": rlbench_obs.left_shoulder_rgb, + "right_shoulder_rgb": rlbench_obs.right_shoulder_rgb, + "wrist_rgb": rlbench_obs.wrist_rgb, + "front_rgb": rlbench_obs.front_rgb, + }) + return gym_obs + + def render(self): if self.render_mode == 'rgb_array': - frame = self._gym_cam.capture_rgb() + frame = self.gym_cam.capture_rgb() frame = np.clip((frame * 255.).astype(np.uint8), 0, 255) return frame def reset(self, seed=None, options=None): super().reset(seed=seed) - descriptions, obs = self.task.reset() + descriptions, obs = self.rlbench_task_env.reset() return self._extract_obs(obs), {"text_descriptions": descriptions} - def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, dict]: - obs, reward, success, _terminate = self.task.step(action) + def step(self, action): + obs, reward, success, _terminate = self.rlbench_task_env.step(action) terminated = success truncated = _terminate and not success return self._extract_obs(obs), reward, terminated, truncated, {"success": success} def close(self) -> None: - self.env.shutdown() + self.rlbench_env.shutdown() + +