Skip to content

Commit

Permalink
Upgrade from gym to gymnasium
Browse files Browse the repository at this point in the history
  • Loading branch information
eugeneteoh committed May 11, 2024
1 parent 790a90e commit 34575c2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 28 deletions.
4 changes: 2 additions & 2 deletions examples/rlbench_gym.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import rlbench.gym

env = gym.make('reach_target-state-v0', render_mode='human')
Expand All @@ -9,7 +9,7 @@
if i % episode_length == 0:
print('Reset Episode')
obs = env.reset()
obs, reward, terminate, _ = env.step(env.action_space.sample())
obs, reward, terminate, _, _ = env.step(env.action_space.sample())
env.render() # Note: rendering increases step time.

print('Done')
Expand Down
3 changes: 2 additions & 1 deletion rlbench/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from gym.envs.registration import register
# from gym.envs.registration import register
from gymnasium import register
import rlbench.backend.task as task
import os
from rlbench.utils import name_to_task_class
Expand Down
42 changes: 19 additions & 23 deletions rlbench/gym/rlbench_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Union, Dict, Tuple

import gym
import gymnasium as gym
import numpy as np
from gym import spaces
from gymnasium import spaces
from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.objects.vision_sensor import VisionSensor

from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity
from rlbench.action_modes.action_mode import JointPositionActionMode, MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig
Expand All @@ -17,10 +17,10 @@
class RLBenchEnv(gym.Env):
"""An gym wrapper for RLBench."""

metadata = {'render.modes': ['human', 'rgb_array']}
metadata = {'render_modes': ['human', 'rgb_array']}

def __init__(self, task_class, observation_mode='state',
render_mode: Union[None, str] = None):
render_mode: Union[None, str] = None, action_mode=None):
self._observation_mode = observation_mode
self._render_mode = render_mode
obs_config = ObservationConfig()
Expand All @@ -33,16 +33,18 @@ def __init__(self, task_class, observation_mode='state',
raise ValueError(
'Unrecognised observation_mode: %s.' % observation_mode)

action_mode = MoveArmThenGripper(JointVelocity(), Discrete())
if action_mode is None:
action_mode = JointPositionActionMode()
self.env = Environment(
action_mode, obs_config=obs_config, headless=True)
self.env.launch()
self.task = self.env.get_task(task_class)

_, obs = self.task.reset()

action_bounds = action_mode.action_bounds()
self.action_space = spaces.Box(
low=-1.0, high=1.0, shape=self.env.action_shape)
low=action_bounds[0], high=action_bounds[1], shape=self.env.action_shape)

if observation_mode == 'state':
self.observation_space = spaces.Box(
Expand Down Expand Up @@ -84,28 +86,22 @@ def _extract_obs(self, obs) -> Dict[str, np.ndarray]:
"front_rgb": obs.front_rgb,
}

def render(self, mode='human') -> Union[None, np.ndarray]:
if mode != self._render_mode:
raise ValueError(
'The render mode must match the render mode selected in the '
'constructor. \nI.e. if you want "human" render mode, then '
'create the env by calling: '
'gym.make("reach_target-state-v0", render_mode="human").\n'
'You passed in mode %s, but expected %s.' % (
mode, self._render_mode))
if mode == 'rgb_array':
def render(self) -> Union[None, np.ndarray]:
if self.render_mode == 'rgb_array':
frame = self._gym_cam.capture_rgb()
frame = np.clip((frame * 255.).astype(np.uint8), 0, 255)
return frame

def reset(self) -> Dict[str, np.ndarray]:
def reset(self, seed=None, options=None):
super().reset(seed=seed)
descriptions, obs = self.task.reset()
del descriptions # Not used.
return self._extract_obs(obs)
return self._extract_obs(obs), {"text_descriptions": descriptions}

def step(self, action) -> Tuple[Dict[str, np.ndarray], float, bool, dict]:
obs, reward, terminate = self.task.step(action)
return self._extract_obs(obs), reward, terminate, {}
obs, reward, success, _terminate = self.task.step(action)
terminated = success
truncated = _terminate and not success
return self._extract_obs(obs), reward, terminated, truncated, {"success": success}

def close(self) -> None:
self.env.shutdown()
2 changes: 1 addition & 1 deletion rlbench/task_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def step(self, action) -> (Observation, int, bool):
raise RuntimeError(
'User requested shaped rewards, but task %s does not have '
'a defined reward() function.' % self._task.get_name())
return self._scene.get_observation(), reward, terminate
return self._scene.get_observation(), reward, success, terminate

def get_demos(self, amount: int, live_demos: bool = False,
image_paths: bool = False,
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def get_version(rel_path):
'rlbench.gym'
],
extras_require={
"dev": ["pytest", "html-testRunner", "gym"]
"gymnasium": ["gymnasium==1.0.0a1"],
"dev": ["pytest", "html-testRunner", "gym"]
},
package_data={'': ['*.ttm', '*.obj', '**/**/*.ttm', '**/**/*.obj'],
'rlbench': ['task_design.ttt']},
Expand Down

0 comments on commit 34575c2

Please sign in to comment.