Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BC] Generate trajectories #388

Merged
merged 19 commits into from
Nov 13, 2024
Merged

Conversation

tvmarino
Copy link
Collaborator

@tvmarino tvmarino commented Nov 1, 2024

Commit gen_trajectories() which is the function that loads the modules corpus, creates the worker manager to be used with ModuleWorker, collects the results and writes the results to file.

num_workers: Optional[int] = None,
num_output_files: int = 1,
profiling_file_path: Optional[str] = None,
worker_wait: int = 10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: worker_wait_sec and it's self-documenting the unit now, too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

total_work = len(corpus_elements)
total_failed_examples = 0
total_write_files = num_output_files
total_profiles_max: List[Optional[Dict[str, Union[str, float, int]]]] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a lot of nesting. A trick to help comprehension is to alias the type somewhere (as a module-level def), e.g.: (I'm making the names up)

ExperimentValueType = Union[str, float, int]
ExperimentResultType = Dict[str, ExperimentValueType]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set ProfilingDictValueType = Dict[str, Union[str, float, int]].

total_successful_examples = 0
total_work = len(corpus_elements)
total_failed_examples = 0
total_write_files = num_output_files
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not leave this as num_output_files? you're never mutating it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. I will fix this.

Dict[str, Union[str, float, int]]],
tf.train.SequenceExample]]] = []

for written_files in range(total_write_files):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/written_files/written_file_index?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it to written_files_idx.

logging.INFO,
('%d success, %d failed out of %d, modules processed'
' %d\n timing compiler: %f'),
10,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's 10?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the log_every_n_seconds docstring it looks like it's the time between each logging.

modules_processed,
time_compiler_calls,
)
if len(succeeded) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't if not succeeded work?


max_profiles_path = ''
pol_profiles_path = ''
if profiling_file_path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if profiling_file_path isn't given, then what happens with the open below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context is set to contextlib.nullcontext(), so I think nothing happens.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be cleaner to check here and do the open stuff only when profiling_file_path is set?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am addressing this.

ModuleWorker and ModuleExplorer to replace class and callables which
were passed as a gin.config to be directly passed to gen_trajectories.
This is because gin.config classes and callables can not be pickled for
multiprocessing purposes.
@tvmarino tvmarino merged commit ad31887 into google:main Nov 13, 2024
15 checks passed
@tvmarino tvmarino deleted the bc_generate_trajectories branch November 13, 2024 20:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants