Skip to content

Commit

Permalink
Support vector env
Browse files Browse the repository at this point in the history
  • Loading branch information
eugeneteoh committed May 12, 2024
1 parent 34575c2 commit 56c87ba
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 62 deletions.
2 changes: 1 addition & 1 deletion examples/rlbench_gym.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import gymnasium as gym
import rlbench.gym

env = gym.make('reach_target-state-v0', render_mode='human')
env = gym.make('reach_target-vision-v0', render_mode="human")

training_steps = 120
episode_length = 40
Expand Down
18 changes: 18 additions & 0 deletions examples/rlbench_gym_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import gymnasium as gym
import rlbench.gym

if __name__ == "__main__":
# Only works with spawn (multiprocessing) context
env = gym.make_vec('reach_target-vision-v0', num_envs=2, vectorization_mode="async", vector_kwargs={"context": "spawn"})

training_steps = 120
episode_length = 40
for i in range(training_steps):
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')
env.close()
123 changes: 62 additions & 61 deletions rlbench/gym/rlbench_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,22 @@
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.objects.vision_sensor import VisionSensor
from pyrep.const import RenderMode


from rlbench.action_modes.action_mode import JointPositionActionMode, MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.action_modes.action_mode import JointPositionActionMode
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig


class RLBenchEnv(gym.Env):
"""An gym wrapper for RLBench."""

metadata = {'render_modes': ['human', 'rgb_array']}
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, task_class, observation_mode='state',
render_mode: Union[None, str] = None, action_mode=None):
self.task_class = task_class
self._observation_mode = observation_mode
self._render_mode = render_mode
obs_config = ObservationConfig()
Expand All @@ -32,76 +30,79 @@ def __init__(self, task_class, observation_mode='state',
else:
raise ValueError(
'Unrecognised observation_mode: %s.' % observation_mode)

self.obs_config = obs_config
if action_mode is None:
action_mode = JointPositionActionMode()
self.env = Environment(
action_mode, obs_config=obs_config, headless=True)
self.env.launch()
self.task = self.env.get_task(task_class)

_, obs = self.task.reset()

action_bounds = action_mode.action_bounds()
self.action_space = spaces.Box(
low=action_bounds[0], high=action_bounds[1], shape=self.env.action_shape)

if observation_mode == 'state':
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=obs.get_low_dim_data().shape)
elif observation_mode == 'vision':
self.observation_space = spaces.Dict({
"state": spaces.Box(
low=-np.inf, high=np.inf,
shape=obs.get_low_dim_data().shape),
self.action_mode = action_mode

self.rlbench_env = Environment(
action_mode=self.action_mode,
obs_config=self.obs_config,
headless=True,
)
self.rlbench_env.launch()
self.rlbench_task_env = self.rlbench_env.get_task(self.task_class)
if render_mode is not None:
cam_placeholder = Dummy("cam_cinematic_placeholder")
self.gym_cam = VisionSensor.create([640, 360])
self.gym_cam.set_pose(cam_placeholder.get_pose())
if render_mode == "human":
self.gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED)
else:
self.gym_cam.set_render_mode(RenderMode.OPENGL3)
_, obs = self.rlbench_task_env.reset()

self.observation_space = {
"state": spaces.Box(
low=-np.inf, high=np.inf, shape=obs.get_low_dim_data().shape),
}
if observation_mode == 'vision':
self.observation_space.update({
"left_shoulder_rgb": spaces.Box(
low=0, high=1, shape=obs.left_shoulder_rgb.shape),
low=0, high=255, shape=obs.left_shoulder_rgb.shape, dtype=np.uint8),
"right_shoulder_rgb": spaces.Box(
low=0, high=1, shape=obs.right_shoulder_rgb.shape),
low=0, high=255, shape=obs.right_shoulder_rgb.shape, dtype=np.uint8),
"wrist_rgb": spaces.Box(
low=0, high=1, shape=obs.wrist_rgb.shape),
low=0, high=255, shape=obs.wrist_rgb.shape, dtype=np.uint8),
"front_rgb": spaces.Box(
low=0, high=1, shape=obs.front_rgb.shape),
})

if render_mode is not None:
# Add the camera to the scene
cam_placeholder = Dummy('cam_cinematic_placeholder')
self._gym_cam = VisionSensor.create([640, 360])
self._gym_cam.set_pose(cam_placeholder.get_pose())
if render_mode == 'human':
self._gym_cam.set_render_mode(RenderMode.OPENGL3_WINDOWED)
else:
self._gym_cam.set_render_mode(RenderMode.OPENGL3)

def _extract_obs(self, obs) -> Dict[str, np.ndarray]:
if self._observation_mode == 'state':
return obs.get_low_dim_data()
elif self._observation_mode == 'vision':
return {
"state": obs.get_low_dim_data(),
"left_shoulder_rgb": obs.left_shoulder_rgb,
"right_shoulder_rgb": obs.right_shoulder_rgb,
"wrist_rgb": obs.wrist_rgb,
"front_rgb": obs.front_rgb,
}

def render(self) -> Union[None, np.ndarray]:
low=0, high=255, shape=obs.front_rgb.shape, dtype=np.uint8),
})
self.observation_space = spaces.Dict(self.observation_space)

action_low, action_high = action_mode.action_bounds()
self.action_space = spaces.Box(
low=action_low, high=action_high, shape=self.rlbench_env.action_shape)

def _extract_obs(self, rlbench_obs):
gym_obs = {}
gym_obs["state"] = np.float32(rlbench_obs.get_low_dim_data())
if self._observation_mode == 'vision':
gym_obs.update({
"left_shoulder_rgb": rlbench_obs.left_shoulder_rgb,
"right_shoulder_rgb": rlbench_obs.right_shoulder_rgb,
"wrist_rgb": rlbench_obs.wrist_rgb,
"front_rgb": rlbench_obs.front_rgb,
})
return gym_obs

def render(self):
if self.render_mode == 'rgb_array':
frame = self._gym_cam.capture_rgb()
frame = self.gym_cam.capture_rgb()
frame = np.clip((frame * 255.).astype(np.uint8), 0, 255)
return frame

def reset(self, seed=None, options=None):
super().reset(seed=seed)
descriptions, obs = self.task.reset()
descriptions, obs = self.rlbench_task_env.reset()
return self._extract_obs(obs), {"text_descriptions": descriptions}

def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
obs, reward, success, _terminate = self.task.step(action)
def step(self, action):
obs, reward, success, _terminate = self.rlbench_task_env.step(action)
terminated = success
truncated = _terminate and not success
return self._extract_obs(obs), reward, terminated, truncated, {"success": success}

def close(self) -> None:
self.env.shutdown()
self.rlbench_env.shutdown()


0 comments on commit 56c87ba

Please sign in to comment.