Skip to content

Commit

Permalink
Create an imitation_learning directory, called imitation_learning_dir (
Browse files Browse the repository at this point in the history
…#397)

and move generate_bc_trajectories* there.
  • Loading branch information
tvmarino authored Dec 17, 2024
1 parent 8e00aca commit 4702cb2
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from absl import logging
import gin

from compiler_opt.rl import generate_bc_trajectories_lib
from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib
from compiler_opt.tools import generate_test_model # pylint:disable=unused-import

from tf_agents.system import system_multiprocessing as multiprocessing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from google.protobuf import text_format # pytype: disable=pyi-error

from compiler_opt.rl import generate_bc_trajectories_lib
from compiler_opt.rl.imitation_learning import generate_bc_trajectories_lib
from compiler_opt.rl import env
from compiler_opt.rl import env_test

Expand Down Expand Up @@ -652,12 +652,12 @@ def setUp(self):
gin.parse_config_files_and_bindings(
config_files=['compiler_opt/rl/inlining/gin_configs/common.gin'],
bindings=[
('compiler_opt.rl.generate_bc_trajectories_test.'
('generate_bc_trajectories_test.'
'MockModuleWorker.clang_path="/test/clang/path"'),
('compiler_opt.rl.generate_bc_trajectories_test.'
('generate_bc_trajectories_test.'
'MockModuleWorker.exploration_frac=1.0'),
('compiler_opt.rl.generate_bc_trajectories_test'
'.MockModuleWorker.reward_key="default"'),
('generate_bc_trajectories_test.'
'MockModuleWorker.reward_key="default"'),
])
return super().setUp()

Expand Down

0 comments on commit 4702cb2

Please sign in to comment.