diff --git a/Makefile b/Makefile index d3580c8a..08844165 100644 --- a/Makefile +++ b/Makefile @@ -49,4 +49,7 @@ conda-upload: ./scripts/conda_upload.sh doc: - ./scripts/gen_api_docs.sh \ No newline at end of file + ./scripts/gen_api_docs.sh + +upload-codecov: + codecov --file coverage.xml -t $(CODECOV_TOKEN) \ No newline at end of file diff --git a/examples/arena/run_arena.py b/examples/arena/run_arena.py index 9eb26b7e..e880884c 100644 --- a/examples/arena/run_arena.py +++ b/examples/arena/run_arena.py @@ -26,6 +26,7 @@ def run_arena( seed=0, total_games: int = 10, max_game_onetime: int = 5, + use_tqdm: bool = True, ): env_wrappers = [RecordWinner] if render: @@ -33,7 +34,7 @@ def run_arena( env_wrappers.append(TictactoeRender) - arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=True) + arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=use_tqdm) agent1 = LocalAgent("../selfplay/opponent_templates/random_opponent") agent2 = LocalAgent("../selfplay/opponent_templates/random_opponent") @@ -52,4 +53,4 @@ def run_arena( if __name__ == "__main__": run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=10) - # run_arena(render=True, parallel=True, seed=1, total_games=10, max_game_onetime=2) + # run_arena(render=False, parallel=False, seed=1, total_games=1, max_game_onetime=1,use_tqdm=False) diff --git a/examples/snake/jidi_random_vs_openrl_random.py b/examples/snake/jidi_random_vs_openrl_random.py index 364beb8e..37eb9e98 100644 --- a/examples/snake/jidi_random_vs_openrl_random.py +++ b/examples/snake/jidi_random_vs_openrl_random.py @@ -28,6 +28,7 @@ def run_arena( seed=0, total_games: int = 10, max_game_onetime: int = 5, + use_tqdm: bool = True, ): env_wrappers = [RecordWinner] @@ -36,7 +37,7 @@ def run_arena( f"snakes_{player_num}v{player_num}", env_wrappers=env_wrappers, render=render, - use_tqdm=True, + use_tqdm=use_tqdm, ) agent1 = JiDiAgent("./submissions/random_agent", player_num=player_num) @@ -55,4 +56,12 @@ def run_arena( if __name__ == "__main__": - run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5) + # run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5) + run_arena( + render=False, + parallel=False, + seed=0, + total_games=1, + max_game_onetime=1, + use_tqdm=False, + ) diff --git a/openrl/envs/PettingZoo/__init__.py b/openrl/envs/PettingZoo/__init__.py index b1384f81..6b29092f 100644 --- a/openrl/envs/PettingZoo/__init__.py +++ b/openrl/envs/PettingZoo/__init__.py @@ -35,6 +35,7 @@ def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs): from pettingzoo.classic import tictactoe_v3 env = tictactoe_v3.env(render_mode=render_mode) + else: raise NotImplementedError return env diff --git a/openrl/envs/snake/snake_pettingzoo.py b/openrl/envs/snake/snake_pettingzoo.py index 136bb420..fb34358a 100644 --- a/openrl/envs/snake/snake_pettingzoo.py +++ b/openrl/envs/snake/snake_pettingzoo.py @@ -83,7 +83,8 @@ def action_space(self, agent): return deepcopy(self._action_spaces[agent]) def observe(self, agent): - return self.raw_obs[self.agent_name_to_slice[agent]] + obs = self.raw_obs[self.agent_name_to_slice[agent]] + return obs def reset( self, diff --git a/openrl/selfplay/opponents/random_opponent.py b/openrl/selfplay/opponents/random_opponent.py index b3ff8569..1f396c34 100644 --- a/openrl/selfplay/opponents/random_opponent.py +++ b/openrl/selfplay/opponents/random_opponent.py @@ -41,7 +41,11 @@ def _sample_random_action( ): action_space = self.env.action_space(player_name) if isinstance(action_space, list): + if not isinstance(observation, list): + observation = [observation] + action = [] + for obs, space in zip(observation, action_space): mask = obs.get("action_mask", None) action.append(space.sample(mask)) diff --git a/openrl/supports/opendata/utils/opendata_utils.py b/openrl/supports/opendata/utils/opendata_utils.py index 7b387d70..4cae62df 100644 --- a/openrl/supports/opendata/utils/opendata_utils.py +++ b/openrl/supports/opendata/utils/opendata_utils.py @@ -48,6 +48,7 @@ def data_server_wrapper(fp): def load_dataset(data_path: str, split: str): from datasets import load_from_disk + if Path(data_path).exists(): dataset = load_from_disk("{}/{}".format(data_path, split)) elif "data_server:" in data_path: diff --git a/setup.py b/setup.py index f00d23fc..494e4c4c 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def get_extra_requires() -> dict: "retro": ["gym-retro"], "super_mario": ["gym-super-mario-bros"], } + req["test"].extend(req["selfplay"]) return req diff --git a/tests/test_arena/test_reproducibility.py b/tests/test_arena/test_reproducibility.py new file mode 100644 index 00000000..0d186ab0 --- /dev/null +++ b/tests/test_arena/test_reproducibility.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +import os +import sys + +import pytest + +from openrl.arena import make_arena +from openrl.arena.agents.local_agent import LocalAgent +from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner + + +def run_arena( + render: bool = False, + parallel: bool = True, + seed=0, + total_games: int = 10, + max_game_onetime: int = 5, +): + env_wrappers = [RecordWinner] + if render: + from examples.selfplay.tictactoe_utils.tictactoe_render import TictactoeRender + + env_wrappers.append(TictactoeRender) + + arena = make_arena("tictactoe_v3", env_wrappers=env_wrappers, use_tqdm=False) + + agent1 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + agent2 = LocalAgent("./examples/selfplay/opponent_templates/random_opponent") + + arena.reset( + agents={"agent1": agent1, "agent2": agent2}, + total_games=total_games, + max_game_onetime=max_game_onetime, + seed=seed, + ) + result = arena.run(parallel=parallel) + arena.close() + print(result) + return result + + +@pytest.mark.unittest +def test_seed(): + seed = 0 + test_time = 3 + pre_result = None + for parallel in [False, True]: + for i in range(test_time): + result = run_arena(seed=seed, parallel=parallel, total_games=20) + if pre_result is not None: + assert pre_result == result, f"parallel={parallel}, seed={seed}" + pre_result = result + + +if __name__ == "__main__": + sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) diff --git a/tests/test_supports/test_opendata/test_opendata.py b/tests/test_supports/test_opendata/test_opendata.py index 4a0fe64f..44807906 100644 --- a/tests/test_supports/test_opendata/test_opendata.py +++ b/tests/test_supports/test_opendata/test_opendata.py @@ -30,7 +30,5 @@ def test_data_abs_path(): assert data_abs_path(data_path) == data_path - - if __name__ == "__main__": sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))