forked from jchengai/forecast-mae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
66 lines (49 loc) · 2 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from argparse import ArgumentParser
from pathlib import Path
from typing import List
import ray
from tqdm import tqdm
from src.datamodule.av2_extractor import Av2Extractor
from src.datamodule.av2_extractor_multiagent import Av2ExtractorMultiAgent
from src.utils.ray_utils import ActorHandle, ProgressBar
ray.init(num_cpus=16)
def glob_files(data_root: Path, mode: str):
file_root = data_root / mode
scenario_files = list(file_root.rglob("*.parquet"))
return scenario_files
@ray.remote
def preprocess_batch(extractor: Av2Extractor, file_list: List[Path], pb: ActorHandle):
for file in file_list:
extractor.save(file)
pb.update.remote(1)
def preprocess(args):
batch = args.batch
data_root = Path(args.data_root)
for mode in ["train", "val", "test"]:
if args.multiagent:
save_dir = data_root / "multiagent-baseline" / mode
extractor = Av2ExtractorMultiAgent(save_path=save_dir, mode=mode)
else:
save_dir = data_root / "forecast-mae" / mode
extractor = Av2Extractor(save_path=save_dir, mode=mode)
save_dir.mkdir(exist_ok=True, parents=True)
scenario_files = glob_files(data_root, mode)
if args.parallel:
pb = ProgressBar(len(scenario_files), f"preprocess {mode}-set")
pb_actor = pb.actor
for i in range(0, len(scenario_files), batch):
preprocess_batch.remote(
extractor, scenario_files[i : i + batch], pb_actor
)
pb.print_until_done()
else:
for file in tqdm(scenario_files):
extractor.save(file)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data_root", "-d", type=str, required=True)
parser.add_argument("--batch", "-b", type=int, default=50)
parser.add_argument("--parallel", "-p", action="store_true")
parser.add_argument("--multiagent", "-m", action="store_true")
args = parser.parse_args()
preprocess(args)