Skip to content

Commit

Permalink
Fix env attribute forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Apr 3, 2024
1 parent 1f8c554 commit c32e198
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,14 @@ def test_rollout_success_rate_on_policy_algorithm(tmp_path):
env = Monitor(DummySuccessEnv(dummy_successes, ep_steps), filename=monitor_file, info_keywords=("is_success",))

# Equip the model of a custom logger to check the success_rate info
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.steps_per_log, verbose=1)
model = PPO("MlpPolicy", env=env, stats_window_size=STATS_WINDOW_SIZE, n_steps=env.env.steps_per_log, verbose=1)
logger = InMemoryLogger()
model.set_logger(logger)

# Make the model learn and check that the success rate corresponds to the ratio of dummy successes
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.3
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.5
model.learn(total_timesteps=env.ep_per_log * ep_steps, log_interval=1)
model.learn(total_timesteps=env.env.ep_per_log * ep_steps, log_interval=1)
assert logger.name_to_value["rollout/success_rate"] == 0.8

0 comments on commit c32e198

Please sign in to comment.