Skip to content

Commit

Permalink
Makes sure that either both of base_path and keep_temps are set
Browse files Browse the repository at this point in the history
or neither of them is set. Further, passes keep_temps to
```compilation_runner.get_workdir_context``` and ensures that the flag
and argument are not set at the same time.
  • Loading branch information
tvmarino committed Dec 20, 2024
1 parent 588bed7 commit 8dd9070
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 20 deletions.
11 changes: 10 additions & 1 deletion compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,23 @@ def __exit__(self, exc, value, tb):
pass


def get_workdir_context():
def get_workdir_context(keep_temps: Optional[str] = None):
"""Return a context which manages how the temperory directories are handled.
When the flag keep_temps is specified temporary directories are stored in
keep_temps.
Args:
keep_temps: Put temporary files into given directory and keep them
past exit when compilining
"""
if keep_temps and _KEEP_TEMPS.value:
raise ValueError('Only one of flag keep_temps={_KEEP_TEMPS.value}'
'and arg keep_temps={keep_temps} should be specified.')
if _KEEP_TEMPS.value is not None:
tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value)
elif keep_temps:
tempdir_context = NonTemporaryDirectory(dir=keep_temps)
else:
tempdir_context = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
return tempdir_context
Expand Down
24 changes: 20 additions & 4 deletions compiler_opt/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def clang_session(
module: corpus.LoadedModuleSpec,
task_type: Type[MLGOTask],
*,
keep_temps: Optional[str] = None,
interactive: bool,
):
"""Context manager for clang session.
Expand All @@ -236,12 +237,15 @@ def clang_session(
clang_path: The clang binary to use for the InteractiveClang session.
module: The module to compile with clang.
task_type: Type of the MLGOTask to use.
keep_temps: Put temporary files into given directory and keep them
past exit when compilining
interactive: Whether to use an interactive or default clang instance
Yields:
Either the constructed InteractiveClang or DefaultClang object.
"""
tempdir_context = compilation_runner.get_workdir_context()
tempdir_context = compilation_runner.get_workdir_context(
keep_temps=keep_temps)
with tempdir_context as td:
task_working_dir = os.path.join(td, '__task_working_dir__')
os.mkdir(task_working_dir)
Expand Down Expand Up @@ -290,6 +294,7 @@ def _get_scores() -> dict[str, float]:
def _get_clang_generator(
clang_path: str,
task_type: Type[MLGOTask],
keep_temps: Optional[str] = None,
interactive_only: bool = False,
) -> Generator[Optional[Tuple[ClangProcess, InteractiveClang]],
Optional[corpus.LoadedModuleSpec], None]:
Expand All @@ -298,6 +303,8 @@ def _get_clang_generator(
Args:
clang_path: Path to the clang binary to use within InteractiveClang.
task_type: Type of the MLGO task to use.
keep_temps: Put temporary files into given directory and keep them
past exit when compilining
interactive_only: If set to true the returned tuple of generators is
iclang, iclang instead of iclang, clang
Expand All @@ -315,12 +322,17 @@ def _get_clang_generator(
# https://github.com/google/yapf/issues/1092
module = yield
with clang_session(
clang_path, module, task_type, interactive=True) as iclang:
clang_path, module, task_type, keep_temps=keep_temps,
interactive=True) as iclang:
if interactive_only:
yield iclang, iclang
else:
with clang_session(
clang_path, module, task_type, interactive=False) as clang:
clang_path,
module,
task_type,
keep_temps=keep_temps,
interactive=False) as clang:
yield iclang, clang


Expand All @@ -340,10 +352,14 @@ def __init__(
task_type: Type[MLGOTask],
obs_spec,
action_spec,
keep_temps: Optional[str] = None,
interactive_only: bool = False,
):
self._clang_generator = _get_clang_generator(
clang_path, task_type, interactive_only=interactive_only)
clang_path,
task_type,
keep_temps=keep_temps,
interactive_only=interactive_only)
self._obs_spec = obs_spec
self._action_spec = action_spec

Expand Down
29 changes: 20 additions & 9 deletions compiler_opt/rl/imitation_learning/generate_bc_trajectories_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator, Union
import json

# from absl import flags
from absl import flags
from absl import logging
import bisect
import dataclasses
Expand All @@ -46,6 +46,13 @@
from compiler_opt.distributed import buffered_scheduler
from compiler_opt.distributed.local import local_worker_manager

_BASE_PATH = flags.DEFINE_string(
'base_path', None, ('If specified, the temp compiled binaries throughout'
'the trajectory generation will be saved in base_path'
'for linking the final binary.'))

FLAGS = flags.FLAGS

ProfilingDictValueType = Dict[str, Union[str, float, int]]


Expand Down Expand Up @@ -313,11 +320,10 @@ def __init__(
max_horizon_to_explore=np.inf,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[
time_step.TimeStep,
tensor_spec.BoundedTensorSpec,
]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
reward_key: str = '',
keep_temps: Optional[str] = None,
**kwargs,
):
self._loaded_module_spec = loaded_module_spec
Expand All @@ -343,6 +349,7 @@ def __init__(
task_type=mlgo_task_type,
obs_spec=obs_spec,
action_spec=action_spec,
keep_temps=keep_temps,
interactive_only=True,
)
if self._env.action_spec:
Expand Down Expand Up @@ -744,10 +751,8 @@ def __init__(
exploration_policy_paths: Optional[str] = None,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[
time_step.TimeStep,
tensor_spec.BoundedTensorSpec,
]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
base_path: Optional[str] = None,
partitions: List[float] = [
0.,
Expand Down Expand Up @@ -918,6 +923,10 @@ def gen_trajectories(
worker_manager_class: A pool of workers hosted on the local machines, each
in its own process.
"""
if (None in (_BASE_PATH.value, FLAGS.keep_temps) and
not all(el is None for el in (_BASE_PATH.value, FLAGS.keep_temps))):
raise ValueError(('Both flags keep_temps={FLAGS.keep_temps} and'
'base_path={_BASE_PATH.value} should be set or be None'))
cps = corpus.Corpus(data_path=data_path, delete_flags=delete_flags)
logging.info('Done loading module specs from corpus.')

Expand All @@ -944,6 +953,8 @@ def gen_trajectories(
mlgo_task_type=mlgo_task_type,
callable_policies=callable_policies,
explore_on_features=explore_on_features,
base_path=_BASE_PATH.value,
keep_temps=FLAGS.keep_temps,
gin_config_str=gin.config_str(),
) as lwm:

Expand Down
2 changes: 0 additions & 2 deletions compiler_opt/rl/inlining/gin_configs/imitation_learning.gin
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ generate_bc_trajectories_lib.ModuleWorker.mlgo_task_type=@env.InliningForSizeTas
generate_bc_trajectories_lib.ModuleWorker.policy_paths=['']
generate_bc_trajectories_lib.ModuleWorker.exploration_policy_paths=[]
generate_bc_trajectories_lib.ModuleWorker.explore_on_features=None
generate_bc_trajectories_lib.ModuleWorker.base_path=''
generate_bc_trajectories_lib.ModuleWorker.partitions=[
285.0, 376.0, 452.0, 512.0, 571.0, 627.5, 720.0, 809.5, 1304.0, 1832.0,
2467.0, 3344.0, 4545.0, 6459.0, 9845.0, 17953.0, 29430.5, 85533.5,
124361.0]
generate_bc_trajectories_lib.ModuleWorker.reward_key='default'
# generate_bc_trajectories_lib.ModuleWorker.gin_config_str=None

generate_bc_trajectories_lib.gen_trajectories.data_path=''
generate_bc_trajectories_lib.gen_trajectories.delete_flags=('-split-dwarf-file', '-split-dwarf-output')
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/rl/policy_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def __init__(self, policy_dict: Dict[str, tf_policy.TFPolicy]):
self._policy_saver_dict: Dict[str, Tuple[
policy_saver.PolicySaver, tf_policy.TFPolicy]] = {
policy_name: (policy_saver.PolicySaver(
policy, batch_size=1, use_nest_path_signatures=False), policy
) for policy_name, policy in policy_dict.items()
policy, batch_size=1, use_nest_path_signatures=False), policy)
for policy_name, policy in policy_dict.items()
}

def _write_output_signature(
Expand Down
4 changes: 2 additions & 2 deletions compiler_opt/rl/train_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def sequence_example_iterator_fn(seq_ex: List[str]):

# Repeat for num_policy_iterations iterations.
t1 = time.time()
while (llvm_trainer.global_step_numpy()
< num_policy_iterations * num_iterations):
while (llvm_trainer.global_step_numpy() <
num_policy_iterations * num_iterations):
t2 = time.time()
logging.info('Last iteration took: %f', t2 - t1)
t1 = t2
Expand Down

0 comments on commit 8dd9070

Please sign in to comment.