-
Notifications
You must be signed in to change notification settings - Fork 64
/
run_agent.py
236 lines (209 loc) · 7.57 KB
/
run_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import argparse
import asyncio
import json
import logging
import os
import time
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import docker
from agents.registry import Agent
from agents.registry import registry as agent_registry
from agents.run import run_in_container
from environment.defaults import DEFAULT_CONTAINER_CONFIG_PATH
from mlebench.data import is_dataset_prepared
from mlebench.registry import Competition, registry
from mlebench.utils import create_run_dir, get_logger, get_runs_dir, get_timestamp
logger = get_logger(__name__)
@dataclass(frozen=True)
class Task:
run_id: str
seed: int
image: str
path_to_run_group: Path
path_to_run: Path
agent: Agent
competition: Competition
container_config: dict[str, Any]
async def worker(
idx: int,
queue: asyncio.Queue[Task],
client: docker.DockerClient,
tasks_outputs: dict[str, dict[str, Any]],
) -> None:
while True:
task = await queue.get()
# Create logger for the run
run_logger = get_logger(str(task.path_to_run))
log_file_handler = logging.FileHandler(task.path_to_run / "run.log")
log_file_handler.setFormatter(
logging.getLogger().handlers[0].formatter
) # match the formatting we have
run_logger.addHandler(log_file_handler)
run_logger.propagate = False
run_logger.info(
f"[Worker {idx}] Running seed {task.seed} for {task.competition.id} and agent {task.agent.name}"
)
task_output = {}
try:
await asyncio.to_thread(
run_in_container,
client=client,
competition=task.competition,
agent=task.agent,
image=task.agent.name,
container_config=task.container_config,
retain_container=args.retain,
run_dir=task.path_to_run,
logger=run_logger,
)
task_output["success"] = True
run_logger.info(
f"[Worker {idx}] Finished running seed {task.seed} for {task.competition.id} and agent {task.agent.name}"
)
except Exception as e:
stack_trace = traceback.format_exc()
run_logger.error(type(e))
run_logger.error(stack_trace)
run_logger.error(
f"Run failed for seed {task.seed}, agent {task.agent.id} and competition "
f"{task.competition.id}"
)
task_output["success"] = False
finally:
tasks_outputs[task.run_id] = task_output
queue.task_done()
async def main(args):
client = docker.from_env()
global registry
registry = registry.set_data_dir(Path(args.data_dir))
agent = agent_registry.get_agent(args.agent_id)
if agent.privileged and not (
os.environ.get("I_ACCEPT_RUNNING_PRIVILEGED_CONTAINERS", "False").lower()
in ("true", "1", "t")
):
raise ValueError(
"Agent requires running in a privileged container, but the environment variable `I_ACCEPT_RUNNING_PRIVILEGED_CONTAINERS` is not set to `True`! "
"Carefully consider if you wish to run this agent before continuing. See agents/README.md for more details."
)
run_group = f"{get_timestamp()}_run-group_{agent.name}"
# Load competition ids and check all are prepared
with open(args.competition_set, "r") as f:
competition_ids = [line.strip() for line in f.read().splitlines() if line.strip()]
for competition_id in competition_ids:
competition = registry.get_competition(competition_id)
if not is_dataset_prepared(competition):
raise ValueError(
f"Dataset for competition `{competition.id}` is not prepared! "
f"Please run `mlebench prepare -c {competition.id}` to prepare the dataset."
)
with open(args.container_config, "r") as f:
container_config = json.load(f)
# Create tasks for each (competition * seed)
logger.info(f"Launching run group: {run_group}")
tasks = []
for seed in range(args.n_seeds):
for competition_id in competition_ids:
competition = registry.get_competition(competition_id)
run_dir = create_run_dir(competition.id, agent.id, run_group)
run_id = run_dir.stem
task = Task(
run_id=run_id,
seed=seed,
image=agent.name,
agent=agent,
competition=competition,
path_to_run_group=run_dir.parent,
path_to_run=run_dir,
container_config=container_config,
)
tasks.append(task)
logger.info(f"Creating {args.n_workers} workers to serve {len(tasks)} tasks...")
# Create queue of tasks, and assign workers to run them
queue = asyncio.Queue()
for task in tasks:
queue.put_nowait(task)
workers = []
tasks_outputs = {}
for idx in range(args.n_workers):
w = asyncio.create_task(worker(idx, queue, client, tasks_outputs))
workers.append(w)
# Wait for all tasks to be completed and collect results
started_at = time.monotonic()
await queue.join()
time_taken = time.monotonic() - started_at
for w in workers:
w.cancel() # Cancel all workers now that the queue is empty
await asyncio.gather(*workers, return_exceptions=True)
# Generate metadata.json
metadata = {
"run_group": run_group,
"created_at": get_timestamp(),
"runs": tasks_outputs,
}
run_group_dir = get_runs_dir() / run_group
with open(run_group_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=4, sort_keys=False, default=str)
logger.info(f"{args.n_workers} workers ran for {time_taken:.2f} seconds in total")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run an agent on a set of competitions in a Docker container."
)
parser.add_argument(
"--agent-id",
help="Agent ID of the agent to run.",
type=str,
)
parser.add_argument(
"--competition-set",
type=str,
required=True,
help="Path to a text file with a single competition ID on each line",
)
parser.add_argument(
"--n-workers",
type=int,
required=False,
default=1,
help="Number of workers to run in parallel",
)
parser.add_argument(
"--n-seeds",
type=int,
required=False,
default=1,
help="Number of seeds to run for each competition",
)
parser.add_argument(
"--container-config",
help="Path to a JSON file with an environment configuration; these args will be passed to `docker.from_env().containers.create`",
type=str,
required=False,
default=DEFAULT_CONTAINER_CONFIG_PATH,
)
parser.add_argument(
"--retain",
help="Whether to retain the container after the run instead of removing it.",
action="store_true",
required=False,
default=False,
)
parser.add_argument(
"--run-dir",
help="Path to the directory where all assets associated with the run are stored.",
type=str,
required=False,
default=None,
)
parser.add_argument(
"--data-dir",
help="Path to the directory containing the competition data.",
type=str,
required=False,
default=registry.get_data_dir(),
)
args = parser.parse_args()
logger = get_logger(__name__)
asyncio.run(main(args))