Skip to content

Commit

Permalink
Update for Gymnasium alpha 2
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed May 21, 2024
1 parent 96abd7d commit aadb895
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 37 deletions.
8 changes: 1 addition & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ jobs:
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
pip install .[extra_no_roms,tests,docs]
pip install .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
- name: Lint with ruff
Expand Down
43 changes: 16 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,13 @@
""" # noqa:E501

# Atari Games download is sometimes problematic:
# https://github.com/Farama-Foundation/AutoROM/issues/39
# That's why we define extra packages without it.
extra_no_roms = [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"shimmy[atari]~=1.3.0",
"pillow",
]

extra_packages = extra_no_roms + [ # noqa: RUF005
# For atari roms,
"autorom[accept-rom-license]~=0.6.1",
]


setup(
name="stable_baselines3",
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium==1.0.0a1",
"gymnasium>=1.0.0a1",
"numpy>=1.20",
"torch>=1.13",
# For saving models
Expand Down Expand Up @@ -133,8 +109,21 @@
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": extra_packages,
"extra_no_roms": extra_no_roms,
"extra": [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py>=0.9.0",
"pillow",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
Expand Down
5 changes: 2 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
import shutil

import ale_py
import gymnasium as gym
import numpy as np
import pytest
import torch as th
from gymnasium import spaces
from shimmy import registration

import stable_baselines3 as sb3
from stable_baselines3 import A2C
Expand All @@ -25,8 +25,7 @@
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

# a hack to get atari environment registered for 1.0.0 alpha 1
registration._register_atari_envs()
gym.register_envs(ale_py)


@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")])
Expand Down

0 comments on commit aadb895

Please sign in to comment.