diff --git a/compiler_opt/distributed/worker_test.py b/compiler_opt/distributed/worker_test.py index 87a657ea..df4b5cd5 100644 --- a/compiler_opt/distributed/worker_test.py +++ b/compiler_opt/distributed/worker_test.py @@ -30,12 +30,13 @@ def __init__(self, argument): class WorkerTest(absltest.TestCase): def test_gin_args(self): - with gin.unlock_config(): - gin.bind_parameter('_test.SomeType.argument', 42) - real_args = worker.get_full_worker_args( - SomeType, more_args=2, even_more_args='hi') - self.assertDictEqual(real_args, - dict(argument=42, more_args=2, even_more_args='hi')) + with gin.config_scope('worker_test'): + with gin.unlock_config(): + gin.bind_parameter('_test.SomeType.argument', 42) + real_args = worker.get_full_worker_args( + SomeType, more_args=2, even_more_args='hi') + self.assertDictEqual(real_args, + dict(argument=42, more_args=2, even_more_args='hi')) if __name__ == '__main__': diff --git a/compiler_opt/es/blackbox_learner_test.py b/compiler_opt/es/blackbox_learner_test.py index 4136ab22..0ceaf25b 100644 --- a/compiler_opt/es/blackbox_learner_test.py +++ b/compiler_opt/es/blackbox_learner_test.py @@ -40,6 +40,9 @@ class BlackboxLearnerTests(absltest.TestCase): """Tests for blackbox_learner""" + def tearDown(self): + gin.clear_config() + def setUp(self): super().setUp() diff --git a/compiler_opt/rl/agent_config_test.py b/compiler_opt/rl/agent_config_test.py index cad69d31..c9bbb033 100644 --- a/compiler_opt/rl/agent_config_test.py +++ b/compiler_opt/rl/agent_config_test.py @@ -51,34 +51,42 @@ def setUp(self): super().setUp() def test_create_behavioral_cloning_agent(self): - gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) - gin.bind_parameter('BehavioralCloningAgent.optimizer', - tf.compat.v1.train.AdamOptimizer()) - tf_agent = agent_config.create_agent( - agent_config.BCAgentConfig( - time_step_spec=self._time_step_spec, action_spec=self._action_spec), - preprocessing_layer_creator=_observation_processing_layer) - self.assertIsInstance(tf_agent, - behavioral_cloning_agent.BehavioralCloningAgent) + with gin.config_scope('test_create_behavioral_cloning_agent'): + gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) + gin.bind_parameter('BehavioralCloningAgent.optimizer', + tf.compat.v1.train.AdamOptimizer()) + tf_agent = agent_config.create_agent( + agent_config.BCAgentConfig( + time_step_spec=self._time_step_spec, + action_spec=self._action_spec), + preprocessing_layer_creator=_observation_processing_layer) + self.assertIsInstance(tf_agent, + behavioral_cloning_agent.BehavioralCloningAgent) def test_create_dqn_agent(self): - gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) - gin.bind_parameter('DqnAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) - tf_agent = agent_config.create_agent( - agent_config.DQNAgentConfig( - time_step_spec=self._time_step_spec, action_spec=self._action_spec), - preprocessing_layer_creator=_observation_processing_layer) - self.assertIsInstance(tf_agent, dqn_agent.DqnAgent) + with gin.config_scope('test_create_dqn_agent'): + gin.bind_parameter('create_agent.policy_network', q_network.QNetwork) + gin.bind_parameter('DqnAgent.optimizer', + tf.compat.v1.train.AdamOptimizer()) + tf_agent = agent_config.create_agent( + agent_config.DQNAgentConfig( + time_step_spec=self._time_step_spec, + action_spec=self._action_spec), + preprocessing_layer_creator=_observation_processing_layer) + self.assertIsInstance(tf_agent, dqn_agent.DqnAgent) def test_create_ppo_agent(self): - gin.bind_parameter('create_agent.policy_network', - actor_distribution_network.ActorDistributionNetwork) - gin.bind_parameter('PPOAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) - tf_agent = agent_config.create_agent( - agent_config.PPOAgentConfig( - time_step_spec=self._time_step_spec, action_spec=self._action_spec), - preprocessing_layer_creator=_observation_processing_layer) - self.assertIsInstance(tf_agent, ppo_agent.PPOAgent) + with gin.config_scope('test_create_ppo_agent'): + gin.bind_parameter('create_agent.policy_network', + actor_distribution_network.ActorDistributionNetwork) + gin.bind_parameter('PPOAgent.optimizer', + tf.compat.v1.train.AdamOptimizer()) + tf_agent = agent_config.create_agent( + agent_config.PPOAgentConfig( + time_step_spec=self._time_step_spec, + action_spec=self._action_spec), + preprocessing_layer_creator=_observation_processing_layer) + self.assertIsInstance(tf_agent, ppo_agent.PPOAgent) if __name__ == '__main__':