Skip to content

Commit

Permalink
Update to support Gymnasium v1.0 (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
eleurent authored Aug 18, 2024
2 parents af85faf + 5df3cf7 commit 9db00e0
Show file tree
Hide file tree
Showing 23 changed files with 492 additions and 438 deletions.
20 changes: 3 additions & 17 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand All @@ -25,20 +25,6 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo pip install pygame
pip install -e .[deploy]
- name: Lint with flake8
run: |
pip install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
run: pip install .[testing]
- name: Test with pytest
run: |
pip install pytest
pip install pytest-cov
pytest --cov=./ --cov-report=xml
run: pytest --cov=./ --cov-report=xml
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ repos:
hooks:
- id: isort
args: ["--profile", "black"]
exclude: "__init__.py"
- repo: https://github.com/python/black
rev: 23.3.0
hooks:
Expand Down
8 changes: 0 additions & 8 deletions codecov.yml

This file was deleted.

49 changes: 28 additions & 21 deletions highway_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import sys

__version__ = "1.9.1"
from gymnasium.envs.registration import register

__version__ = "2.0.0"

try:
from farama_notifications import notifications
Expand All @@ -15,96 +17,101 @@
# Hide pygame support prompt
os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1"

from gymnasium.envs.registration import register
from highway_env.envs.common.abstract import MultiAgentWrapper


def register_highway_envs():
def _register_highway_envs():
"""Import the envs module so that envs register themselves."""

from highway_env.envs.common.abstract import MultiAgentWrapper

# exit_env.py
register(
id="exit-v0",
entry_point="highway_env.envs:ExitEnv",
entry_point="highway_env.envs.exit_env:ExitEnv",
)

# highway_env.py
register(
id="highway-v0",
entry_point="highway_env.envs:HighwayEnv",
entry_point="highway_env.envs.highway_env:HighwayEnv",
)

register(
id="highway-fast-v0",
entry_point="highway_env.envs:HighwayEnvFast",
entry_point="highway_env.envs.highway_env:HighwayEnvFast",
)

# intersection_env.py
register(
id="intersection-v0",
entry_point="highway_env.envs:IntersectionEnv",
entry_point="highway_env.envs.intersection_env:IntersectionEnv",
)

register(
id="intersection-v1",
entry_point="highway_env.envs:ContinuousIntersectionEnv",
entry_point="highway_env.envs.intersection_env:ContinuousIntersectionEnv",
)

register(
id="intersection-multi-agent-v0",
entry_point="highway_env.envs:MultiAgentIntersectionEnv",
entry_point="highway_env.envs.intersection_env:MultiAgentIntersectionEnv",
)

register(
id="intersection-multi-agent-v1",
entry_point="highway_env.envs:MultiAgentIntersectionEnv",
entry_point="highway_env.envs.intersection_env:MultiAgentIntersectionEnv",
additional_wrappers=(MultiAgentWrapper.wrapper_spec(),),
)

# lane_keeping_env.py
register(
id="lane-keeping-v0",
entry_point="highway_env.envs:LaneKeepingEnv",
entry_point="highway_env.envs.lane_keeping_env:LaneKeepingEnv",
max_episode_steps=200,
)

# merge_env.py
register(
id="merge-v0",
entry_point="highway_env.envs:MergeEnv",
entry_point="highway_env.envs.merge_env:MergeEnv",
)

# parking_env.py
register(
id="parking-v0",
entry_point="highway_env.envs:ParkingEnv",
entry_point="highway_env.envs.parking_env:ParkingEnv",
)

register(
id="parking-ActionRepeat-v0",
entry_point="highway_env.envs:ParkingEnvActionRepeat",
entry_point="highway_env.envs.parking_env:ParkingEnvActionRepeat",
)

register(
id="parking-parked-v0", entry_point="highway_env.envs:ParkingEnvParkedVehicles"
id="parking-parked-v0",
entry_point="highway_env.envs.parking_env:ParkingEnvParkedVehicles",
)

# racetrack_env.py
register(
id="racetrack-v0",
entry_point="highway_env.envs:RacetrackEnv",
entry_point="highway_env.envs.racetrack_env:RacetrackEnv",
)

# roundabout_env.py
register(
id="roundabout-v0",
entry_point="highway_env.envs:RoundaboutEnv",
entry_point="highway_env.envs.roundabout_env:RoundaboutEnv",
)

# two_way_env.py
register(
id="two-way-v0", entry_point="highway_env.envs:TwoWayEnv", max_episode_steps=15
id="two-way-v0",
entry_point="highway_env.envs.two_way_env:TwoWayEnv",
max_episode_steps=15,
)

# u_turn_env.py
register(id="u-turn-v0", entry_point="highway_env.envs:UTurnEnv")
register(id="u-turn-v0", entry_point="highway_env.envs.u_turn_env:UTurnEnv")


_register_highway_envs()
45 changes: 35 additions & 10 deletions highway_env/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from highway_env.envs.highway_env import *
from highway_env.envs.merge_env import *
from highway_env.envs.parking_env import *
from highway_env.envs.roundabout_env import *
from highway_env.envs.two_way_env import *
from highway_env.envs.intersection_env import *
from highway_env.envs.lane_keeping_env import *
from highway_env.envs.u_turn_env import *
from highway_env.envs.exit_env import *
from highway_env.envs.racetrack_env import *
from highway_env.envs.exit_env import ExitEnv
from highway_env.envs.highway_env import HighwayEnv, HighwayEnvFast
from highway_env.envs.intersection_env import (
ContinuousIntersectionEnv,
IntersectionEnv,
MultiAgentIntersectionEnv,
)
from highway_env.envs.lane_keeping_env import LaneKeepingEnv
from highway_env.envs.merge_env import MergeEnv
from highway_env.envs.parking_env import (
ParkingEnv,
ParkingEnvActionRepeat,
ParkingEnvParkedVehicles,
)
from highway_env.envs.racetrack_env import RacetrackEnv
from highway_env.envs.roundabout_env import RoundaboutEnv
from highway_env.envs.two_way_env import TwoWayEnv
from highway_env.envs.u_turn_env import UTurnEnv

__all__ = [
"ExitEnv",
"HighwayEnv",
"HighwayEnvFast",
"IntersectionEnv",
"ContinuousIntersectionEnv",
"MultiAgentIntersectionEnv",
"LaneKeepingEnv",
"MergeEnv",
"ParkingEnv",
"ParkingEnvActionRepeat",
"RacetrackEnv",
"RoundaboutEnv",
"TwoWayEnv",
"UTurnEnv",
]
10 changes: 7 additions & 3 deletions highway_env/envs/common/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import Wrapper
from gymnasium.utils import RecordConstructorArgs
from gymnasium.wrappers import RecordVideo

from highway_env import utils
Expand Down Expand Up @@ -426,10 +427,13 @@ def __deepcopy__(self, memo):
return result


class MultiAgentWrapper(Wrapper):
class MultiAgentWrapper(Wrapper, RecordConstructorArgs):
def __init__(self, env):
Wrapper.__init__(self, env)
RecordConstructorArgs.__init__(self)

def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
obs, _, _, truncated, info = super().step(action)
reward = info["agents_rewards"]
terminated = info["agents_terminated"]
truncated = info["agents_truncated"]
return obs, reward, terminated, truncated, info
6 changes: 3 additions & 3 deletions highway_env/envs/common/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
lateral: bool = True,
dynamical: bool = False,
clip: bool = True,
**kwargs
**kwargs,
) -> None:
"""
Create a continuous action space.
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
dynamical: bool = False,
clip: bool = True,
actions_per_axis: int = 3,
**kwargs
**kwargs,
) -> None:
super().__init__(
env,
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(
longitudinal: bool = True,
lateral: bool = True,
target_speeds: Optional[Vector] = None,
**kwargs
**kwargs,
) -> None:
"""
Create a discrete action space of meta-actions.
Expand Down
10 changes: 5 additions & 5 deletions highway_env/envs/common/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
weights: List[float],
scaling: Optional[float] = None,
centering_position: Optional[List[float]] = None,
**kwargs
**kwargs,
) -> None:
super().__init__(env)
self.observation_shape = observation_shape
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
see_behind: bool = False,
observe_intentions: bool = False,
include_obstacles: bool = True,
**kwargs: dict
**kwargs: dict,
) -> None:
"""
:param env: The environment to observe
Expand Down Expand Up @@ -293,7 +293,7 @@ def __init__(
align_to_vehicle_axes: bool = False,
clip: bool = True,
as_image: bool = False,
**kwargs: dict
**kwargs: dict,
) -> None:
"""
:param env: The environment to observe
Expand Down Expand Up @@ -674,7 +674,7 @@ def observe(self) -> np.ndarray:
if self.order == "shuffled":
self.env.np_random.shuffle(obs[1:])
# Flatten
return obs
return obs.astype(self.space().dtype)


class LidarObservation(ObservationType):
Expand All @@ -687,7 +687,7 @@ def __init__(
cells: int = 16,
maximum_range: float = 60,
normalize: bool = True,
**kwargs
**kwargs,
):
super().__init__(env, **kwargs)
self.cells = cells
Expand Down
14 changes: 8 additions & 6 deletions highway_env/envs/exit_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Dict, Text, Tuple
from __future__ import annotations

import numpy as np

from highway_env import utils
from highway_env.envs import CircularLane, HighwayEnv, Vehicle
from highway_env.envs.common.action import Action
from highway_env.envs.highway_env import HighwayEnv
from highway_env.road.lane import CircularLane
from highway_env.road.road import Road, RoadNetwork
from highway_env.vehicle.controller import ControlledVehicle
from highway_env.vehicle.kinematics import Vehicle


class ExitEnv(HighwayEnv):
Expand Down Expand Up @@ -44,10 +46,10 @@ def _reset(self) -> None:
self._create_road()
self._create_vehicles()

def step(self, action) -> Tuple[np.ndarray, float, bool, dict]:
obs, reward, terminal, info = super().step(action)
def step(self, action) -> tuple[np.ndarray, float, bool, bool, dict]:
obs, reward, terminated, truncated, info = super().step(action)
info.update({"is_success": self._is_success()})
return obs, reward, terminal, info
return obs, reward, terminated, truncated, info

def _create_road(
self, road_length=1000, exit_position=400, exit_length=100
Expand Down Expand Up @@ -154,7 +156,7 @@ def _reward(self, action: Action) -> float:
reward = np.clip(reward, 0, 1)
return reward

def _rewards(self, action: Action) -> Dict[Text, float]:
def _rewards(self, action: Action) -> dict[str, float]:
lane_index = (
self.vehicle.target_lane_index
if isinstance(self.vehicle, ControlledVehicle)
Expand Down
2 changes: 1 addition & 1 deletion highway_env/envs/intersection_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _info(self, obs: np.ndarray, action: int) -> dict:
info["agents_rewards"] = tuple(
self._agent_reward(action, vehicle) for vehicle in self.controlled_vehicles
)
info["agents_dones"] = tuple(
info["agents_terminated"] = tuple(
self._agent_is_terminal(vehicle) for vehicle in self.controlled_vehicles
)
return info
Expand Down
2 changes: 1 addition & 1 deletion highway_env/road/road.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def straight_road_network(
*nodes_str,
StraightLane(
origin, end, line_types=line_types, speed_limit=speed_limit
)
),
)
return net

Expand Down
Loading

0 comments on commit 9db00e0

Please sign in to comment.