-
Notifications
You must be signed in to change notification settings - Fork 0
/
mask_ipa_pretrain.py
76 lines (61 loc) · 2.86 KB
/
mask_ipa_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from comet_ml import Experiment
import hydra
import os
import torch
from dataloader.large_dataset import Cath
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from dataloader.collator import CollatorIPAPretrain
from model.ipa.ipa_net import IPANetPredictor
from torch.optim import Adam, lr_scheduler
from trainer.mask_ipa_trainer import Trainer
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
@hydra.main(version_base=None, config_path="conf", config_name="mask_pretrain")
def main(cfg: DictConfig):
if cfg.comet.use:
experiment = Experiment(
project_name=cfg.comet.project_name,
workspace=cfg.comet.workspace,
auto_output_logging="simple",
log_graph=True,
log_code=False,
log_git_metadata=False,
log_git_patch=False,
auto_param_logging=False,
auto_metric_logging=False
)
experiment.log_parameters(OmegaConf.to_container(cfg))
experiment.set_name(cfg.experiment.name)
if cfg.comet.comet_tag:
experiment.add_tag(cfg.comet.comet_tag)
else:
experiment = None
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
print(OmegaConf.to_yaml(cfg))
print(f"Output directory: {output_dir}")
if experiment:
experiment.log_parameters({"output_dir": output_dir})
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
train_ID = os.listdir(cfg.dataset.train_dir)
train_dataset = Cath(train_ID, cfg.dataset.train_dir)
collator = CollatorIPAPretrain(candi_rate=cfg.train.candi_rate, mask_rate=cfg.train.mask_rate,
replace_rate=cfg.train.replace_rate, keep_rate=cfg.train.keep_rate)
train_loader = DataLoader(train_dataset, batch_size=cfg.train.batch_size, shuffle=True, num_workers=6,
collate_fn=collator)
model = IPANetPredictor(dropout=cfg.model.ipa_drop_out).to(device)
print(f"Total parameters: {count_parameters(model)}")
steps_per_epoch = len(train_loader)
optimizer = Adam(model.parameters(), lr=cfg.train.lr, betas=(0.95, 0.999), weight_decay=cfg.train.weight_decay)
if cfg.train.scheduler:
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.train.lr, total_steps=cfg.train.train_epochs * steps_per_epoch)
else:
scheduler = None
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
trainer = Trainer(config=cfg, model=model, optimizer=optimizer, epochs=cfg.train.train_epochs, loss_fn=loss_fn,
train_dataloader=train_loader, output_dir=output_dir, device=device,
scheduler=scheduler, experiment=experiment)
trainer.fit()
print(f"Output directory: {output_dir}")
if __name__ == "__main__":
main()