diff --git a/sprint/saycan_eval.py b/sprint/saycan_eval.py index b5d5553..d2f6102 100644 --- a/sprint/saycan_eval.py +++ b/sprint/saycan_eval.py @@ -32,13 +32,11 @@ ) from sprint.utils.utils import make_primitive_annotation_eval_dataset from sprint.utils.data_utils import process_annotation +from sprint.utils.wandb_info import WANDB_PROJECT_NAME, WANDB_ENTITY_NAME from sprint.rollouts.saycan_rollout import run_policy_multi_process os.environ["TOKENIZERS_PARALLELISM"] = "false" -WANDB_ENTITY_NAME = "clvr" -WANDB_PROJECT_NAME = "p-bootstrap-llm" - def setup_mp( result_queue, @@ -93,9 +91,7 @@ def multithread_dataset_aggregation( # asynchronously collect results from result_queue num_env_samples = 0 num_finished_tasks = 0 - num_rollouts = ( - config.num_eval_tasks if eval else config.num_rollouts_per_epoch - ) + num_rollouts = config.num_eval_tasks if eval else config.num_rollouts_per_epoch with tqdm(total=num_rollouts) as pbar: while num_finished_tasks < num_rollouts: result = result_queue.get() @@ -142,18 +138,11 @@ def multiprocess_rollout( video_captions = [] extra_info = defaultdict(list) - num_rollouts = ( - config.num_eval_tasks if eval else config.num_rollouts_per_epoch - ) + num_rollouts = config.num_eval_tasks if eval else config.num_rollouts_per_epoch # create tasks for MP Queue # create tasks for thread/process Queue - args_func = lambda subgoal: ( - True, - True, - epsilon, - subgoal - ) + args_func = lambda subgoal: (True, True, epsilon, subgoal) for subgoal in range(num_rollouts): task_queue.put(args_func(subgoal)) @@ -323,7 +312,7 @@ def signal_handler(sig, frame): result_queue, config, 0, - True, + True, ) wandb.log( eval_metrics,