diff --git a/Gallery.md b/Gallery.md
index 7bf4e4c4..bb7a2364 100644
--- a/Gallery.md
+++ b/Gallery.md
@@ -73,5 +73,6 @@ Users are also welcome to contribute their own training examples and demos to th
| [GridWorld](./examples/gridworld/)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [code](./examples/gridworld/) |
| [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/super_mario/) |
| [Gym Retro](https://github.com/openai/retro)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/retro/) |
+| [Crafter](https://github.com/danijar/crafter)
| ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![image](https://img.shields.io/badge/-image-red) | [code](./examples/crafter/) |
\ No newline at end of file
diff --git a/README.md b/README.md
index afccec54..564fa376 100644
--- a/README.md
+++ b/README.md
@@ -126,6 +126,7 @@ Environments currently supported by OpenRL (for more details, please refer to [G
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
+- [Crafter](https://github.com/danijar/crafter)
This framework has undergone multiple iterations by the [OpenRL-Lab](https://github.com/OpenRL-Lab) team which has
applied it in academic research.
diff --git a/README_zh.md b/README_zh.md
index ed86c7e4..d2aed051 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -101,6 +101,7 @@ OpenRL目前支持的环境(更多详情请参考 [Gallery](Gallery.md)):
- [GridWorld](./examples/gridworld/)
- [Super Mario Bros](https://github.com/Kautenja/gym-super-mario-bros)
- [Gym Retro](https://github.com/openai/retro)
+- [Crafter](https://github.com/danijar/crafter)
该框架经过了[OpenRL-Lab](https://github.com/OpenRL-Lab)的多次迭代并应用于学术研究,目前已经成为了一个成熟的强化学习框架。
OpenRL-Lab将持续维护和更新OpenRL,欢迎大家加入我们的[开源社区](./docs/CONTRIBUTING_zh.md),一起为强化学习的发展做出贡献。
diff --git a/examples/snake/submissions/rule_v1/submission.py b/examples/snake/submissions/rule_v1/submission.py
index db9a81e9..14a4b414 100644
--- a/examples/snake/submissions/rule_v1/submission.py
+++ b/examples/snake/submissions/rule_v1/submission.py
@@ -243,8 +243,9 @@ def step(self): # delay: prevent rear-end collision
and self.state + state == 0
): # third claim or more
print(
- "snake {} meets third or more claim in grid ({}, {})"
- .format(key, x_, y_)
+ "snake {} meets third or more claim in grid ({}, {})".format(
+ key, x_, y_
+ )
)
controversy = self.controversy[(x_, y_)]
pprint.pprint(controversy)
diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py
index 31e52e85..7b67fcdd 100644
--- a/openrl/buffers/offpolicy_replay_data.py
+++ b/openrl/buffers/offpolicy_replay_data.py
@@ -251,9 +251,7 @@ def feed_forward_generator(
batch_size = n_rollout_threads * (episode_length - 1) * num_agents
if mini_batch_size is None:
- assert (
- batch_size >= num_mini_batch
- ), (
+ assert batch_size >= num_mini_batch, (
"DQN requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of DQN mini batches ({})."
diff --git a/openrl/buffers/replay_data.py b/openrl/buffers/replay_data.py
index a8f4c1b7..b81b493f 100644
--- a/openrl/buffers/replay_data.py
+++ b/openrl/buffers/replay_data.py
@@ -561,9 +561,7 @@ def feed_forward_generator(
batch_size = n_rollout_threads * episode_length * num_agents
if mini_batch_size is None:
- assert (
- batch_size >= num_mini_batch
- ), (
+ assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
@@ -658,9 +656,7 @@ def feed_forward_critic_obs_generator(
batch_size = n_rollout_threads * episode_length
if mini_batch_size is None:
- assert (
- batch_size >= num_mini_batch
- ), (
+ assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) * number of agents ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
@@ -721,9 +717,7 @@ def feed_forward_generator_transformer(
batch_size = n_rollout_threads * episode_length
if mini_batch_size is None:
- assert (
- batch_size >= num_mini_batch
- ), (
+ assert batch_size >= num_mini_batch, (
"PPO requires the number of processes ({}) "
"* number of steps ({}) = {} "
"to be greater than or equal to the number of PPO mini batches ({})."
diff --git a/openrl/envs/mpe/rendering.py b/openrl/envs/mpe/rendering.py
index a7197dca..9e458999 100644
--- a/openrl/envs/mpe/rendering.py
+++ b/openrl/envs/mpe/rendering.py
@@ -58,8 +58,9 @@ def get_display(spec):
return pyglet.canvas.Display(spec)
else:
raise error.Error(
- "Invalid display specification: {}. (Must be a string like :0 or None.)"
- .format(spec)
+ "Invalid display specification: {}. (Must be a string like :0 or None.)".format(
+ spec
+ )
)
diff --git a/openrl/envs/nlp/utils/metrics/meteor.py b/openrl/envs/nlp/utils/metrics/meteor.py
index ab15e66d..e9930265 100644
--- a/openrl/envs/nlp/utils/metrics/meteor.py
+++ b/openrl/envs/nlp/utils/metrics/meteor.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-""" METEOR metric. """
+"""METEOR metric."""
import datasets
import evaluate
diff --git a/openrl/envs/vec_env/async_venv.py b/openrl/envs/vec_env/async_venv.py
index e4f10d2b..220f999c 100644
--- a/openrl/envs/vec_env/async_venv.py
+++ b/openrl/envs/vec_env/async_venv.py
@@ -342,9 +342,7 @@ def step_send(self, actions: np.ndarray):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
- def step_fetch(
- self, timeout: Optional[Union[int, float]] = None
- ) -> Union[
+ def step_fetch(self, timeout: Optional[Union[int, float]] = None) -> Union[
Tuple[Any, NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
Tuple[Any, NDArray[Any], NDArray[Any], NDArray[Any], List[Dict[str, Any]]],
]:
diff --git a/openrl/envs/vec_env/wrappers/base_wrapper.py b/openrl/envs/vec_env/wrappers/base_wrapper.py
index 85ca5082..87e5f8a9 100644
--- a/openrl/envs/vec_env/wrappers/base_wrapper.py
+++ b/openrl/envs/vec_env/wrappers/base_wrapper.py
@@ -230,8 +230,9 @@ def step(self, actions, *args, **kwargs):
)
else:
raise ValueError(
- "Invalid step return value, expected 4 or 5 values, got {} values"
- .format(len(results))
+ "Invalid step return value, expected 4 or 5 values, got {} values".format(
+ len(results)
+ )
)
def observation(self, observation: ObsType) -> ObsType:
diff --git a/openrl/selfplay/opponents/utils.py b/openrl/selfplay/opponents/utils.py
index 42ddbb2b..73abc041 100644
--- a/openrl/selfplay/opponents/utils.py
+++ b/openrl/selfplay/opponents/utils.py
@@ -47,7 +47,7 @@ def check_opponent_template(opponent_template: Union[str, Path]):
def get_opponent_info(
- info_path: Optional[Union[str, Path]]
+ info_path: Optional[Union[str, Path]],
) -> Optional[Dict[str, str]]:
if info_path is None:
return None
diff --git a/openrl/selfplay/wrappers/opponent_pool_wrapper.py b/openrl/selfplay/wrappers/opponent_pool_wrapper.py
index d42c17d1..a24ae10c 100644
--- a/openrl/selfplay/wrappers/opponent_pool_wrapper.py
+++ b/openrl/selfplay/wrappers/opponent_pool_wrapper.py
@@ -111,9 +111,10 @@ def on_episode_end(
else:
loser_id = self.opponent.opponent_id
loser_ids.append(loser_id)
- assert set(winner_ids).isdisjoint(set(loser_ids)), (
- "winners and losers must be disjoint, but get winners: {}, losers: {}"
- .format(winner_ids, loser_ids)
+ assert set(winner_ids).isdisjoint(
+ set(loser_ids)
+ ), "winners and losers must be disjoint, but get winners: {}, losers: {}".format(
+ winner_ids, loser_ids
)
battle_info = {"winner_ids": winner_ids, "loser_ids": loser_ids}
self.api_client.add_battle_result(battle_info)