diff --git a/compiler_opt/rl/imitation_learning/generate_bc_trajectories.py b/compiler_opt/rl/imitation_learning/generate_bc_trajectories.py index 76b069ed..c333642e 100644 --- a/compiler_opt/rl/imitation_learning/generate_bc_trajectories.py +++ b/compiler_opt/rl/imitation_learning/generate_bc_trajectories.py @@ -21,19 +21,19 @@ import gin from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib -from compiler_opt.tools import generate_test_model # pylint:disable=unused-import from tf_agents.system import system_multiprocessing as multiprocessing -flags.FLAGS['gin_files'].allow_override = True -flags.FLAGS['gin_bindings'].allow_override = True - -FLAGS = flags.FLAGS +_GIN_FILES = flags.DEFINE_multi_string( + 'gin_files', [], 'List of paths to gin configuration files.') +_GIN_BINDINGS = flags.DEFINE_multi_string( + 'gin_bindings', [], + 'Gin bindings to override the values set in the config files.') def main(_): gin.parse_config_files_and_bindings( - FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=True) + _GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=True) logging.info(gin.config_str()) generate_bc_trajectories_lib.gen_trajectories() diff --git a/compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py b/compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py index 8f881669..fbd3cdd8 100644 --- a/compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py +++ b/compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator, Union import json -from absl import flags +# from absl import flags from absl import logging import bisect import dataclasses @@ -46,13 +46,6 @@ from compiler_opt.distributed import buffered_scheduler from compiler_opt.distributed.local import local_worker_manager -from compiler_opt.tools import generate_test_model # pylint:disable=unused-import - -flags.FLAGS['gin_files'].allow_override = True -flags.FLAGS['gin_bindings'].allow_override = True - -FLAGS = flags.FLAGS - ProfilingDictValueType = Dict[str, Union[str, float, int]] diff --git a/compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py b/compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py index c3037833..f87551e1 100644 --- a/compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py +++ b/compiler_opt/rl/imitation_learning/generate_bc_trajectories_test.py @@ -16,7 +16,6 @@ import functools from absl import app -from absl import flags import gin import json from typing import List @@ -36,8 +35,8 @@ from compiler_opt.rl import env from compiler_opt.rl import env_test -flags.FLAGS['gin_files'].allow_override = True -flags.FLAGS['gin_bindings'].allow_override = True +# flags.FLAGS['gin_files'].allow_override = True +# flags.FLAGS['gin_bindings'].allow_override = True _eps = 1e-5 @@ -648,17 +647,17 @@ def select_best_exploration(self, mock_popen, loaded_module_spec): class GenTrajectoriesTest(tf.test.TestCase): def setUp(self): - with gin.unlock_config(): - gin.parse_config_files_and_bindings( - config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'], - bindings=[ - ('generate_bc_trajectories_test.' - 'MockModuleWorker.clang_path="/test/clang/path"'), - ('generate_bc_trajectories_test.' - 'MockModuleWorker.exploration_frac=1.0'), - ('generate_bc_trajectories_test.' - 'MockModuleWorker.reward_key="default"'), - ]) + with gin.config_scope('gen_trajectories_test'): + with gin.unlock_config(): + gin.bind_parameter( + 'generate_bc_trajectories_test.MockModuleWorker.clang_path', + '/test/clang/path') + gin.bind_parameter( + 'generate_bc_trajectories_test.MockModuleWorker.exploration_frac', + 1.0) + gin.bind_parameter( + 'generate_bc_trajectories_test.MockModuleWorker.reward_key', + 'default') return super().setUp() def test_gen_trajectories(self): diff --git a/compiler_opt/rl/inlining/gin_configs/imitation_learning.gin b/compiler_opt/rl/inlining/gin_configs/imitation_learning.gin index cb654885..9b69ffdd 100644 --- a/compiler_opt/rl/inlining/gin_configs/imitation_learning.gin +++ b/compiler_opt/rl/inlining/gin_configs/imitation_learning.gin @@ -7,11 +7,14 @@ env.InliningForSizeTask.llvm_size_path='' generate_bc_trajectories_lib.ModuleWorker.clang_path='' generate_bc_trajectories_lib.ModuleWorker.mlgo_task_type=@env.InliningForSizeTask -generate_bc_trajectories_lib.ModuleWorker.policy_paths=[] +generate_bc_trajectories_lib.ModuleWorker.policy_paths=[''] generate_bc_trajectories_lib.ModuleWorker.exploration_policy_paths=[] generate_bc_trajectories_lib.ModuleWorker.explore_on_features=None generate_bc_trajectories_lib.ModuleWorker.base_path='' -generate_bc_trajectories_lib.ModuleWorker.partitions=[0.,] +generate_bc_trajectories_lib.ModuleWorker.partitions=[ + 285.0, 376.0, 452.0, 512.0, 571.0, 627.5, 720.0, 809.5, 1304.0, 1832.0, + 2467.0, 3344.0, 4545.0, 6459.0, 9845.0, 17953.0, 29430.5, 85533.5, + 124361.0] generate_bc_trajectories_lib.ModuleWorker.reward_key='default' # generate_bc_trajectories_lib.ModuleWorker.gin_config_str=None diff --git a/compiler_opt/rl/inlining/imitation_learning_runner.py b/compiler_opt/rl/inlining/imitation_learning_runner.py index 84724331..9d5bbfb1 100644 --- a/compiler_opt/rl/inlining/imitation_learning_runner.py +++ b/compiler_opt/rl/inlining/imitation_learning_runner.py @@ -22,19 +22,19 @@ from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib from compiler_opt.rl.inlining import imitation_learning_config -from compiler_opt.tools import generate_test_model # pylint:disable=unused-import from tf_agents.system import system_multiprocessing as multiprocessing -flags.FLAGS['gin_files'].allow_override = True -flags.FLAGS['gin_bindings'].allow_override = True - -FLAGS = flags.FLAGS +_GIN_FILES = flags.DEFINE_multi_string( + 'gin_files', [], 'List of paths to gin configuration files.') +_GIN_BINDINGS = flags.DEFINE_multi_string( + 'gin_bindings', [], + 'Gin bindings to override the values set in the config files.') def main(_): gin.parse_config_files_and_bindings( - FLAGS.gin_files, bindings=FLAGS.gin_bindings, skip_unknown=True) + _GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=True) logging.info(gin.config_str()) generate_bc_trajectories_lib.gen_trajectories(