Skip to content

Commit

Permalink
agilerl precommit formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nicku-a committed Jun 24, 2024
1 parent c9b5189 commit f0eda2e
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 25 deletions.
38 changes: 25 additions & 13 deletions tutorials/AgileRL/agilerl_dqn_curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import torch
import wandb
import yaml
from pettingzoo.classic import connect_four_v3
from tqdm import tqdm, trange

from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import create_population
from tqdm import tqdm, trange

from pettingzoo.classic import connect_four_v3


class CurriculumEnv:
Expand Down Expand Up @@ -835,9 +835,13 @@ def transform_and_flip(observation, player):
train_actions_hist[p1_action] += 1

env.step(p1_action) # Act in environment
observation, cumulative_reward, done, truncation, _ = (
env.last()
)
(
observation,
cumulative_reward,
done,
truncation,
_,
) = env.last()
p1_next_state, p1_next_state_flipped = transform_and_flip(
observation, player=1
)
Expand Down Expand Up @@ -938,9 +942,13 @@ def transform_and_flip(observation, player):
rewards = []
for i in range(evo_loop):
env.reset() # Reset environment at start of episode
observation, cumulative_reward, done, truncation, _ = (
env.last()
)
(
observation,
cumulative_reward,
done,
truncation,
_,
) = env.last()

player = -1 # Tracker for which player"s turn it is

Expand Down Expand Up @@ -994,9 +1002,13 @@ def transform_and_flip(observation, player):
eval_actions_hist[action] += 1

env.step(action) # Act in environment
observation, cumulative_reward, done, truncation, _ = (
env.last()
)
(
observation,
cumulative_reward,
done,
truncation,
_,
) = env.last()

if (player > 0 and opponent_first) or (
player < 0 and not opponent_first
Expand All @@ -1021,7 +1033,7 @@ def transform_and_flip(observation, player):
f" Train Mean Score: {np.mean(agent.scores[-episodes_per_epoch:])} Train Mean Turns: {mean_turns} Eval Mean Fitness: {np.mean(fitnesses)} Eval Best Fitness: {np.max(fitnesses)} Eval Mean Turns: {eval_turns} Total Steps: {total_steps}"
)
pbar.update(0)

if wb:
# Format action histograms for visualisation
train_actions_hist = [
Expand Down
6 changes: 3 additions & 3 deletions tutorials/AgileRL/agilerl_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import numpy as np
import supersuit as ss
import torch
from pettingzoo.atari import space_invaders_v2
from tqdm import trange

from agilerl.components.multi_agent_replay_buffer import MultiAgentReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import create_population
from agilerl.wrappers.pettingzoo_wrappers import PettingZooVectorizationParallelWrapper
from tqdm import trange

from pettingzoo.atari import space_invaders_v2

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
6 changes: 3 additions & 3 deletions tutorials/AgileRL/agilerl_matd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@

import numpy as np
import torch
from pettingzoo.mpe import simple_speaker_listener_v4
from tqdm import trange

from agilerl.components.multi_agent_replay_buffer import MultiAgentReplayBuffer
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import create_population
from agilerl.wrappers.pettingzoo_wrappers import PettingZooVectorizationParallelWrapper
from tqdm import trange

from pettingzoo.mpe import simple_speaker_listener_v4

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
4 changes: 2 additions & 2 deletions tutorials/AgileRL/render_agilerl_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import imageio
import numpy as np
import torch
from agilerl.algorithms.dqn import DQN
from agilerl_dqn_curriculum import Opponent, transform_and_flip
from pettingzoo.classic import connect_four_v3
from PIL import Image, ImageDraw, ImageFont

from agilerl.algorithms.dqn import DQN
from pettingzoo.classic import connect_four_v3


# Define function to return image
Expand Down
4 changes: 2 additions & 2 deletions tutorials/AgileRL/render_agilerl_maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
import supersuit as ss
import torch
from pettingzoo.atari import space_invaders_v2
from agilerl.algorithms.maddpg import MADDPG
from PIL import Image, ImageDraw

from agilerl.algorithms.maddpg import MADDPG
from pettingzoo.atari import space_invaders_v2


# Define function to return image
Expand Down
4 changes: 2 additions & 2 deletions tutorials/AgileRL/render_agilerl_matd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import imageio
import numpy as np
import torch
from pettingzoo.mpe import simple_speaker_listener_v4
from agilerl.algorithms.matd3 import MATD3
from PIL import Image, ImageDraw

from agilerl.algorithms.matd3 import MATD3
from pettingzoo.mpe import simple_speaker_listener_v4


# Define function to return image
Expand Down

0 comments on commit f0eda2e

Please sign in to comment.